# A Generic GNN Model Framework for Food Recipe Recommendation

## Environment setup

In [1]:
import os
import gc
from collections import defaultdict

import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from torch import nn

from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.nn import SAGEConv, GATv2Conv, LGConv, LayerNorm, BatchNorm, HeteroConv
from torch_geometric.loader import LinkNeighborLoader, DataLoader, NodeLoader
from torch_geometric.data import HeteroData, Dataset

import copy

In [2]:
def flush():
  gc.collect()
  torch.cuda.empty_cache()
  torch.cuda.reset_peak_memory_stats()

flush()

os.environ['CUDA_LAUNCH_BLOCKING'] = "1" 
os.environ['TORCH_USE_CUDA_DSA'] = "1"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

## Graph Dataset Retrieval

In this step, we load the graphs already generated in the graph dataset generation step.  
In order to generate the graph dataset files, please follow the instruction on README.md. Since generating graph datasets is time consiming, we won't add them to each notebook.

In [3]:
def load_graph(file_path):
    return torch.load(file_path)

dataset_version = 1
base_data_path = f"../data/graph/v{dataset_version}"

train_graph = load_graph(f"{base_data_path}/train_graph.pt")
val_graph = load_graph(f"{base_data_path}/validation_graph.pt")
test_graph = load_graph(f"{base_data_path}/test_graph.pt")

train_graph

  return torch.load(file_path)


HeteroData(
  user={ num_nodes=226570 },
  recipe={ x=[231637, 3081] },
  (user, rates, recipe)={
    edge_index=[2, 770011],
    edge_label=[192502, 1],
    edge_label_index=[2, 192502],
  },
  (recipe, rev_rates, user)={ edge_index=[2, 770011] }
)

In [4]:
print("Train graph information: ")
print("Number of nodes:", train_graph.num_nodes)
print("Number of edges:", train_graph.num_edges)
print("Metadata:", train_graph.metadata())
print("Edge index:", train_graph['user', 'rates', 'recipe'].edge_index)
print("Recipe node_embeddings dimension: ", train_graph['recipe'].x.size(1))
print("Type of ('user', 'rates', 'recipe') edge index", train_graph[('user', 'rates', 'recipe')].edge_index.dtype)  
print("Type of ('user', 'rates', 'recipe') edge index: ", train_graph[('user', 'rates', 'recipe')].edge_label_index.dtype) 

Train graph information: 
Number of nodes: 458207
Number of edges: 1540022
Metadata: (['user', 'recipe'], [('user', 'rates', 'recipe'), ('recipe', 'rev_rates', 'user')])
Edge index: tensor([[  3106,    317,  16543,  ...,    541, 208023,    489],
        [211809,   6600, 109688,  ...,  62108,  96459, 200804]])
Recipe node_embeddings dimension:  3081
Type of ('user', 'rates', 'recipe') edge index torch.int64
Type of ('user', 'rates', 'recipe') edge index:  torch.int64


In [5]:
# Check for overlapping edges
train_edges = set(zip(train_graph['user', 'rates', 'recipe'].edge_label_index[0].tolist(),
                      train_graph['user', 'rates', 'recipe'].edge_label_index[1].tolist()))

val_edges = set(zip(val_graph['user', 'rates', 'recipe'].edge_label_index[0].tolist(),
                    val_graph['user', 'rates', 'recipe'].edge_label_index[1].tolist()))

test_edges = set(zip(test_graph['user', 'rates', 'recipe'].edge_label_index[0].tolist(),
                     test_graph['user', 'rates', 'recipe'].edge_label_index[1].tolist()))

overlap_val = train_edges & val_edges
overlap_test = train_edges & test_edges

print(f"Overlap between Training and Validation: {len(overlap_val)} edges")
print(f"Overlap between Training and Test: {len(overlap_test)} edges")


Overlap between Training and Validation: 0 edges
Overlap between Training and Test: 0 edges


## Model Implementation


In [6]:
import torch
from torch_geometric.nn import LGConv

class LGConvWrapper(torch.nn.Module):
    def __init__(self, lgconv):
        super(LGConvWrapper, self).__init__()
        self.lgconv = lgconv

    def forward(self, x, edge_index, edge_weight=None, **kwargs):
        # LGConv uses x, edge_index, and optional edge_weight
        if isinstance(x, tuple):  # If HeteroConv passes (x_src, x_dst)
            x = x[0]  # Use only x_src
        return self.lgconv(x, edge_index, edge_weight)


import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, GATv2Conv, HeteroConv
from torch_geometric.nn.norm import BatchNorm

