# Training a GraphSAGE-based GNN Model for Food Recipe Recommendation

## Environment setup

In [1]:
import gc
from collections import defaultdict

import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from torch import nn

from tqdm import tqdm
from torch_geometric.nn import LGConv, SAGEConv, GATv2Conv, to_hetero
from torch_geometric.loader import LinkNeighborLoader

import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.nn import SAGEConv, GATv2Conv, LGConv
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.transforms import ToUndirected
from torch_geometric.data import HeteroData
from torch_geometric.nn import LayerNorm, BatchNorm

import copy

In [None]:
def flush():
  gc.collect()
  torch.cuda.empty_cache()
  torch.cuda.reset_peak_memory_stats()

flush()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

## Graph dataset loading

In this step, we load the graphs already generated in the graph dataset generation step.    
Since generating graph datasets is time consiming, we won't add them to each notebook.

In [None]:
def load_graph(file_path):
    return torch.load(file_path)

dataset_version = 1
base_data_path = f"../data/graph/v{dataset_version}"

train_graph = load_graph(f"{base_data_path}/train_graph.pt")
validation_graph = load_graph(f"{base_data_path}/validation_graph.pt")
test_graph = load_graph(f"{base_data_path}/test_graph.pt")
metadata = train_graph.metadata()

train_graph

In [None]:
validation_graph

In [None]:
print("Train graph information: ")
print("Number of nodes:", train_graph.num_nodes)
print("Number of edges:", train_graph.num_edges)
print("Metadata:", train_graph.metadata())
print("Edge index:", train_graph['user', 'rates', 'recipe'].edge_index)
print("Recipe embeddings dimension: ", train_graph['recipe'].x.size(1))
print("Type of ('user', 'rates', 'recipe') edge index", train_graph[('user', 'rates', 'recipe')].edge_index.dtype)  
print("Type of ('user', 'rates', 'recipe') edge index: ", train_graph[('user', 'rates', 'recipe')].edge_label_index.dtype) 

In [None]:
# Extract edge_label for 'user', 'rates', 'recipe'
ratings = train_graph[('user', 'rates', 'recipe')].edge_label

# Verify statistics of ratings
print("Statistics of edge_label (rates) in train_graph:")
print(f"Min rating: {ratings.min().item()}")
print(f"Max rating: {ratings.max().item()}")
print(f"Unique ratings: {ratings.unique().tolist()}")
print(f"Total number of ratings: {ratings.size(0)}")

## Model Implementation


In [7]:
# ############################################################
# # Model Definition
# ############################################################
# class HeteroGNN(nn.Module):
#     def __init__(self, metadata, hidden_channels=64, out_channels=1, model_type='sage', 
#                  num_layers=2, dropout=0.5, l2_reg=1e-5, normalize=True):
#         super().__init__()
#         self.metadata = metadata
#         self.model_type = model_type.lower()
#         self.hidden_channels = hidden_channels
#         self.out_channels = out_channels
#         self.num_layers = num_layers
#         self.dropout = dropout
#         self.l2_reg = l2_reg
#         self.normalize = normalize

#         # User embedding
#         user_node_count = train_graph['user'].num_nodes
#         self.user_emb = nn.Embedding(user_node_count, self.hidden_channels)
#         nn.init.xavier_uniform_(self.user_emb.weight)

#         # Recipe features
#         recipe_x_dim = train_graph['recipe'].x.size(-1)
#         self.recipe_norm = BatchNorm(recipe_x_dim, affine=True) if self.normalize else nn.Identity()
#         self.recipe_lin = nn.Linear(recipe_x_dim, self.hidden_channels)
#         nn.init.xavier_uniform_(self.recipe_lin.weight)

#         # Select GNN Layer
#         if self.model_type == 'sage':
#             self.conv_class = SAGEConv
#         elif self.model_type == 'gat':
#             self.conv_class = GATv2Conv
#         elif self.model_type == 'lightgcn':
#             self.conv_class = LGConv
#         else:
#             raise ValueError("model_type should be one of ['sage', 'gat', 'lightgcn']")

