# Training a GraphSAGE-based GNN Model for Food Recipe Recommendation

## Environment setup

In [1]:
import torch
import torch.nn.functional as F
import torch.optim as optim

from torch_geometric.nn import SAGEConv, to_hetero
from torch_geometric.loader import LinkNeighborLoader

  from .autonotebook import tqdm as notebook_tqdm


## Graph dataset loading

In this step, we load the graphs already generated in the graph dataset generation step.    
Since generating graph datasets is time consiming, we won't add them to each notebook.

In [2]:
def load_graph(file_path):
    return torch.load(file_path)

dataset_version = 1
base_data_path = f"../data/graph/v{dataset_version}"

train_data = load_graph(f"{base_data_path}/train_graph.pt")
val_data = load_graph(f"{base_data_path}/validation_graph.pt")
# test_graph = load_graph(f"{base_data_path}/test_graph.pt")

train_data

  return torch.load(file_path)


HeteroData(
  user={ num_nodes=226570 },
  recipe={ x=[231637, 3081] },
  (user, rates, recipe)={
    edge_index=[2, 770011],
    edge_label=[192502, 1],
    edge_label_index=[2, 192502],
  },
  (recipe, rev_rates, user)={ edge_index=[2, 770011] }
)

In [3]:
val_data

HeteroData(
  user={ num_nodes=226570 },
  recipe={ x=[231637, 3081] },
  (user, rates, recipe)={
    edge_index=[2, 962513],
    edge_label=[56618, 1],
    edge_label_index=[2, 56618],
  },
  (recipe, rev_rates, user)={ edge_index=[2, 962513] }
)

In [4]:
train_data['user', 'rates', 'recipe'].edge_index

tensor([[  3106,    317,  16543,  ...,    541, 208023,    489],
        [211809,   6600, 109688,  ...,  62108,  96459, 200804]])

In [5]:
train_data.metadata()

(['user', 'recipe'],
 [('user', 'rates', 'recipe'), ('recipe', 'rev_rates', 'user')])

## Model implementation

### Model Architecture Overview
This GNN model is designed for edge rating prediction in a recipe recommendation system. It uses the GraphSAGE architecture to encode node features into embeddings and a custom edge decoder to predict ratings between user and recipe nodes. The model consists of three main components: the encoder, the decoder, and the integration of both into a complete model for training and inference.

### GNNEncoder Class
**Purpose**: Encodes node features into embeddings using the GraphSAGE architecture.

**Structure**:
- `conv1`: First SAGE convolutional layer for initial feature transformation.
- `conv2`: Second SAGE convolutional layer for output embedding generation.

**Forward Pass**:
- Takes node features `x` and edge connections `edge_index` as input.
- Applies `conv1` followed by a ReLU activation.
- Applies `conv2` to output the final node embeddings.

### EdgeDecoder Class
**Purpose**: Decodes the node embeddings to predict edge labels, such as ratings between users and recipes.

**Structure**:
- `lin1`: A fully connected layer that combines node embeddings from both ends of an edge.
- `lin2`: A linear layer that outputs a scalar representing the predicted edge label (e.g., a rating).

**Forward Pass**:
- Extracts embeddings for connected nodes (e.g., user and recipe).
- Concatenates these embeddings and passes them through `lin1` with a ReLU activation.
- Outputs a single value through `lin2` representing the predicted edge rating.

### Model Class
**Purpose**: Integrates the encoder and decoder to create a complete GNN model.

**Structure**:
- `encoder`: Instantiates the `GNNEncoder` and adapts it to heterogeneous graphs using `to_hetero`, allowing the model to handle different types of nodes and edges.
- `decoder`: A custom `EdgeDecoder` for predicting edge labels based on embeddings.

**Forward Pass**:
- Accepts a dictionary of node features `x_dict`, an edge index dictionary `edge_index_dict`, and the edge label index `edge_label_index`.
- Passes `x_dict` and `edge_index_dict` to the encoder to generate node embeddings.
- Uses the `decoder` to predict edge labels from these embeddings.

**Execution Context**:
- The model is set to run on a GPU if available, with `hidden_channels` set to 32 for embedding dimensions.

