In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import (
    Module, 
    Linear, 
    Dropout, 
    LayerNorm, 
    ModuleList, 
    TransformerEncoder, 
    TransformerEncoderLayer
)

import pandas as pd

import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np

class UserItemDataset:
    def __init__(self, user_file, item_file, interaction_file):
        self.users = pd.read_csv(user_file)
        self.items = pd.read_csv(item_file)
        self.interactions = pd.read_csv(interaction_file)

        # Strip whitespace from column names
        self.users.columns = self.users.columns.str.strip()
        self.items.columns = self.items.columns.str.strip()
        self.interactions.columns = self.interactions.columns.str.strip()

    def __len__(self):
        # Return the number of interactions (or samples) in the dataset
        return len(self.interactions)

    def __getitem__(self, index):
        # Get the interaction data
        interaction = self.interactions.iloc[index]
        
        user_id = interaction['user_id']  # Adjust based on your actual column names
        item_id = interaction['item_id']  # This should be present in interactions
        
        # Find the corresponding label from the items DataFrame
        item_row = self.items[self.items['merchant_id'] == item_id]
        
        if not item_row.empty:
            label = item_row['label'].values[0]  # Get the label if it exists
        else:
            label = 0  # Default value if no label is found (or handle as needed)

        return user_id, item_id, label