class HeteroGNN(torch.nn.Module):
    def __init__(self, graph, model_type, hidden_channels, num_layers, dropout, l2_reg):
        super(HeteroGNN, self).__init__()
        self.model_type = model_type.lower()
        self.hidden_channels = hidden_channels
        self.num_layers = num_layers
        self.l2_reg = l2_reg

        # User node embeddings
        user_node_count = graph['user'].num_nodes
        self.user_emb = torch.nn.Embedding(user_node_count, self.hidden_channels)
        torch.nn.init.xavier_uniform_(self.user_emb.weight)

        # Recipe Features
        recipe_x_dim = graph['recipe'].x.size(-1)
        self.recipe_norm = BatchNorm(recipe_x_dim, affine=True)
        self.recipe_lin = torch.nn.Linear(recipe_x_dim, self.hidden_channels)
        torch.nn.init.xavier_uniform_(self.recipe_lin.weight)

        # Define Convolutional Layers and BatchNorms
        self.convs = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv_dict = {}
            for edge_type in graph.edge_types:
                if self.model_type == 'sage':
                    conv_dict[edge_type] = SAGEConv((-1, -1), self.hidden_channels)
                elif self.model_type == 'gat':
                    conv_dict[edge_type] = GATv2Conv((-1, -1), self.hidden_channels, heads=4, concat=False)
                elif self.model_type == 'lightgcn':
                    conv_dict[edge_type] = LGConvWrapper(LGConv(normalize=True))
                else:
                    raise ValueError("model_type should be one of ['sage', 'gat', 'lightgcn']")
            self.convs.append(HeteroConv(conv_dict, aggr='mean'))

            if self.model_type in ['sage', 'gat']:
                self.batch_norms.append(BatchNorm(self.hidden_channels))
            else:
                self.batch_norms.append(torch.nn.Identity())  # No BatchNorm for LightGCN

        # Dropout Layer
        self.dropout_layer = torch.nn.Dropout(dropout) if dropout > 0 else None

        # Prediction Layer (out_channels fixed to 1 for consistency)
        if self.model_type in ['sage', 'gat']:
            self.predict_mlp = torch.nn.Sequential(
                torch.nn.Linear(self.hidden_channels * 2, self.hidden_channels),
                torch.nn.ReLU(),
                torch.nn.Linear(self.hidden_channels, 1)  # Output fixed to 1
            )
        else:
            self.predict_mlp = None

    def forward(self, x_dict, edge_index_dict):
        # Replace user node features with node embeddings
        x_dict['user'] = self.user_emb.weight  # [num_users, hidden_channels]
        x_dict['recipe'] = self.recipe_norm(x_dict['recipe'])  # [num_recipes, recipe_x_dim]
        x_dict['recipe'] = self.recipe_lin(x_dict['recipe'])  # [num_recipes, hidden_channels]

        # Apply HeteroConv layers
        for conv, bn in zip(self.convs, self.batch_norms):
            x_dict = conv(x_dict, edge_index_dict)
            if self.model_type in ['sage', 'gat']:
                # Apply BatchNorm, Dropout, and Activation only for GraphSAGE and GAT
                x_dict = {key: bn(x) for key, x in x_dict.items()}
                if self.dropout_layer:
                    x_dict = {key: self.dropout_layer(x) for key, x in x_dict.items()}
                x_dict = {key: F.relu(x) for key, x in x_dict.items()}
            else:
                # For LightGCN, no BatchNorm, Dropout, or Activation
                pass

        # Collect final node embeddings
        out_dict = {
            'user': x_dict['user'],     # [num_users, hidden_channels]
            'recipe': x_dict['recipe']  # [num_recipes, hidden_channels]
        }
        return out_dict

    def predict(self, user_emb, recipe_emb):
        """
        Predict edge ratings between user embeddings and recipe embeddings.
        Clamps the predictions between 0 and 5 for all model types.
        Ensures consistent output shape across all models.
        """
        if self.model_type == 'lightgcn':
            # Inner product for LightGCN
            pred = (user_emb * recipe_emb).sum(dim=-1, keepdim=True)  # [batch_size, 1]
        else:
            # MLP for GraphSAGE and GAT
            combined = torch.cat([user_emb, recipe_emb], dim=-1)  # [batch_size, hidden_channels * 2]
            pred = self.predict_mlp(combined)  # [batch_size, 1]

        # Clamp predictions between 0 and 5
        pred = pred.clamp(min=0, max=5)
        return pred

    def loss_l2_regularization(self):
        """
        Computes L2 regularization loss for model parameters.
        """
        l2_loss = torch.sum(self.user_emb.weight**2)
        for name, param in self.named_parameters():
            if 'weight' in name and param.requires_grad:
                l2_loss += torch.sum(param**2)
        return self.l2_reg * l2_loss