#         # GNN Layers
#         self.convs = nn.ModuleList()
#         if self.model_type in ['sage', 'gat']:
#             for _ in range(num_layers):
#                 self.convs.append(self.conv_class(self.hidden_channels, self.hidden_channels))
#         else:
#             # LightGCN doesn't require input/output dimensions
#             for _ in range(num_layers):
#                 self.convs.append(LGConv())

#         # Dropout layer (applies only if dropout > 0)
#         self.dropout_layer = nn.Dropout(dropout) if dropout > 0 else None

#         # Prediction Layer:
#         # - For LightGCN: rating = dot product
#         # - For SAGE/GAT: use MLP
#         if self.model_type in ['sage', 'gat']:
#             self.predict_mlp = nn.Sequential(
#                 nn.Linear(self.hidden_channels * 2, self.hidden_channels),
#                 nn.ReLU(),
#                 nn.Linear(self.hidden_channels, out_channels)
#             )
#         else:
#             self.predict_mlp = None

#     def forward(self, x_dict, edge_index_dict):
#         # Replace user node features with embeddings
#         x_dict['user'] = self.user_emb.weight
#         x_dict['recipe'] = self.recipe_norm(x_dict['recipe'])
#         x_dict['recipe'] = self.recipe_lin(x_dict['recipe'])

#         # Message passing
#         user_recipe_edges = edge_index_dict[('user', 'rates', 'recipe')]
#         recipe_user_edges = edge_index_dict[('recipe', 'rev_rates', 'user')]

#         for conv in self.convs:
#             x_dict['user'] = conv(x_dict['user'], user_recipe_edges)
#             x_dict['recipe'] = conv(x_dict['recipe'], recipe_user_edges)

#             # Apply dropout if defined
#             if self.dropout_layer:
#                 x_dict['user'] = self.dropout_layer(x_dict['user'])
#                 x_dict['recipe'] = self.dropout_layer(x_dict['recipe'])

#         return x_dict

#     def predict(self, user_emb, recipe_emb):
#         if self.model_type == 'lightgcn':
#             # LightGCN rating = dot product
#             return (user_emb * recipe_emb).sum(dim=-1, keepdim=True)
#         else:
#             # For SAGE/GAT: use MLP
#             combined = torch.cat([user_emb, recipe_emb], dim=-1)
#             return self.predict_mlp(combined)

#     def loss_l2_regularization(self):
#         l2_loss = torch.sum(self.user_emb.weight**2)
#         for name, param in self.named_parameters():
#             if 'weight' in name and param.requires_grad:
#                 l2_loss += torch.sum(param**2)
#         return self.l2_reg * l2_loss

# ############################################################
# # Data Loader for Mini-Batching
# ############################################################
# train_loader = LinkNeighborLoader(
#     data=train_graph,
#     num_neighbors=[10, 5],
#     edge_label_index=(('user', 'rates', 'recipe'), train_graph['user', 'rates', 'recipe'].edge_label_index),
#     edge_label=train_graph['user', 'rates', 'recipe'].edge_label,
#     batch_size=1024,
#     shuffle=True
# )

# ############################################################
# # Training Setup
# ############################################################
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# model_type = 'sage'  # Change to 'sage', 'gat', or 'lightgcn' as needed
# dropout = 0.0 if model_type == 'lightgcn' else 0.5
# hidden_channels = 128
# model = HeteroGNN(metadata=metadata, hidden_channels=hidden_channels, num_layers=2, model_type=model_type, dropout=dropout).to(device)
# criterion = nn.MSELoss()
# optimizer = optim.Adam(model.parameters(), lr=0.001)
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)

# ############################################################
# # Training Function
# ############################################################
# def train(model, loader, optimizer):
#     model.train()
#     total_loss = 0
#     for batch in loader:
#         batch = batch.to(device)
#         optimizer.zero_grad()

#         x_dict = {
#             'user': torch.arange(batch['user'].num_nodes, device=device),
#             'recipe': batch['recipe'].x
#         }

#         out_dict = model(x_dict, batch.edge_index_dict)