class GraphTransformerV2(nn.Module):
    def __init__(self, num_layers, d_model, num_heads, d_feedforward, input_dim, num_weights=10, use_weights=True, dropout=0.1):
        super(GraphTransformerV2, self).__init__()
        self.num_weights = num_weights
        self.use_weights = use_weights
        
        # Adjust input_linear to handle the concatenated user-item embedding
        self.input_linear = Linear(input_dim, d_model)
        
        self.encoder_layer = TransformerEncoderLayer(
            d_model=d_model, 
            nhead=num_heads, 
            dim_feedforward=d_feedforward, 
            dropout=dropout, 
            batch_first=True
        )
        self.transformer_encoder = TransformerEncoder(self.encoder_layer, num_layers=num_layers)
        self.output_linear = Linear(d_model, input_dim)
        self.dropout = Dropout(dropout)
        self.layer_norm = LayerNorm(d_model)
        
        if self.use_weights:
            self.weight_linears = ModuleList([Linear(input_dim, d_model) for _ in range(num_weights)])

    def forward(self, x, adjacency_matrix, graph_metrics, weights=None):
        # Ensure adjacency_matrix is a FloatTensor
        adjacency_matrix = adjacency_matrix.float()
        
        # Ensure graph_metrics is a FloatTensor
        graph_metrics = graph_metrics.float()

        # Validate and potentially adjust dimensions
        batch_size, input_dim = x.shape
        
        # Ensure adjacency matrix is square and matches batch size
        if adjacency_matrix.size(0) != batch_size or adjacency_matrix.size(1) != batch_size:
            # Create an identity-like matrix if dimensions don't match
            adjacency_matrix = torch.eye(batch_size, device=x.device)

        try:
            # Direct Connections
            direct_scores = adjacency_matrix @ x  # Matrix multiplication to get direct connection scores

            # Neighborhood Similarity (modified to handle potential dimension issues)
            try:
                neighborhood_similarity = self.compute_neighborhood_similarity(adjacency_matrix, x)
            except RuntimeError:
                # Fallback: use a simplified similarity if computation fails
                neighborhood_similarity = torch.zeros_like(x)

            # Graph Structure Scores - modify to handle 2D graph metrics
            if graph_metrics.dim() == 2:
                # Project graph metrics to match input dimensions
                graph_metrics_projected = self.project_graph_metrics(graph_metrics, input_dim)
                graph_structure_scores = graph_metrics_projected * x  # Element-wise multiplication instead of matrix multiplication
            else:
                graph_structure_scores = torch.zeros_like(x)

            # Combine DNG scores
            dng_scores = direct_scores + neighborhood_similarity + graph_structure_scores

            # Optional weighted processing
            if self.use_weights and weights is not None:
                weighted_x = torch.zeros_like(x)
                for i, weight in enumerate(weights.T):
                    weighted_x += self.weight_linears[i](x) * weight.unsqueeze(1)
                x = weighted_x
            else:
                x = self.input_linear(x)

            x = self.layer_norm(x)
            x = self.transformer_encoder(x.unsqueeze(1)).squeeze(1)  # Adjust for transformer input
            x = self.output_linear(x)
            x = self.dropout(x)

            # Combine with DNG scores
            final_scores = F.relu(x + dng_scores)
            return final_scores

        except RuntimeError as e:
            print(f"RuntimeError during forward pass: {e}")
            print(f"x shape: {x.shape}, adjacency_matrix shape: {adjacency_matrix.shape}, graph_metrics shape: {graph_metrics.shape}")
            raise

    def project_graph_metrics(self, graph_metrics, target_dim):
        """
        Project graph metrics to match target dimension
        
        Args:
        - graph_metrics: Tensor of shape [batch_size, num_metrics]
        - target_dim: Desired output dimension
        
        Returns:
        - Projected tensor of shape [batch_size, target_dim]
        """
        # If graph_metrics has fewer dimensions than target, repeat or expand
        if graph_metrics.size(1) < target_dim:
            # Repeat the metrics to fill the target dimension
            repeats = (target_dim + graph_metrics.size(1) - 1) // graph_metrics.size(1)
            graph_metrics = graph_metrics.repeat(1, repeats)[:, :target_dim]
        elif graph_metrics.size(1) > target_dim:
            # Truncate if too many metrics
            graph_metrics = graph_metrics[:, :target_dim]
        
        return graph_metrics

    def compute_neighborhood_similarity(self, adjacency_matrix, x):
        # Robust Jaccard similarity computation
        try:
            # Ensure adjacency matrix is binary
            binary_adj = (adjacency_matrix > 0).float()
            
            # Compute intersection
            intersection = binary_adj @ binary_adj.T
            
            # Compute row and column sums
            row_sums = binary_adj.sum(dim=1, keepdim=True)
            col_sums = binary_adj.sum(dim=0, keepdim=True)
            
            # Compute union
            union = row_sums + col_sums.T - intersection
            
            # Compute similarity with small epsilon to avoid division by zero
            similarity = intersection / (union + 1e-8)
            
            # Matrix multiplication with input
            return similarity @ x
        
        except RuntimeError:
            # Fallback to a simple similarity if computation fails
            return torch.zeros_like(x)

            

    def forward(self, x, adjacency_matrix, graph_metrics, weights=None):
        # Ensure adjacency_matrix is a FloatTensor
        adjacency_matrix = adjacency_matrix.float()
        
        # Ensure graph_metrics is a FloatTensor
        graph_metrics = graph_metrics.float()

        # Validate and potentially adjust dimensions
        batch_size, input_dim = x.shape
        
        # Ensure adjacency matrix is square and matches batch size
        if adjacency_matrix.size(0) != batch_size or adjacency_matrix.size(1) != batch_size:
            # Create an identity-like matrix if dimensions don't match
            adjacency_matrix = torch.eye(batch_size, device=x.device)

        try:
            # Direct Connections
            direct_scores = adjacency_matrix @ x  # Matrix multiplication to get direct connection scores

            # Neighborhood Similarity (modified to handle potential dimension issues)
            try:
                neighborhood_similarity = self.compute_neighborhood_similarity(adjacency_matrix, x)
            except RuntimeError:
                # Fallback: use a simplified similarity if computation fails
                neighborhood_similarity = torch.zeros_like(x)

            # Graph Structure Scores - modify to handle 2D graph metrics
            if graph_metrics.dim() == 2:
                # Project graph metrics to match input dimensions
                graph_metrics_projected = self.project_graph_metrics(graph_metrics, input_dim)
                graph_structure_scores = graph_metrics_projected * x  # Element-wise multiplication instead of matrix multiplication
            else:
                graph_structure_scores = torch.zeros_like(x)

            # Combine DNG scores
            dng_scores = direct_scores + neighborhood_similarity + graph_structure_scores

            # Optional weighted processing
            if self.use_weights and weights is not None:
                weighted_x = torch.zeros_like(x)
                for i, weight in enumerate(weights.T):
                    weighted_x += self.weight_linears[i](x) * weight.unsqueeze(1)
                x = weighted_x
            else:
                x = self.input_linear(x)

            x = self.layer_norm(x)
            x = self.transformer_encoder(x.unsqueeze(1)).squeeze(1)  # Adjust for transformer input
            x = self.output_linear(x)
            x = self.dropout(x)

            # Combine with DNG scores
            final_scores = F.relu(x + dng_scores)
            return final_scores

        except RuntimeError as e:
            print(f"RuntimeError during forward pass: {e}")
            print(f"x shape: {x.shape}, adjacency_matrix shape: {adjacency_matrix.shape}, graph_metrics shape: {graph_metrics.shape}")
            raise

    def project_graph_metrics(self, graph_metrics, target_dim):
        """
        Project graph metrics to match target dimension
        
        Args:
        - graph_metrics: Tensor of shape [batch_size, num_metrics]
        - target_dim: Desired output dimension
        
        Returns:
        - Projected tensor of shape [batch_size, target_dim]
        """
        # If graph_metrics has fewer dimensions than target, repeat or expand
        if graph_metrics.size(1) < target_dim:
            # Repeat the metrics to fill the target dimension
            repeats = (target_dim + graph_metrics.size(1) - 1) // graph_metrics.size(1)
            graph_metrics = graph_metrics.repeat(1, repeats)[:, :target_dim]
        elif graph_metrics.size(1) > target_dim:
            # Truncate if too many metrics
            graph_metrics = graph_metrics[:, :target_dim]
        
        return graph_metrics

    def compute_neighborhood_similarity(self, adjacency_matrix, x):
        # Robust Jaccard similarity computation
        try:
            # Ensure adjacency matrix is binary
            binary_adj = (adjacency_matrix > 0).float()
            
            # Compute intersection
            intersection = binary_adj @ binary_adj.T
            
            # Compute row and column sums
            row_sums = binary_adj.sum(dim=1, keepdim=True)
            col_sums = binary_adj.sum(dim=0, keepdim=True)
            
            # Compute union
            union = row_sums + col_sums.T - intersection
            
            # Compute similarity with small epsilon to avoid division by zero
            similarity = intersection / (union + 1e-8)
            
            # Matrix multiplication with input
            return similarity @ x
        
        except RuntimeError:
            # Fallback to a simple similarity if computation fails
            return torch.zeros_like(x)



