# Training a GraphSAGE-based GNN Model for Food Recipe Recommendation

## Environment setup

In [1]:
import gc

import torch
import torch.nn.functional as F
import torch.optim as optim

from tqdm import tqdm
from torch_geometric.nn import SAGEConv, to_hetero
from torch_geometric.loader import LinkNeighborLoader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def flush():
  gc.collect()
  torch.cuda.empty_cache()
  torch.cuda.reset_peak_memory_stats()

flush()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

## Graph dataset loading

In this step, we load the graphs already generated in the graph dataset generation step.    
Since generating graph datasets is time consiming, we won't add them to each notebook.

In [3]:
def load_graph(file_path):
    return torch.load(file_path)

dataset_version = 1
base_data_path = f"../data/graph/v{dataset_version}"

train_graph = load_graph(f"{base_data_path}/train_graph.pt")
validation_graph = load_graph(f"{base_data_path}/validation_graph.pt")
test_graph = load_graph(f"{base_data_path}/test_graph.pt")

train_graph

  return torch.load(file_path)


HeteroData(
  user={ num_nodes=226570 },
  recipe={ x=[231637, 3081] },
  (user, rates, recipe)={
    edge_index=[2, 770011],
    edge_label=[192502, 1],
    edge_label_index=[2, 192502],
  },
  (recipe, rev_rates, user)={ edge_index=[2, 770011] }
)

In [4]:
print("Train graph information: ")
print("Numbder of nodes:", train_graph.num_nodes)
print("Numbder of edges:", train_graph.num_edges)
print("Metadata:", train_graph.metadata())
print("Edge index:", train_graph['user', 'rates', 'recipe'].edge_index)
print("Recipe embeddings dimension: ", train_graph['recipe'].x.size(1))

Train graph information: 
Numbder of nodes: 458207
Numbder of edges: 1540022
Metadata: (['user', 'recipe'], [('user', 'rates', 'recipe'), ('recipe', 'rev_rates', 'user')])
Edge index: tensor([[  3106,    317,  16543,  ...,    541, 208023,    489],
        [211809,   6600, 109688,  ...,  62108,  96459, 200804]])
Recipe embeddings dimension:  3081


## Model Implementation
In this step, we implement a heterogeneous GNN model for an edge rating prediction task in a recipe recommendation system. This model leverages the GraphSAGE architecture to encode node features into embeddings and a custom edge decoder to predict ratings between user and recipe nodes. The primary components of the model include the encoder, the decoder, and an integrated model structure with enhanced regularization and embedding initialization.

### GNNEncoder
This module encodes node features into embeddings using the GraphSAGE architecture.

**Structure**:
- `conv1`: The first SAGE convolutional layer, performing initial feature transformation.
- `conv2`: The second SAGE convolutional layer, producing the final node embeddings.

**Forward Pass**:
- Takes node features `x` and edge connections `edge_index` as input.
- Applies `conv1` with a ReLU activation for non-linearity.
- Applies `conv2` to generate the final node embeddings.

### EdgeDecoder
The EdgeDecoder serves as the prediction head, decoding the node embeddings to predict edge labels (ratings).

**Structure**:
- `lin1`: A fully connected layer that combines the embeddings from both nodes in an edge.
- `lin2`: A linear layer that outputs a scalar representing the predicted edge label, such as a rating.

**Forward Pass**:
- Extracts embeddings for connected nodes (e.g., user and recipe).
- Concatenates these embeddings and applies `lin1` with a ReLU activation.
- Passes the result through `lin2`, which outputs the predicted edge rating as a scalar.

### SageRecModel
The SageRecModel class integrates the encoder and decoder into a complete model tailored for edge rating prediction in a heterogeneous graph. It includes mechanisms for handling missing features for user nodes and regularizing the model during training.

**Structure**:
- **User Embeddings**: Users are represented with an embedding layer (`user_embedding`) since they lack explicit features. This embedding layer is initialized with normalized values and is learned during training.
- **Recipe Feature Transformation**: A linear transformation layer (`recipe_transform`) is applied to recipe features, preparing them for use in the encoder.
- **Encoder**: Instantiates the `GNNEncoder` and adapts it to heterogeneous graphs using `to_hetero`, allowing the model to process multiple types of nodes and edges.
- **Decoder**: A custom `EdgeDecoder` for predicting edge labels based on node embeddings.
- **Dropout Layer**: A dropout layer is applied to the transformed recipe embeddings to prevent overfitting and enhance generalization.