#         # Extract edge indices for rating prediction
#         user_nodes = batch['user', 'rates', 'recipe'].edge_label_index[0]
#         recipe_nodes = batch['user', 'rates', 'recipe'].edge_label_index[1]

#         user_emb = out_dict['user'][user_nodes]
#         recipe_emb = out_dict['recipe'][recipe_nodes]

#         pred = model.predict(user_emb, recipe_emb)
#         target = batch['user', 'rates', 'recipe'].edge_label.float()

#         if target.dim() == 1:
#             target = target.unsqueeze(-1)

#         loss = criterion(pred, target)
#         loss += model.loss_l2_regularization()

#         loss.backward()
#         optimizer.step()
#         total_loss += loss.item() * target.size(0)

#     return total_loss / len(loader.dataset)

# ############################################################
# # Training Loop
# ############################################################
# num_epochs = 5
# for epoch in range(1, num_epochs + 1):
#     train_mse = train(model, train_loader, optimizer)
#     scheduler.step(train_mse)
#     print(f"Epoch {epoch:03d}, Training MSE: {train_mse:.4f}")

In [8]:
# GraphSage
# Epoch 001, Training MSE: 2.6171
# Epoch 002, Training MSE: 1.7348
# Epoch 003, Training MSE: 1.7048
# Epoch 004, Training MSE: 1.6856
# Epoch 005, Training MSE: 1.6774

# LightGCN
# Epoch 001, Training MSE: 15.5694
# Epoch 002, Training MSE: 14.0834
# Epoch 003, Training MSE: 14.0645
# Epoch 004, Training MSE: 13.9592
# Epoch 005, Training MSE: 13.9545

# GAT
# Epoch 001, Training MSE: 2.4236
# Epoch 002, Training MSE: 1.7843
# Epoch 003, Training MSE: 1.7651
# Epoch 004, Training MSE: 1.7506
# Epoch 005, Training MSE: 1.7458

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.nn import SAGEConv, GATv2Conv, LGConv
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.nn import BatchNorm
from torch_geometric.data import HeteroData