In [2]:
pwd

'/Users/visheshyadav/Documents/GitHub/CoreRec/src/SANDBOX/Tmodel'

In [3]:
# Paths to the IJCAI-15 dataset CSV files
user_file = '/Users/visheshyadav/Documents/GitHub/CoreRec/src/SANDBOX/dataset/IJCAI-15/user_info_format1.csv'
item_file = '/Users/visheshyadav/Documents/GitHub/CoreRec/src/SANDBOX/dataset/IJCAI-15/train_format1.csv'
interaction_file = '/Users/visheshyadav/Documents/GitHub/CoreRec/src/SANDBOX/dataset/IJCAI-15/user_log_format1.csv'

# Load dataset
dataset = UserItemDataset(user_file, item_file, interaction_file)
data_loader = DataLoader(dataset, batch_size=64, shuffle=True)

# Get maximum user and item IDs
max_user_id = dataset.users['user_id'].max()
max_item_id = dataset.interactions['item_id'].max()  # Access item_id from interactions
num_users = max_user_id + 1  # +1 because IDs are zero-indexed
num_items = max_item_id + 1  # +1 for the same reason

print(f"Number of users: {num_users}")
print(f"Number of items: {num_items}")

Number of users: 424171
Number of items: 1113167


In [4]:
class TransformerRecommender(nn.Module):
    def __init__(self, num_users, num_items, embedding_dim=64, num_layers=2, num_heads=4, dropout=0.1):
        super(TransformerRecommender, self).__init__()
        self.num_users = num_users
        self.num_items = num_items
        self.embedding_dim = embedding_dim

        # Initialize the transformer model
        self.model = GraphTransformerV2(
            num_layers=num_layers,
            d_model=embedding_dim,
            num_heads=num_heads,
            d_feedforward=embedding_dim * 4,
            input_dim=2 * embedding_dim,  
            dropout=dropout
        )

        # User and Item embeddings
        self.user_embeddings = nn.Embedding(num_users, embedding_dim)
        self.item_embeddings = nn.Embedding(num_items, embedding_dim)

        # Loss function
        self.criterion = nn.BCEWithLogitsLoss()

        # Optimizer
        self.optimizer = torch.optim.Adam(self.parameters())

    def create_batch_graph_structure(self, batch_size):
        # Create adjacency matrix for the batch (batch_size x batch_size)
        adj_matrix = torch.zeros((batch_size, batch_size))

        # Create basic graph metrics
        graph_metrics = {
            'degree': torch.zeros(batch_size),
            'clustering': torch.zeros(batch_size),
            'centrality': torch.zeros(batch_size)
        }

        return adj_matrix, graph_metrics

    def update_batch_graph_structure(self, user_ids, item_ids, batch_size):
        # Create new batch-specific adjacency matrix
        adj_matrix = torch.zeros((batch_size, batch_size))

        # Create connections between users and items within the batch
        for i in range(batch_size):
            for j in range(batch_size):
                if user_ids[i] == user_ids[j] or item_ids[i] == item_ids[j]:
                    adj_matrix[i, j] = 1.0

        # Calculate basic graph metrics for the batch
        graph_metrics = {
            'degree': adj_matrix.sum(dim=1),
            'clustering': torch.zeros(batch_size),  # Simplified clustering coefficient
            'centrality': adj_matrix.sum(dim=0) / batch_size  # Simplified centrality measure
        }

        return adj_matrix, graph_metrics

    def forward(self, user_ids, item_ids):
        user_emb = self.user_embeddings(user_ids)
        item_emb = self.item_embeddings(item_ids)

        # Concatenate user and item embeddings
        input_emb = torch.cat([user_emb, item_emb], dim=1)  # Shape: [batch_size, 2*embedding_dim]

        # Update batch-specific graph structure
        batch_size = user_ids.size(0)
        adj_matrix, graph_metrics = self.update_batch_graph_structure(user_ids, item_ids, batch_size)

        # Convert graph_metrics to a tensor
        graph_metrics_tensor = torch.stack([
            graph_metrics['degree'],
            graph_metrics['clustering'],
            graph_metrics['centrality']
        ]).T  # Shape: [batch_size, 3]

        # Forward pass through the transformer model
        output = self.model(input_emb, adj_matrix, graph_metrics_tensor)

        return output.mean(dim=1)  # Return mean predictions

    def train_step(self, user_ids, item_ids, labels):
        self.train()  # Set the model to training mode
        self.optimizer.zero_grad()

        # Forward pass
        pred = self.forward(user_ids, item_ids)

        # Calculate loss
        loss = self.criterion(pred, labels.float())

        # Backward pass
        loss.backward()
        self.optimizer.step()

        return loss.item()

    def predict(self, user_ids, item_ids):
        self.eval()  # Set the model to evaluation mode
        with torch.no_grad():
            return self.forward(user_ids, item_ids)

    def eval(self):
        self.model.eval()  # Set the transformer model to evaluation mode     