## Model Training & Development

In [7]:
def train_epoch(model, loader, optimizer, device):
    """
    Trains the model for one epoch and computes training MSE.

    Parameters:
    - model (nn.Module): The GNN model.
    - loader (LinkNeighborLoader): Data loader for training.
    - optimizer (torch.optim.Optimizer): Optimizer.
    - device (torch.device): Device to run computations on.

    Returns:
    - average_mse (float): Average MSE over the training set.
    """
    model.train()
    mse_sum = 0.0
    count = 0

    for batch in tqdm(loader, desc="Training", unit="batch", leave=False):
        batch = batch.to(device)
        optimizer.zero_grad()

        # Prepare node features
        x_dict = {
            'user': torch.arange(batch['user'].num_nodes, device=device),
            'recipe': batch['recipe'].x
        }

        # Forward pass
        out_dict = model(x_dict, batch.edge_index_dict)

        # Extract edge information
        edge = batch['user', 'rates', 'recipe']
        user_emb = out_dict['user'][edge.edge_label_index[0]]
        recipe_emb = out_dict['recipe'][edge.edge_label_index[1]]

        # Prediction and target
        pred = model.predict(user_emb, recipe_emb)  # [batch_size, 1]
        target = edge.edge_label.float().view_as(pred)  # [batch_size, 1]

        # Compute loss
        loss = F.mse_loss(pred, target) + model.loss_l2_regularization()
        loss.backward()
        optimizer.step()

        # Accumulate MSE
        mse_sum += F.mse_loss(pred, target, reduction='sum').item()
        count += target.size(0)

    average_mse = mse_sum / count
    return average_mse


@torch.no_grad()
def evaluate_mse(model, loader, device):
    """
    Evaluates the model on the given data loader and computes MSE.

    Parameters:
    - model (nn.Module): The GNN model.
    - loader (LinkNeighborLoader): Data loader for evaluation.
    - device (torch.device): Device to run computations on.

    Returns:
    - average_mse (float): Average MSE over the evaluation set.
    """
    model.eval()
    mse_sum = 0.0
    count = 0

    for batch in tqdm(loader, desc="Evaluating", unit="batch", leave=False):
        batch = batch.to(device)

        # Prepare node features
        x_dict = {
            'user': torch.arange(batch['user'].num_nodes, device=device),
            'recipe': batch['recipe'].x
        }

        # Forward pass
        out_dict = model(x_dict, batch.edge_index_dict)

        # Extract edge information
        edge = batch['user', 'rates', 'recipe']
        user_emb = out_dict['user'][edge.edge_label_index[0]]
        recipe_emb = out_dict['recipe'][edge.edge_label_index[1]]

        # Prediction and target
        pred = model.predict(user_emb, recipe_emb)  # [batch_size, 1]
        target = edge.edge_label.float().view_as(pred)  # [batch_size, 1]

        # Accumulate MSE
        mse_sum += F.mse_loss(pred, target, reduction='sum').item()
        count += target.size(0)

    average_mse = mse_sum / count
    return average_mse


def create_link_neighbor_loader(data, edge_type, batch_size, num_neighbors, shuffle, num_workers):
    """
    Creates a LinkNeighborLoader for the specified edge type.

    Parameters:
    - data (HeteroData): The input graph data.
    - edge_type (tuple): The edge type for link prediction (e.g., ('user', 'rates', 'recipe')).
    - batch_size (int): Number of samples per batch.
    - num_neighbors (list): Number of neighbors to sample for each layer.
    - shuffle (bool): Whether to shuffle the data.
    - num_workers (int): Number of subprocesses for data loading.

    Returns:
    - LinkNeighborLoader: Configured data loader.
    """
    edge_label_index = data[edge_type].edge_label_index
    edge_label = data[edge_type].edge_label

    loader = LinkNeighborLoader(
        data=data,
        num_neighbors=num_neighbors,
        edge_label_index=(edge_type, edge_label_index),
        edge_label=edge_label,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
    )
    return loader


# Import necessary libraries
import torch
from torch_geometric.data import HeteroData
from torch_geometric.loader import LinkNeighborLoader
from tqdm import tqdm