############################################################
# Model Definition
############################################################
class HeteroGNN(nn.Module):
    def __init__(self, metadata, hidden_channels=64, out_channels=1, model_type='sage', 
                 num_layers=2, dropout=0.5, l2_reg=1e-5, normalize=True):
        super().__init__()
        self.metadata = metadata
        self.model_type = model_type.lower()
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.num_layers = num_layers
        self.dropout = dropout
        self.l2_reg = l2_reg
        self.normalize = normalize

        # User embeddings
        user_node_count = metadata[0].index('user') != -1 and train_graph['user'].num_nodes
        self.user_emb = nn.Embedding(user_node_count, self.hidden_channels)
        nn.init.xavier_uniform_(self.user_emb.weight)

        # Recipe features
        recipe_x_dim = train_graph['recipe'].x.size(-1)
        self.recipe_norm = BatchNorm(recipe_x_dim, affine=True) if self.normalize else nn.Identity()
        self.recipe_lin = nn.Linear(recipe_x_dim, self.hidden_channels)
        nn.init.xavier_uniform_(self.recipe_lin.weight)

        # Select GNN layer based on model_type
        if self.model_type == 'sage':
            # GraphSAGE: typically mean aggregator (default) is used.
            ConvClass = SAGEConv
            conv_args = (self.hidden_channels, self.hidden_channels)
        elif self.model_type == 'gat':
            # GATv2: use multiple heads, no static attention problem.
            # The original GAT paper often used multiple heads (e.g., 8 heads).
            # Here we choose heads=4 and concat=False to keep output dimension stable.
            # This is a reasonable adaptation staying close to GAT-style architectures.
            ConvClass = GATv2Conv
            conv_args = (self.hidden_channels, self.hidden_channels)
            self.gat_heads = 4
        elif self.model_type == 'lightgcn':
            # LightGCN: no input/output dimensions needed, no nonlinearities, no features transformed.
            # Just LGConv layers.
            ConvClass = LGConv
            conv_args = ()
        else:
            raise ValueError("model_type should be one of ['sage', 'gat', 'lightgcn']")

        # GNN Layers
        self.convs = nn.ModuleList()
        if self.model_type == 'sage':
            for _ in range(num_layers):
                self.convs.append(ConvClass(*conv_args))
        elif self.model_type == 'gat':
            # For GATv2Conv, specify heads and concat=False for simplicity
            # Both layers have the same dimension since concat=False.
            for _ in range(num_layers):
                self.convs.append(ConvClass(self.hidden_channels, self.hidden_channels, heads=self.gat_heads, concat=False))
        else:  # lightgcn
            for _ in range(num_layers):
                self.convs.append(ConvClass())

        # Dropout layer (applies only if dropout > 0)
        self.dropout_layer = nn.Dropout(dropout) if dropout > 0 else None

        # Prediction Layer:
        # - For LightGCN: rating = dot product of final user & item embeddings.
        # - For SAGE/GAT: use an MLP for rating prediction.
        if self.model_type in ['sage', 'gat']:
            self.predict_mlp = nn.Sequential(
                nn.Linear(self.hidden_channels * 2, self.hidden_channels),
                nn.ReLU(),
                nn.Linear(self.hidden_channels, out_channels)
            )
        else:
            self.predict_mlp = None

    def forward(self, x_dict, edge_index_dict):
        # Replace user node features with embeddings
        x_user = self.user_emb.weight
        x_recipe = self.recipe_norm(x_dict['recipe'])
        x_recipe = self.recipe_lin(x_recipe)

        # For LightGCN, we need to sum embeddings from all layers, including the initial one.
        if self.model_type == 'lightgcn':
            # Initial embeddings
            user_emb_layers = [x_user]
            recipe_emb_layers = [x_recipe]
        else:
            user_emb_layers = []
            recipe_emb_layers = []

        # Message passing
        user_recipe_edges = edge_index_dict[('user', 'rates', 'recipe')]
        recipe_user_edges = edge_index_dict[('recipe', 'rev_rates', 'user')]

        x_u = x_user
        x_r = x_recipe
        for conv in self.convs:
            x_u = conv(x_u, user_recipe_edges)
            x_r = conv(x_r, recipe_user_edges)

            # Dropout if applicable
            if self.dropout_layer and self.model_type in ['sage', 'gat']:
                x_u = self.dropout_layer(x_u)
                x_r = self.dropout_layer(x_r)

            if self.model_type == 'lightgcn':
                # Accumulate embeddings from each layer for LightGCN
                user_emb_layers.append(x_u)
                recipe_emb_layers.append(x_r)

        if self.model_type == 'lightgcn':
            # Final embedding is the sum of all layer embeddings (including initial)
            x_user_final = torch.stack(user_emb_layers, dim=0).mean(dim=0)  # LightGCN uses averaged sum
            x_recipe_final = torch.stack(recipe_emb_layers, dim=0).mean(dim=0)
        else:
            x_user_final = x_u
            x_recipe_final = x_r

        out_dict = {
            'user': x_user_final,
            'recipe': x_recipe_final
        }
        return out_dict

    def predict(self, user_emb, recipe_emb):
        if self.model_type == 'lightgcn':
            # LightGCN rating = dot product
            return (user_emb * recipe_emb).sum(dim=-1, keepdim=True)
        else:
            # For SAGE/GAT: use MLP
            combined = torch.cat([user_emb, recipe_emb], dim=-1)
            return self.predict_mlp(combined)

    def loss_l2_regularization(self):
        l2_loss = torch.sum(self.user_emb.weight**2)
        for name, param in self.named_parameters():
            if 'weight' in name and param.requires_grad:
                l2_loss += torch.sum(param**2)
        return self.l2_reg * l2_loss


############################################################
# Data Loader for Mini-Batching
############################################################
train_loader = LinkNeighborLoader(
    data=train_graph,
    num_neighbors=[10, 5],
    edge_label_index=(('user', 'rates', 'recipe'), train_graph['user', 'rates', 'recipe'].edge_label_index),
    edge_label=train_graph['user', 'rates', 'recipe'].edge_label,
    batch_size=1024,
    shuffle=True
)