In [5]:
# Define BiasMF
class BiasMF(nn.Module):
    def __init__(self, num_users, num_items, embedding_dim):
        super(BiasMF, self).__init__()
        self.user_embedding = nn.Embedding(num_users, embedding_dim)
        self.item_embedding = nn.Embedding(num_items, embedding_dim)
        self.user_bias = nn.Embedding(num_users, 1)
        self.item_bias = nn.Embedding(num_items, 1)

    def forward(self, user_ids, item_ids):
        user_embeds = self.user_embedding(user_ids)
        item_embeds = self.item_embedding(item_ids)
        user_bias = self.user_bias(user_ids)
        item_bias = self.item_bias(item_ids)
        return (user_embeds * item_embeds).sum(1, keepdim=True) + user_bias + item_bias 

# Define DMF (Deep Matrix Factorization)
class DMF(nn.Module):
    def __init__(self, num_users, num_items, embedding_dim):
        super(DMF, self).__init__()
        self.user_embedding = nn.Embedding(num_users, embedding_dim)
        self.item_embedding = nn.Embedding(num_items, embedding_dim)
        self.fc = nn.Linear(embedding_dim * 2, 1)

    def forward(self, user_ids, item_ids):
       user_embeds = self.user_embedding(user_ids)
       item_embeds = self.item_embedding(item_ids)
       x = torch.cat([user_embeds, item_embeds], dim=1)
       return self.fc(x)  # Ensure this outputs [batch_size, 1]