# Assuming LGConvWrapper, HeteroGNN, train_epoch, evaluate_mse, and create_link_neighbor_loader are defined as above

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Model configuration
model_type = 'lightgcn'  # Options: 'sage', 'gat', 'lightgcn'
hidden_channels = 128
num_layers = 3
dropout = 0.5
l2_reg = 1e-5

# Initialize the model with the entire train_graph.
model = HeteroGNN(
    graph=train_graph, 
    model_type=model_type,
    hidden_channels=hidden_channels, 
    num_layers=num_layers, 
    dropout=dropout,
    l2_reg=l2_reg
).to(device)

print(model)

# Define edge type for link prediction
edge_type = ('user', 'rates', 'recipe')
batch_size = 512  # Adjust based on your GPU memory
num_neighbors = [10] + [5] * (num_layers - 1)
num_workers = 3

# Create loaders
train_loader = create_link_neighbor_loader(
    data=train_graph,
    edge_type=edge_type,
    batch_size=batch_size,
    num_neighbors=num_neighbors,
    shuffle=True,
    num_workers=num_workers
)

val_loader = create_link_neighbor_loader(
    data=val_graph,
    edge_type=edge_type,
    batch_size=batch_size,
    num_neighbors=num_neighbors,
    shuffle=False,        # No need to shuffle for evaluation
    num_workers=num_workers
)

# Optimizer and scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)

# Training loop
num_epochs = 20
best_val_mse = float('inf')

for epoch in range(1, num_epochs + 1):
    print(f"\nEpoch {epoch:02d}/{num_epochs}")

    # Training phase
    train_mse = train_epoch(model, train_loader, optimizer, device)

    # Evaluation phase
    val_mse = evaluate_mse(model, val_loader, device)

    # Scheduler step based on validation MSE
    scheduler.step(val_mse)

    # Checkpointing
    if val_mse < best_val_mse:
        best_val_mse = val_mse
        torch.save(model.state_dict(), 'best_hetero_gnn.pth')
        print(f"New best model saved with Validation MSE: {val_mse:.4f}")

    print(f"Train MSE: {train_mse:.4f}, Validation MSE: {val_mse:.4f}")