############################################################
# Training Setup
############################################################
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Example usage:
# For LightGCN (close to original): no dropout, no MLP, dot product rating
model_type = 'lightgcn'  # choose 'sage', 'gat', or 'lightgcn'
dropout = 0.0 if model_type == 'lightgcn' else 0.5
hidden_channels = 128

model = HeteroGNN(metadata=metadata, hidden_channels=hidden_channels, num_layers=2, model_type=model_type, dropout=dropout).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)

############################################################
# Training Function
############################################################
def train(model, loader, optimizer):
    model.train()
    total_loss = 0
    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()

        x_dict = {
            'user': torch.arange(batch['user'].num_nodes, device=device),
            'recipe': batch['recipe'].x
        }

        out_dict = model(x_dict, batch.edge_index_dict)

        # Extract edge indices for rating prediction
        user_nodes = batch['user', 'rates', 'recipe'].edge_label_index[0]
        recipe_nodes = batch['user', 'rates', 'recipe'].edge_label_index[1]

        user_emb = out_dict['user'][user_nodes]
        recipe_emb = out_dict['recipe'][recipe_nodes]

        pred = model.predict(user_emb, recipe_emb)
        target = batch['user', 'rates', 'recipe'].edge_label.float()

        if target.dim() == 1:
            target = target.unsqueeze(-1)

        loss = criterion(pred, target)
        loss += model.loss_l2_regularization()

        loss.backward()
        optimizer.step()
        total_loss += loss.item() * target.size(0)

    return total_loss / len(loader.dataset)

############################################################
# Training Loop (Example)
############################################################
num_epochs = 5
for epoch in range(1, num_epochs + 1):
    train_mse = train(model, train_loader, optimizer)
    scheduler.step(train_mse)
    print(f"Epoch {epoch:03d}, Training MSE: {train_mse:.4f}")


In [14]:
# LightGCN:
# Epoch 001, Training MSE: 9.5261
# Epoch 002, Training MSE: 6.4682
# Epoch 003, Training MSE: 6.3267
# Epoch 004, Training MSE: 6.2442
# Epoch 005, Training MSE: 6.2378

## Model Training

In this step, we train our initialized \texttt{SageRecModel} model to optimize for the edge rating prediction task, where we predict user ratings for recipes.

### Mini-Batching
To manage memory and computation on our large graph datasets, we use mini-batching, dividing the data into smaller, manageable subsets that the model processes sequentially. With PyG’s `LinkNeighborLoader`, we sample neighbors around each target edge, focusing on the local neighborhood of each user-recipe interaction. This approach enables the model to capture essential neighborhood context without loading the entire graph, making it highly efficient for large-scale training.


### Optimization
- **Optimizer**: To optimize the model’s parameters, we use Adam optimizer that handles sparse gradients well, suitable for graph neural networks. Regularization is applied through weight decay, which discourages complex solutions and helps prevent overfitting, leading to better generalization.  
To further stabilize training, we use a learning rate scheduler that dynamically adjusts the learning rate, reducing it gradually as the model approaches convergence. This prevents overshooting during optimization and enables fine-tuning for more accurate predictions.


- **Loss Function**: Mean Squared Error (MSE) loss measures the difference between predicted and actual ratings. Given the continuous nature of ratings, MSE is a suitable choice for our link regression task and is calculated as:

  $$
  \text{MSE Loss} = \frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2
  $$

  where $y_i$ is the true rating, $\hat{y}_i$ is the predicted rating, and $N$ is the number of samples.


### Training Process
For each epoch, we perform the following steps:
  1. Batch sampling: Sampling mini-batches of edges with the train data loader, allowing the model to process a portion of the data at each step.
  2. Forward propagation: Generating predictions for each mini-batch and calculating the MSE loss based on the difference between predicted and actual ratings.
  3. Backward Ppropagation: Updating the model’s parameters based on the computed gradients.
  4. Parameter update: Updating the model’s parameters based on the computed gradients by optimizer.
  5. Learning rate adjustment: Adjusting the learning rate periodically with the scheduler, stabilizing training as the model converges.