# Define AutoRec
class AutoRec(nn.Module):
    def __init__(self, num_users, num_items, embedding_dim):
        super(AutoRec, self).__init__()
        self.encoder = nn.Linear(num_items, embedding_dim)
        self.decoder = nn.Linear(embedding_dim, num_items)

    def forward(self, x):
        x = torch.relu(self.encoder(x))
        return self.decoder(x)

# Define CDAE (Collaborative Denoising Autoencoder)
class CDAE(nn.Module):
    def __init__(self, num_users, num_items, embedding_dim):
        super(CDAE, self).__init__()
        self.encoder = nn.Linear(num_items, embedding_dim)
        self.decoder = nn.Linear(embedding_dim, num_items)

    def forward(self, x):
        x = torch.relu(self.encoder(x))
        return self.decoder(x)

# Define NADE (Neural Autoregressive Distribution Estimator)
# class NADE(nn.Module):
#     def __init__(self, num_items, embedding_dim):
#         super(NADE, self).__init__()
#         self.fc = nn.Linear(num_items, embedding_dim)

#     def forward(self, x):
#         return self.fc(x)

# Define CF-UIcA (Collaborative Filtering with User-Item Contextual Attention)
class CF_UIcA(nn.Module):
    def __init__(self, num_users, num_items, embedding_dim):
        super(CF_UIcA, self).__init__()
        self.user_embedding = nn.Embedding(num_users, embedding_dim)
        self.item_embedding = nn.Embedding(num_items, embedding_dim)

    def forward(self, user_ids, item_ids):
        user_embeds = self.user_embedding(user_ids)
        item_embeds = self.item_embedding(item_ids)
        return (user_embeds * item_embeds).sum(1)

# Define ST-GCN (Spatial Temporal Graph Convolutional Network)
class STGCN(nn.Module):
    def __init__(self, num_nodes, in_channels, out_channels):
        super(STGCN, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))

    def forward(self, x):
        return self.conv(x)

# Define NGCF (Next Generation Collaborative Filtering)
class NGCF(nn.Module):
    def __init__(self, num_users, num_items, embedding_dim):
        super(NGCF, self).__init__()
        self.user_embedding = nn.Embedding(num_users, embedding_dim)
        self.item_embedding = nn.Embedding(num_items, embedding_dim)

    def forward(self, user_ids, item_ids):
        user_embeds = self.user_embedding(user_ids)
        item_embeds = self.item_embedding(item_ids)
        return (user_embeds * item_embeds).sum(1)