**Forward Pass**:
- Takes a dictionary of node features `x_dict`, an edge index dictionary `edge_index_dict`, and the edge label index `edge_label_index`.
- Normalizes the `user` embeddings for stability.
- Normalizes and transforms the `recipe` features, applying ReLU and dropout to enhance robustness.
- Passes `x_dict` and `edge_index_dict` to the encoder to produce node embeddings (`z_dict`).
- Uses the `decoder` to predict edge labels based on these embeddings, generating edge rating predictions.

In [5]:
class GNNEncoder(torch.nn.Module):
    def __init__(self, hidden_dim, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_dim)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)

        return x

class EdgeDecoder(torch.nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.lin1 = torch.nn.Linear(2 * hidden_dim, hidden_dim)
        self.lin2 = torch.nn.Linear(hidden_dim, 1)

    def forward(self, z_dict, edge_label_index):
        row, col = edge_label_index
        z = torch.cat([z_dict['user'][row], z_dict['recipe'][col]], dim=-1)
        z = self.lin1(z).relu()
        z = self.lin2(z)

        return z.view(-1)

class SageRecModel(torch.nn.Module):
    def __init__(self, hidden_dim, num_users, recipe_feature_dim, args):
        super().__init__()
        self.user_embedding = torch.nn.Embedding(num_users, hidden_dim)
        self.recipe_transform = torch.nn.Linear(recipe_feature_dim, hidden_dim)
        self.encoder = GNNEncoder(hidden_dim, hidden_dim)
        self.encoder = to_hetero(self.encoder, train_graph.metadata(), aggr='sum')
        self.decoder = EdgeDecoder(hidden_dim)
        self.dropout = torch.nn.Dropout(p=args['dropout_rate'])
        self._initialize_embeddings()

    def _initialize_embeddings(self):
        torch.nn.init.xavier_uniform_(self.user_embedding.weight)
        self.user_embedding.weight.data = F.normalize(self.user_embedding.weight.data, p=2, dim=1)

    def forward(self, x_dict, edge_index_dict, edge_label_index):
        # Normalize user embeddings
        x_dict['user'] = F.normalize(self.user_embedding.weight, p=2, dim=1)

        # Normalize and transform recipe features and regularize.
        x_dict['recipe'] = F.normalize(x_dict['recipe'], p=2, dim=1)
        x_dict['recipe'] = self.recipe_transform(x_dict['recipe']).relu()
        x_dict['recipe'] = self.dropout(x_dict['recipe'])  # Apply dropout

        z_dict = self.encoder(x_dict, edge_index_dict)

        return self.decoder(z_dict, edge_label_index)

hidden_dim = 32 # Embedding dimension size.
recipe_feature_dim = train_graph['recipe'].x.size(1)
num_users = train_graph['user']['num_nodes']
args = {
    'dropout_rate': 0.5
}
model = SageRecModel(hidden_dim=hidden_dim, num_users=num_users, recipe_feature_dim=recipe_feature_dim, args=args).to(device)

print(model.parameters())

model

<generator object Module.parameters at 0x7f68e6157ae0>


SageRecModel(
  (user_embedding): Embedding(226570, 32)
  (recipe_transform): Linear(in_features=3081, out_features=32, bias=True)
  (encoder): GraphModule(
    (conv1): ModuleDict(
      (user__rates__recipe): SAGEConv((-1, -1), 32, aggr=mean)
      (recipe__rev_rates__user): SAGEConv((-1, -1), 32, aggr=mean)
    )
    (conv2): ModuleDict(
      (user__rates__recipe): SAGEConv((-1, -1), 32, aggr=mean)
      (recipe__rev_rates__user): SAGEConv((-1, -1), 32, aggr=mean)
    )
  )
  (decoder): EdgeDecoder(
    (lin1): Linear(in_features=64, out_features=32, bias=True)
    (lin2): Linear(in_features=32, out_features=1, bias=True)
  )
  (dropout): Dropout(p=0.5, inplace=False)
)

## Model Training

In this step, we train our initialized \texttt{SageRecModel} model to optimize for the edge rating prediction task, where we predict user ratings for recipes.

### Mini-Batching
To manage memory and computation on our large graph datasets, we use mini-batching, dividing the data into smaller, manageable subsets that the model processes sequentially. With PyG’s `LinkNeighborLoader`, we sample neighbors around each target edge, focusing on the local neighborhood of each user-recipe interaction. This approach enables the model to capture essential neighborhood context without loading the entire graph, making it highly efficient for large-scale training.


### Optimization
- **Optimizer**: To optimize the model’s parameters, we use Adam optimizer that handles sparse gradients well, suitable for graph neural networks. Regularization is applied through weight decay, which discourages complex solutions and helps prevent overfitting, leading to better generalization.  
To further stabilize training, we use a learning rate scheduler that dynamically adjusts the learning rate, reducing it gradually as the model approaches convergence. This prevents overshooting during optimization and enables fine-tuning for more accurate predictions.


- **Loss Function**: Mean Squared Error (MSE) loss measures the difference between predicted and actual ratings. Given the continuous nature of ratings, MSE is a suitable choice for our link regression task and is calculated as:

  $$
  \text{MSE Loss} = \frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2
  $$

  where $y_i$ is the true rating, $\hat{y}_i$ is the predicted rating, and $N$ is the number of samples.


### Training Process
For each epoch, we perform the following steps:
  1. Batch sampling: Sampling mini-batches of edges with the train data loader, allowing the model to process a portion of the data at each step.
  2. Forward propagation: Generating predictions for each mini-batch and calculating the MSE loss based on the difference between predicted and actual ratings.
  3. Backward Ppropagation: Updating the model’s parameters based on the computed gradients.
  4. Parameter update: Updating the model’s parameters based on the computed gradients by optimizer.
  5. Learning rate adjustment: Adjusting the learning rate periodically with the scheduler, stabilizing training as the model converges.

In [6]:
def create_link_neighbor_loader(data, edge_type, batch_size=1024, num_neighbors=[10, 10], shuffle=True, num_workers=4):
    """
    Creates a LinkNeighborLoader for the specified edge type in a HeteroData object.

    Parameters:
    - data (HeteroData): The heterogeneous graph data.
    - edge_type (tuple): The edge type for which to create the loader, e.g., ('user', 'rates', 'recipe').
    - batch_size (int): Number of edges to include in each batch.
    - num_neighbors (list): Number of neighbors to sample at each layer.
    - shuffle (bool): Whether to shuffle the data.
    - num_workers (int): Number of subprocesses to use for data loading.

    Returns:
    - loader (LinkNeighborLoader): The data loader for the specified edge type.
    """
    # Ensure the edge_type exists in the data
    if edge_type not in data.edge_types:
        raise ValueError(f"Edge type {edge_type} not found in the data.")

    # Access the edge_label_index and edge_label for the specified edge type
    edge_label_index = data[edge_type].get('edge_label_index', data[edge_type].edge_index)
    edge_label = data[edge_type].get('edge_label', None)

    # Create the LinkNeighborLoader
    loader = LinkNeighborLoader(
        data=data,
        num_neighbors=num_neighbors,
        edge_label_index=(edge_type, edge_label_index),
        edge_label=edge_label,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
    )

    return loader


edge_type = ('user', 'rates', 'recipe') # Define the edge type of interest
batch_size = 1024  # Adjust based on your GPU memory capacity
num_neighbors = [15, 10] # Number of neighbors to sample at each layer
num_workers = 4  # Adjust based on your system

# Create the training data loader
train_data_loader = create_link_neighbor_loader(
    data=train_graph,
    edge_type=edge_type,
    batch_size=batch_size, 
    num_neighbors=num_neighbors, 
    shuffle=True,
    num_workers=num_workers 
)

weight_decay = 0.0001
learning_rate = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# Implement a learning rate scheduler
step_size = 10
gamma = 0.1
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)