In [6]:
class GNNEncoder(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        # print(f"x: {x}")
        # print(f"edge_index: {edge_index}")
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x


class EdgeDecoder(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.lin1 = torch.nn.Linear(2 * hidden_channels, hidden_channels)
        self.lin2 = torch.nn.Linear(hidden_channels, 1)

    def forward(self, z_dict, edge_label_index):
        row, col = edge_label_index
        z = torch.cat([z_dict['user'][row], z_dict['recipe'][col]], dim=-1)

        z = self.lin1(z).relu()
        z = self.lin2(z)
        return z.view(-1)


class Model(torch.nn.Module):
    def __init__(self, hidden_channels, num_users):
        super().__init__()
        self.user_embedding = torch.nn.Embedding(num_users, hidden_channels)
        self.encoder = GNNEncoder(hidden_channels, hidden_channels)
        self.encoder = to_hetero(self.encoder, train_data.metadata(), aggr='sum')
        self.decoder = EdgeDecoder(hidden_channels)

    def forward(self, x_dict, edge_index_dict, edge_label_index):
        # Initialize user features with embeddings
        x_dict['user'] = self.user_embedding.weight

        z_dict = self.encoder(x_dict, edge_index_dict)
        return self.decoder(z_dict, edge_label_index)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_users = train_data['user']['num_nodes']

model = Model(hidden_channels=32, num_users=num_users).to(device)

model

Model(
  (user_embedding): Embedding(226570, 32)
  (encoder): GraphModule(
    (conv1): ModuleDict(
      (user__rates__recipe): SAGEConv((-1, -1), 32, aggr=mean)
      (recipe__rev_rates__user): SAGEConv((-1, -1), 32, aggr=mean)
    )
    (conv2): ModuleDict(
      (user__rates__recipe): SAGEConv((-1, -1), 32, aggr=mean)
      (recipe__rev_rates__user): SAGEConv((-1, -1), 32, aggr=mean)
    )
  )
  (decoder): EdgeDecoder(
    (lin1): Linear(in_features=64, out_features=32, bias=True)
    (lin2): Linear(in_features=32, out_features=1, bias=True)
  )
)

## Model Training

In this step, we train the Graph Neural Network (GNN) model to optimize for the edge rating prediction task, where we predict ratings from users to recipes.

### Mini-Batching
Since our graph dataset is very large, we need to perform mini-batching to manage memory and computational resources effectively.
Using PyG’s `LinkNeighborLoader`, we divide the data into smaller, manageable batches, allowing the model to process subsets of the graph at each step.  
This loader samples neighboring nodes and edges for each target edge in the batch, focusing on the local neighborhood of each user-recipe interaction. This enables the model to capture relevant context without requiring the full graph in memory, making it ideal for efficient training with large datasets.

### Model Optimization
- **Optimizer**: We use `torch.optim.Adam` to adjust the model's parameters during training. The Adam optimizer effectively handles sparse gradients, helping the GNN learn efficiently from the graph data. The learning rate is set to 0.01 but can be modified during model tuning.

- **Loss Function**: Mean Squared Error (MSE) loss measures the difference between predicted and actual ratings. Given the continuous nature of ratings, MSE is a suitable choice for our link regression task and is calculated as:

  $$
  \text{MSE Loss} = \frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2
  $$

  where $y_i$ is the true rating, $\hat{y}_i$ is the predicted rating, and $N$ is the number of samples.


### Training Process
For each epoch, we perform the following steps:
  - **Batch Sampling**: The `train_loader` samples mini-batches of edges, allowing the model to focus on manageable portions of the data.
  - **Prediction and Loss Calculation**: The model predicts ratings for each mini-batch and calculates the Mean Squared Error (MSE) loss against actual ratings.
  - **Backpropagation and Parameter Update**: The optimizer updates the model’s parameters to minimize the loss.


### Validation
After each epoch, we also evaluate the model perfomance on validation set to ensure it generalizes well.
Although MSE is used for training because of its efficient gradient properties, we use **Root Mean Squared Error (RMSE)** for evaluation since it provides error values in the same units as the target variable, making it more interpretable when assessing model performance.

In [7]:
def create_link_neighbor_loader(data, edge_type, batch_size=1024, num_neighbors=[10, 10], shuffle=True, num_workers=4):
    """
    Creates a LinkNeighborLoader for the specified edge type in a HeteroData object.

    Parameters:
    - data (HeteroData): The heterogeneous graph data.
    - edge_type (tuple): The edge type for which to create the loader, e.g., ('user', 'rates', 'recipe').
    - batch_size (int): Number of edges to include in each batch.
    - num_neighbors (list): Number of neighbors to sample at each layer.
    - shuffle (bool): Whether to shuffle the data.
    - num_workers (int): Number of subprocesses to use for data loading.

    Returns:
    - loader (LinkNeighborLoader): The data loader for the specified edge type.
    """
    # Ensure the edge_type exists in the data
    if edge_type not in data.edge_types:
        raise ValueError(f"Edge type {edge_type} not found in the data.")

    # Access the edge_label_index and edge_label for the specified edge type
    edge_label_index = data[edge_type].get('edge_label_index', data[edge_type].edge_index)
    edge_label = data[edge_type].get('edge_label', None)

    # Create the LinkNeighborLoader
    loader = LinkNeighborLoader(
        data=data,
        num_neighbors=num_neighbors,
        edge_label_index=(edge_type, edge_label_index),
        edge_label=edge_label,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
    )

    return loader


# Define the edge type of interest
edge_type = ('user', 'rates', 'recipe')

# Create the training data loader
train_loader = create_link_neighbor_loader(
    data=train_data,
    edge_type=edge_type,
    batch_size=1024,  # Adjust based on your GPU memory capacity
    num_neighbors=[10, 10],  # Number of neighbors to sample at each layer
    shuffle=True,  # Shuffle during training
    num_workers=4  # Adjust based on your system
)

# Create the validation data loader
validation_loader = create_link_neighbor_loader(
    data=val_data,
    edge_type=edge_type,
    batch_size=1024,  # Same as training
    num_neighbors=[10, 10],  # Same as training
    shuffle=False,  # No need to shuffle during validation
    num_workers=4  # Adjust based on your system
)

train_loader



LinkNeighborLoader()

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def train():
    model.train()
    total_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        pred = model(batch.x_dict, batch.edge_index_dict, batch['user', 'rates', 'recipe'].edge_label_index)
        target = batch['user', 'rates', 'recipe'].edge_label.float().view(-1)  # Flatten target to match pred
        loss = F.mse_loss(pred, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * target.size(0)
    return total_loss / len(train_loader.dataset)

@torch.no_grad()
def test(loader):
    model.eval()
    total_loss = 0
    for batch in loader:
        batch = batch.to(device)
        pred = model(batch.x_dict, batch.edge_index_dict, batch['user', 'rates', 'recipe'].edge_label_index)
        pred = pred.clamp(min=0, max=5)
        target = batch['user', 'rates', 'recipe'].edge_label.float().view(-1)  # Flatten target to match pred
        loss = F.mse_loss(pred, target, reduction='sum')
        total_loss += loss.item()
    rmse = (total_loss / len(loader.dataset)) ** 0.5
    return rmse


# Training loop
for epoch in range(1, 50):
    loss = train()
    train_rmse = test(train_loader)
    val_rmse = test(validation_loader)
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train RMSE: {train_rmse:.4f}, Validation RMSE: {val_rmse:.4f}')


Epoch: 001, Loss: 664013962.6901, Train RMSE: 4.1559, Val RMSE: 4.1458
Epoch: 002, Loss: 45404.5548, Train RMSE: 1.5144, Val RMSE: 1.5084
Epoch: 003, Loss: 2.6491, Train RMSE: 1.4118, Val RMSE: 1.4216
Epoch: 004, Loss: 2.1862, Train RMSE: 1.3855, Val RMSE: 1.3853
Epoch: 005, Loss: 2.0773, Train RMSE: 1.4049, Val RMSE: 1.4063
Epoch: 006, Loss: 2.1642, Train RMSE: 1.3710, Val RMSE: 1.3686
Epoch: 007, Loss: 2.0484, Train RMSE: 1.7870, Val RMSE: 1.7787
Epoch: 008, Loss: 2.2857, Train RMSE: 1.3196, Val RMSE: 1.3240
Epoch: 009, Loss: 1.7969, Train RMSE: 1.3078, Val RMSE: 1.3110
Epoch: 010, Loss: 1.7569, Train RMSE: 1.2921, Val RMSE: 1.2978
Epoch: 011, Loss: 1.7577, Train RMSE: 1.3083, Val RMSE: 1.3109
Epoch: 012, Loss: 1.7161, Train RMSE: 1.2879, Val RMSE: 1.2967
Epoch: 013, Loss: 1.8159, Train RMSE: 1.2942, Val RMSE: 1.3002
Epoch: 014, Loss: 1.6900, Train RMSE: 1.2763, Val RMSE: 1.2829
Epoch: 015, Loss: 1.6682, Train RMSE: 1.2718, Val RMSE: 1.2767
Epoch: 016, Loss: 1.6558, Train RMSE: 1.276