In [10]:
# def create_link_neighbor_loader(data, edge_type, batch_size, num_neighbors, shuffle, num_workers):
#     """
#     Creates a LinkNeighborLoader for the specified edge type in a HeteroData object.

#     Parameters:
#     - data (HeteroData): The heterogeneous graph data.
#     - edge_type (tuple): The edge type for which to create the loader, e.g., ('user', 'rates', 'recipe').
#     - batch_size (int): Number of edges to include in each batch.
#     - num_neighbors (list): Number of neighbors to sample at each layer.
#     - shuffle (bool): Whether to shuffle the data.
#     - num_workers (int): Number of subprocesses to use for data loading.

#     Returns:
#     - loader (LinkNeighborLoader): The data loader for the specified edge type.
#     """
#     # Ensure the edge_type exists in the data
#     if edge_type not in data.edge_types:
#         raise ValueError(f"Edge type {edge_type} not found in the data.")

#     # Access the edge_label_index and edge_label for the specified edge type
#     edge_label_index = data[edge_type].get('edge_label_index', data[edge_type].edge_index)
#     edge_label = data[edge_type].get('edge_label', None)

#     # Create the LinkNeighborLoader
#     loader = LinkNeighborLoader(
#         data=data,
#         num_neighbors=num_neighbors,
#         edge_label_index=(edge_type, edge_label_index),
#         edge_label=edge_label,
#         batch_size=batch_size,
#         shuffle=shuffle,
#         num_workers=num_workers,
#     )

#     return loader


# edge_type = ('user', 'rates', 'recipe') # Define the edge type of interest
# batch_size = 512  # Adjust based on your GPU memory capacity
# num_neighbors = [10, 5, 5] # Number of neighbors to sample at each layer
# num_workers = 4  # Adjust based on your system

# # Create the training data loader
# train_data_loader = create_link_neighbor_loader(
#     data=train_graph,
#     edge_type=edge_type,
#     batch_size=batch_size, 
#     num_neighbors=num_neighbors, 
#     shuffle=True,
#     num_workers=num_workers 
# )

# weight_decay = 0.0001
# learning_rate = 0.001
# optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# # Implement a learning rate scheduler
# step_size = 10
# gamma = 0.1
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

In [11]:
# def train(model, data_loader, optimizer, scheduler):
#     model.train()
#     total_loss = 0
#     for batch in tqdm(data_loader, desc='Training', unit='batch', leave=False):
#         batch = batch.to(device)
#         optimizer.zero_grad()
#         pred = model(batch.x_dict, batch.edge_index_dict, batch['user', 'rates', 'recipe'].edge_label_index)
#         # Flatten target to match pred.
#         target = batch['user', 'rates', 'recipe'].edge_label.float().view(-1)
#         loss = F.mse_loss(pred, target)
#         loss.backward()
#         optimizer.step()
#         total_loss += loss.item() * target.size(0)
    
#     # Compute average loss (MSE) per data point (edge). 
#     mse = total_loss / len(data_loader.dataset)

#     # Step the scheduler to update the learning rate
#     scheduler.step()

#     return mse


# # Training loop
# num_epochs = 20
# for epoch in range(1, num_epochs + 1):
#     loss = train(model, train_data_loader, optimizer, scheduler)
#     print(f'Epoch: {epoch:03d}, Loss (MSE): {loss:.4f}')

# Evaluation
Finally, we evaluate the performance of our model on the validation and test graphs using the **Root Mean Squared Error (RMSE)** and **@Recall@k** metrics. 
Although MSE is used as a loss function for training due to its efficient gradient properties, we use RMSE for evaluation because it provides error values in the same units as the target variable, making it more interpretable when assessing model performance.

We also report the evaluation results on the training graph to monitor the model error.

In [None]:
from torch.utils.data import DataLoader, Dataset