In [7]:
def train(model, data_loader, optimizer, scheduler):
    model.train()
    total_loss = 0
    for batch in tqdm(data_loader, desc='Training', unit='batch', leave=False):
        batch = batch.to(device)
        optimizer.zero_grad()
        pred = model(batch.x_dict, batch.edge_index_dict, batch['user', 'rates', 'recipe'].edge_label_index)
        # Flatten target to match pred.
        target = batch['user', 'rates', 'recipe'].edge_label.float().view(-1)
        loss = F.mse_loss(pred, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * target.size(0)
    
    # Compute average loss (MSE) per data point (edge). 
    mse = total_loss / len(data_loader.dataset)

    # Step the scheduler to update the learning rate
    scheduler.step()

    return mse


# Training loop
num_epochs = 30
for epoch in range(1, num_epochs + 1):
    loss = train(model, train_data_loader, optimizer, scheduler)
    print(f'Epoch: {epoch:03d}, Loss (MSE): {loss:.4f}')

flush()
# Epoch: 030, Loss (MSE): 1.4998


                                                              

Epoch: 001, Loss (MSE): 4.4525


                                                              

Epoch: 002, Loss (MSE): 1.5808


                                                              

Epoch: 003, Loss (MSE): 1.5315


                                                              

Epoch: 004, Loss (MSE): 1.5153


                                                              

Epoch: 005, Loss (MSE): 1.5114


                                                              

Epoch: 006, Loss (MSE): 1.5098


                                                              

Epoch: 007, Loss (MSE): 1.5078


                                                              

Epoch: 008, Loss (MSE): 1.5071


                                                              

Epoch: 009, Loss (MSE): 1.5059


                                                              

Epoch: 010, Loss (MSE): 1.5043


                                                              

Epoch: 011, Loss (MSE): 1.5032


                                                              

Epoch: 012, Loss (MSE): 1.5034


                                                              

Epoch: 013, Loss (MSE): 1.5021


                                                              

Epoch: 014, Loss (MSE): 1.5025


                                                              

Epoch: 015, Loss (MSE): 1.5023


                                                              

Epoch: 016, Loss (MSE): 1.5018


                                                              

Epoch: 017, Loss (MSE): 1.5014


                                                              

Epoch: 018, Loss (MSE): 1.5011


                                                              

Epoch: 019, Loss (MSE): 1.5007


                                                              

Epoch: 020, Loss (MSE): 1.5009


                                                              

Epoch: 021, Loss (MSE): 1.4999


                                                              

Epoch: 022, Loss (MSE): 1.5003


                                                              

Epoch: 023, Loss (MSE): 1.5002


                                                              

Epoch: 024, Loss (MSE): 1.5004


                                                              

Epoch: 025, Loss (MSE): 1.5011


                                                              

Epoch: 026, Loss (MSE): 1.5002


                                                              

Epoch: 027, Loss (MSE): 1.5005


                                                              

Epoch: 028, Loss (MSE): 1.5001


                                                              

Epoch: 029, Loss (MSE): 1.5007


                                                              

Epoch: 030, Loss (MSE): 1.4997


# Evaluation
Finally, we evaluate the performance of our model on the validation and test graphs using the **Root Mean Squared Error (RMSE)** metric. Although MSE is used as a loss function for training due to its efficient gradient properties, we use RMSE for evaluation because it provides error values in the same units as the target variable, making it more interpretable when assessing model performance.

We also report the evaluation results on the training graph to monitor the model error.

In [8]:
validation_data_loader = create_link_neighbor_loader(
    data=validation_graph,
    edge_type=edge_type,
    batch_size=batch_size, 
    num_neighbors=num_neighbors, 
    shuffle=False, # No need to shuffle during validation
    num_workers=num_workers 
)

test_data_loader = create_link_neighbor_loader(
    data=test_graph,
    edge_type=edge_type,
    batch_size=batch_size, 
    num_neighbors=num_neighbors, 
    shuffle=False, # No need to shuffle during testing
    num_workers=num_workers 
)

In [9]:
@torch.no_grad()
def evaluate_by_rmse(data_loader):
    model.eval()
    total_loss = 0
    for batch in tqdm(data_loader, desc='Evaluating', unit='batch', leave=False):
        batch = batch.to(device)
        pred = model(batch.x_dict, batch.edge_index_dict, batch['user', 'rates', 'recipe'].edge_label_index)
        # Scale down the predicted rating between 0 and 5 if it's not already.
        pred = pred.clamp(min=0, max=5)
        # Flatten target to match pred
        target = batch['user', 'rates', 'recipe'].edge_label.float().view(-1)  
        loss = F.mse_loss(pred, target, reduction='sum')
        total_loss += loss.item()

    rmse = (total_loss / len(data_loader.dataset)) ** 0.5
    return rmse

evluation_data_loaders = {
    'Train': train_data_loader,
    'Validation': validation_data_loader,
    'Test': test_data_loader
}
evaluation_results = {}
for data_split_name, data_loader in evluation_data_loaders.items():
    evaluation_result = evaluate_by_rmse(data_loader)
    print(f'{data_split_name} RMSE: {evaluation_result:.4f}')
    flush()

# Train RMSE: 1.2245                                            
# Validation RMSE: 1.2330                                                   
# Test RMSE: 1.2175


                                                                

Train RMSE: 1.2244


                                                              

Validation RMSE: 1.2331


                                                                

Test RMSE: 1.2175