# Define NMTR (Neural Matrix Factorization with Temporal Regularization)
class NMTR(nn.Module):
    def __init__(self, num_users, num_items, embedding_dim):
        super(NMTR, self).__init__()
        self.user_embedding = nn.Embedding(num_users, embedding_dim)
        self.item_embedding = nn.Embedding(num_items, embedding_dim)

    def forward(self, user_ids, item_ids):
        user_embeds = self.user_embedding(user_ids)
        item_embeds = self.item_embedding(item_ids)
        return (user_embeds * item_embeds).sum(1)

# Define DIPN (Deep Item-based Personalized Network)
class DIPN(nn.Module):
    def __init__(self, num_users, num_items, embedding_dim):
        super(DIPN, self).__init__()
        self.user_embedding = nn.Embedding(num_users, embedding_dim)
        self.item_embedding = nn.Embedding(num_items, embedding_dim)

    def forward(self, user_ids, item_ids):
        user_embeds = self.user_embedding(user_ids)
        item_embeds = self.item_embedding(item_ids)
        return (user_embeds * item_embeds).sum(1)

# Define NGCF+M (NGCF with Memory)
class NGCF_M(nn.Module):
    def __init__(self, num_users, num_items, embedding_dim):
        super(NGCF_M, self).__init__()
        self.user_embedding = nn.Embedding(num_users, embedding_dim)
        self.item_embedding = nn.Embedding(num_items, embedding_dim)

    def forward(self, user_ids, item_ids):
        user_embeds = self.user_embedding(user_ids)
        item_embeds = self.item_embedding(item_ids)
        return (user_embeds * item_embeds).sum(1)

# Define MBGCN (Multi-Branch Graph Convolutional Network)
class MBGCN(nn.Module):
    def __init__(self, num_users, num_items, embedding_dim):
        super(MBGCN, self).__init__()
        self.user_embedding = nn.Embedding(num_users, embedding_dim)
        self.item_embedding = nn.Embedding(num_items, embedding_dim)

    def forward(self, user_ids, item_ids):
        user_embeds = self.user_embedding(user_ids)
        item_embeds = self.item_embedding(item_ids)
        return (user_embeds * item_embeds).sum(1)

# Define MATN (Multi-Attention Temporal Network)
class MATN(nn.Module):
    def __init__(self, num_users, num_items, embedding_dim):
        super(MATN, self).__init__()
        self.user_embedding = nn.Embedding(num_users, embedding_dim)
        self.item_embedding = nn.Embedding(num_items, embedding_dim)

    def forward(self, user_ids, item_ids):
        user_embeds = self.user_embedding(user_ids)
        item_embeds = self.item_embedding(item_ids)
        return (user_embeds * item_embeds).sum(1)

# Define GNMR (Graph Neural Matrix Recommendation)
class GNMR(nn.Module):
    def __init__(self, num_users, num_items, embedding_dim):
        super(GNMR, self).__init__()
        self.user_embedding = nn.Embedding(num_users, embedding_dim)
        self.item_embedding = nn.Embedding(num_items, embedding_dim)

    def forward(self, user_ids, item_ids):
        user_embeds = self.user_embedding(user_ids)
        item_embeds = self.item_embedding(item_ids)
        return (user_embeds * item_embeds).sum(1)

# Define MBRec (Multi-Branch Recommendation)
class MBRec(nn.Module):
    def __init__(self, num_users, num_items, embedding_dim):
        super(MBRec, self).__init__()
        self.user_embedding = nn.Embedding(num_users, embedding_dim)
        self.item_embedding = nn.Embedding(num_items, embedding_dim)

    def forward(self, user_ids, item_ids):
        user_embeds = self.user_embedding(user_ids)
        item_embeds = self.item_embedding(item_ids)
        return (user_embeds * item_embeds).sum(1)