class EdgeBatchDataset(Dataset):
    def __init__(self, data, edge_type, batch_size):
        """
        Dataset that batches only target edges and labels for evaluation.
        The entire graph is shared across all batches.
        """
        self.data = data
        self.edge_label_index = data[edge_type].edge_label_index
        self.edge_label = data[edge_type].edge_label
        self.batch_size = batch_size

    def __len__(self):
        return (self.edge_label.size(0) + self.batch_size - 1) // self.batch_size

    def __getitem__(self, idx):
        # Get batch range
        start = idx * self.batch_size
        end = min((idx + 1) * self.batch_size, self.edge_label.size(0))
        
        # Extract target edges and labels for this batch
        edge_label_index = self.edge_label_index[:, start:end]
        edge_label = self.edge_label[start:end]
        
        return edge_label_index, edge_label

def create_evaluation_data_loader(data, edge_type, batch_size, num_workers):
    """
    Creates a memory-efficient DataLoader for evaluation without cloning the graph.
    """
    dataset = EdgeBatchDataset(data, edge_type, batch_size)
    loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=num_workers)
    return loader


edge_type = ('user', 'rates', 'recipe')
batch_size = 1024
num_neighbors = [10, 5] # Number of neighbors to sample at each layer
num_workers = 4  # Adjust based on your system

# Create evaluation data loaders
evaluation_data_loaders = {}
evaluation_data_loaders["validation"] = create_evaluation_data_loader(validation_graph, edge_type, batch_size, num_workers)
evaluation_data_loaders["test"] = create_evaluation_data_loader(test_graph, edge_type, batch_size, num_workers)

evaluation_data_loaders 

In [None]:
@torch.no_grad()
def evaluate_by_rmse(model, data_loader, full_data, device):
    model.eval()
    total_loss = 0
    total_edges = 0

    for edge_label_index, edge_label in tqdm(data_loader, desc='Evaluating RMSE', leave=False):
        edge_label_index = edge_label_index.squeeze(0).to(device)
        edge_label = edge_label.squeeze(0).to(device).view(-1)  # Flatten target tensor

        # Forward pass using the full graph
        pred = model(full_data.x_dict, full_data.edge_index_dict, edge_label_index)
        pred = pred.clamp(min=0, max=5)  # Clamp predictions

        # Compute RMSE
        total_loss += F.mse_loss(pred, edge_label, reduction='sum').item()
        total_edges += edge_label.size(0)

    return (total_loss / total_edges) ** 0.5



def evaluate_by_recall_at_k(model, data_loader, full_data, k, relevance_threshold):
    model.eval()
    user_predictions = defaultdict(list)
    user_true_items = defaultdict(set)

    for edge_label_index, edge_label in tqdm(data_loader, desc=f"Evaluating Recall@{k}", leave=False):
        edge_label_index = edge_label_index.squeeze(0).to(device)
        edge_label = edge_label.squeeze(0).to(device).view(-1)

        # Forward pass
        pred = model(full_data.x_dict, full_data.edge_index_dict, edge_label_index)
        pred = pred.clamp(min=0, max=5)

        # Extract user and item indices
        users, items = edge_label_index

        for i, user_id in enumerate(users.cpu().numpy()):
            user_predictions[user_id].append((pred[i].item(), items[i].item()))
            if edge_label[i].item() >= relevance_threshold:
                user_true_items[user_id].add(items[i].item())

    recalls = []
    for user_id in user_predictions:
        # Sort predictions by score in descending order and get top-k items
        top_k_pred_items = {item for _, item in sorted(user_predictions[user_id], key=lambda x: x[0], reverse=True)[:k]}
        true_items = user_true_items[user_id]

        if true_items:
            recall = len(top_k_pred_items & true_items) / len(true_items)
            recalls.append(recall)

    return sum(recalls) / len(recalls) if recalls else 0.0


k = 5
relevance_threshold = 4

for data_split_name, data_loader in evaluation_data_loaders.items():
    full_data = validation_graph if data_split_name == "validation" else test_graph
    full_data = full_data.to(device)

    rmse = evaluate_by_rmse(model, data_loader, full_data, device)
    recall_at_k = evaluate_by_recall_at_k(model, data_loader, full_data, k, relevance_threshold)

    print(f"{data_split_name.capitalize()}: RMSE = {rmse:.4f}, Recall@{k} = {recall_at_k:.4f}")