HeteroGNN(
  (user_emb): Embedding(226570, 128)
  (recipe_norm): BatchNorm(3081, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (recipe_lin): Linear(in_features=3081, out_features=128, bias=True)
  (convs): ModuleList(
    (0-2): 3 x HeteroConv(num_relations=2)
  )
  (batch_norms): ModuleList(
    (0-2): 3 x Identity()
  )
  (dropout_layer): Dropout(p=0.5, inplace=False)
)





Epoch 01/20


                                                    

TypeError: LGConvWrapper.forward() missing 1 required positional argument: 'edge_index'

# Evaluation
Finally, we evaluate the performance of our model on the validation and test graphs using the **Root Mean Squared Error (RMSE)** and **@Recall@k** metrics. 
Although MSE is used as a loss function for training due to its efficient gradient properties, we use RMSE for evaluation because it provides error values in the same units as the target variable, making it more interpretable when assessing model performance.

In [8]:
@torch.no_grad()
def get_all_node_embeddings(model, data, device, node_types=['user', 'recipe']):
    """
    Computes node_embeddings for all specified node types using a full forward pass.

    Parameters:
    - model (nn.Module): The trained GNN model.
    - data (HeteroData): The entire graph data.
    - device (torch.device): The device to perform computations on.
    - node_types (list): List of node types to compute node_embeddings for.

    Returns:
    - node_embeddings_dict (dict): Dictionary mapping node types to their embeddings.
    """
    model.eval()
    data = data.to(device)
    with torch.no_grad():
        # Prepare node features
        x_dict = {}
        for node_type in node_types:
            if node_type == 'user':
                # Use node indices as placeholders; model will replace them with embeddings
                x_dict[node_type] = torch.arange(data[node_type].num_nodes, device=device)
            else:
                # Use actual features for other node types
                x_dict[node_type] = data[node_type].x.to(device)

        # Forward pass
        out_dict = model(x_dict, data.edge_index_dict)

    node_embeddings_dict = {}
    for node_type in node_types:
        node_embeddings_dict[node_type] = out_dict[node_type]

    return node_embeddings_dict


node_embeddings = get_all_node_embeddings(
    model=model,
    data=train_graph,  # Ensure 'train_graph' includes all nodes
    device=device,
    node_types=['user', 'recipe']
)

node_embeddings

NameError: name 'model' is not defined

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.loader import LinkNeighborLoader
from tqdm import tqdm
from collections import defaultdict

@torch.no_grad()
def evaluate_by_rmse(model, data_loader, node_embeddings, device):
    model.eval()
    total_loss = 0.0
    total_edges = 0

    for batch in tqdm(data_loader, desc='Evaluating RMSE', leave=False):
        edge_type = ('user', 'rates', 'recipe')  # Adjust if your edge type differs
        edge_label_index = batch[edge_type].edge_label_index.to(device)  # [2, E]
        edge_label = batch[edge_type].edge_label.to(device).view(-1)  # [E]

        users = edge_label_index[0]    # [E]
        recipes = edge_label_index[1]  # [E]

        user_emb = node_embeddings['user'][users]        # [E, hidden_dim]
        recipe_emb = node_embeddings['recipe'][recipes]  # [E, hidden_dim]

        # Generate predictions
        pred = model.predict(user_emb, recipe_emb).squeeze()
        pred = pred.clamp(min=0, max=5)  # Clamp predictions

        # Compute MSE
        mse = F.mse_loss(pred, edge_label, reduction='sum').item()
        total_loss += mse
        total_edges += edge_label.size(0)

    average_rmse = (total_loss / total_edges) ** 0.5
    return average_rmse


@torch.no_grad()
def evaluate_by_recall_at_k(model, data_loader, node_embeddings, k, relevance_threshold, device):
    model.eval()
    user_predictions = defaultdict(list)
    user_true_items = defaultdict(set)

    for batch in tqdm(data_loader, desc=f"Evaluating Recall@{k}", leave=False):
        edge_type = ('user', 'rates', 'recipe')  # Adjust if your edge type differs
        edge_label_index = batch[edge_type].edge_label_index.to(device)  # [2, E]
        edge_label = batch[edge_type].edge_label.to(device).view(-1)  # [E]

        users = edge_label_index[0]    # [E]
        recipes = edge_label_index[1]  # [E]

        user_emb = node_embeddings['user'][users]        # [E, hidden_dim]
        recipe_emb = node_embeddings['recipe'][recipes]  # [E, hidden_dim]

        # Generate predictions
        pred = model.predict(user_emb, recipe_emb).squeeze()
        pred = pred.clamp(min=0, max=5)  # Clamp predictions

        # Populate user_predictions and user_true_items
        for i, user_id in enumerate(users.cpu().numpy()):
            user_predictions[user_id].append((pred[i].item(), recipes[i].item()))
            if edge_label[i].item() >= relevance_threshold:
                user_true_items[user_id].add(recipes[i].item())

    recalls = []
    for user_id in user_predictions:
        # Sort predictions by score in descending order and get top-k items
        sorted_predictions = sorted(user_predictions[user_id], key=lambda x: x[0], reverse=True)
        topk_recipes = {item for _, item in sorted_predictions[:k]}
        true_items = user_true_items[user_id]

        if true_items:
            recall = len(topk_recipes & true_items) / len(true_items)
            recalls.append(recall)

    recall_at_k = sum(recalls) / len(recalls) if recalls else 0.0
    return recall_at_k


def create_link_neighbor_loader(data, edge_type, batch_size, num_neighbors, shuffle, num_workers):
    if edge_type not in data.edge_types:
        raise ValueError(f"Edge type {edge_type} not found in the data.")

    edge_label_index = data[edge_type].edge_label_index
    edge_label = data[edge_type].edge_label

    loader = LinkNeighborLoader(
        data=data,
        num_neighbors=num_neighbors,
        edge_label_index=(edge_type, edge_label_index),
        edge_label=edge_label,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
    )
    return loader

test_loader = create_link_neighbor_loader(
    data=test_graph,
    edge_type=edge_type,
    batch_size=batch_size,
    num_neighbors=num_neighbors,
    shuffle=False,        # No need to shuffle for evaluation
    num_workers=num_workers
)



# List of evaluation datasets and their corresponding loaders
evaluation_sets = [
    ('Validation', val_graph, val_loader),
    ('Test', test_graph, test_loader)
]

# Define standalone variables for recall evaluation
k = 5
rating_threshold = 4 

# Evaluation for Recall@K and RMSE
for data_split, graph, loader in evaluation_sets:
    # Compute RMSE.
    rmse = evaluate_by_rmse(
        model=model,
        data_loader=loader,
        node_embeddings=node_embeddings,
        device=device
    )

    # Compute Recall@K.
    recall = evaluate_by_recall_at_k(
        model=model,
        data_loader=loader,
        node_embeddings=node_embeddings,
        k=k,
        relevance_threshold=rating_threshold,
        device=device
    )
        
    print(f"{data_split} set: RMSE = {rmse:.4f}, Recall@{k} = {recall:.4f}")