In [6]:
embedding_dim = 64
# Initialize models
models = {
    "GraphTransformerV2": GraphTransformerV2(num_layers=3, d_model=128, num_heads=8, d_feedforward=256, input_dim=2),
    "BiasMF": BiasMF(num_users, num_items, embedding_dim),
    "DMF": DMF(num_users, num_items, embedding_dim),
    # "NCF": NCF(num_users, num_items, embedding_dim),
    "AutoRec": AutoRec(num_users, num_items, embedding_dim),
    "CDAE": CDAE(num_users, num_items, embedding_dim),
    # "NADE": NADE(num_items, embedding_dim),
    "CF_UIcA": CF_UIcA(num_users, num_items, embedding_dim),
    "STGCN": STGCN(num_items, 1, embedding_dim),
    "NGCF": NGCF(num_users, num_items, embedding_dim),
    "NMTR": NMTR(num_users, num_items, embedding_dim),
    "DIPN": DIPN(num_users, num_items, embedding_dim),
    "NGCF_M": NGCF_M(num_users, num_items, embedding_dim),
    "MBGCN": MBGCN(num_users, num_items, embedding_dim),
    "MATN": MATN(num_users, num_items, embedding_dim),
    "GNMR": GNMR(num_users, num_items, embedding_dim),
    "MBRec": MBRec(num_users, num_items, embedding_dim),
}

import torch
import torch.nn as nn
import torch.optim as optim

# Assume models are already defined and initialized in a dictionary 'models'
# Assume data_loader is properly set up

embedding_dim = 64
num_epochs = 10

for model_name, model in models.items():
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.BCEWithLogitsLoss()  # Binary classification

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for batch_idx, (user_ids, item_ids, labels) in enumerate(data_loader):
            print(f'Epoch {epoch + 1}, Batch {batch_idx + 1}')
            print(f'User IDs shape: {user_ids.shape}, Item IDs shape: {item_ids.shape}, Labels shape: {labels.shape}')
            
            optimizer.zero_grad()
            
            if model_name == "GraphTransformerV2":
                # GraphTransformerV2 needs custom input preparation
                x = torch.cat((user_ids.unsqueeze(1), item_ids.unsqueeze(1)), dim=1).float()  # Shape: [batch_size, 2]
                adjacency_matrix = torch.eye(user_ids.size(0)).float()  # Identity matrix as adjacency matrix
                graph_metrics = torch.zeros(user_ids.size(0), 2).float()  # Placeholder for graph metrics
                output = model(x, adjacency_matrix, graph_metrics)
            else:
                # For other models, ensure they output [batch_size, 1]
                output = model(user_ids, item_ids)
                output = output.view(-1, 1)  # Ensure output is [batch_size, 1]

            # Adjust labels for binary classification
            labels = labels.view(-1, 1).float()  # Ensure labels are of shape [batch_size, 1]

            # Debugging output shapes
            print(f'Output shape: {output.shape}, Labels shape: {labels.shape}')

            # Calculate loss
            loss = criterion(output, labels)
            
            # Check for NaN loss
            if torch.isnan(loss):
                print(f"Loss is NaN for {model_name} in Epoch {epoch + 1}, Batch {batch_idx + 1}. Stopping training.")
                break

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            # Print loss for the current batch
            print(f'Batch {batch_idx + 1}, Loss: {loss.item()}')

        # Print average loss for the epoch
        print(f'{model_name} - Epoch {epoch + 1}, Average Loss: {total_loss / len(data_loader)}')

        # Optional: Early stopping if loss becomes NaN
        if torch.isnan(torch.tensor(total_loss)):
            break

Epoch 1, Batch 1
User IDs shape: torch.Size([64]), Item IDs shape: torch.Size([64]), Labels shape: torch.Size([64])
Output shape: torch.Size([64, 2]), Labels shape: torch.Size([64, 1])


ValueError: Target size (torch.Size([64, 1])) must be the same as input size (torch.Size([64, 2]))