## GraphSAGE-based Recommender Model (SageRecModel) for Food Recipe Recommendation (V0)

## Environment setup

In [1]:
import gc

import torch
import torch.nn.functional as F
import torch.optim as optim

from tqdm import tqdm
from torch_geometric.nn import SAGEConv, to_hetero
from torch_geometric.loader import LinkNeighborLoader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def flush():
  gc.collect()
  torch.cuda.empty_cache()
  torch.cuda.reset_peak_memory_stats()

flush()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

## Graph dataset loading

In this step, we load the graphs already generated in the graph dataset generation step.    
Since generating graph datasets is time consiming, we won't add them to each notebook.

In [3]:
def load_graph(file_path):
    return torch.load(file_path)

dataset_version = 1
base_data_path = f"../data/graph/v{dataset_version}"

train_graph = load_graph(f"{base_data_path}/train_graph.pt")
validation_graph = load_graph(f"{base_data_path}/validation_graph.pt")
test_graph = load_graph(f"{base_data_path}/test_graph.pt")

train_graph

  return torch.load(file_path)


HeteroData(
  user={ num_nodes=226570 },
  recipe={ x=[231637, 3081] },
  (user, rates, recipe)={
    edge_index=[2, 770011],
    edge_label=[192502, 1],
    edge_label_index=[2, 192502],
  },
  (recipe, rev_rates, user)={ edge_index=[2, 770011] }
)

In [4]:
print("Train graph information: ")
print("Numbder of nodes:", train_graph.num_nodes)
print("Numbder of edges:", train_graph.num_edges)
print("Metadata:", train_graph.metadata())
print("Edge index:", train_graph['user', 'rates', 'recipe'].edge_index)

Train graph information: 
Numbder of nodes: 458207
Numbder of edges: 1540022
Metadata: (['user', 'recipe'], [('user', 'rates', 'recipe'), ('recipe', 'rev_rates', 'user')])
Edge index: tensor([[  3106,    317,  16543,  ...,    541, 208023,    489],
        [211809,   6600, 109688,  ...,  62108,  96459, 200804]])


## Model implementation
In this step, we implement a heterogeneous GNN model for the edge rating prediction task in a recipe recommendation system. 
Our model uses the GraphSAGE architecture to encode node features into embeddings and a custom edge decoder to predict ratings between user and recipe nodes. The model consists of three main components: the encoder, the decoder, and the integration of both into a complete model for training and inference.

### GNNEncoder
It encodes node features into embeddings using the GraphSAGE architecture.

**Structure**:
- `conv1`: First SAGE convolutional layer for initial feature transformation.
- `conv2`: Second SAGE convolutional layer for output embedding generation.

**Forward Pass**:
- Takes node features `x` and edge connections `edge_index` as input.
- Applies `conv1` followed by a ReLU activation.
- Applies `conv2` to output the final node embeddings.

### EdgeDecoder
It acts a prediction head to predict edge labels (ratings) by by decoding the node embeddings

**Structure**:
- `lin1`: A fully connected layer that combines node embeddings from both ends of an edge.
- `lin2`: A linear layer that outputs a scalar representing the predicted edge label (e.g., a rating).

**Forward Pass**:
- Extracts embeddings for connected nodes (e.g., user and recipe).
- Concatenates these embeddings and passes them through `lin1` with a ReLU activation.
- Outputs a single value through `lin2` representing the predicted edge rating.

### SageRecModel
It integrates the encoder and decoder to create a complete GNN model for edge rating prediction.

**Structure**:
- `Encoder`: Instantiates the `GNNEncoder` and adapts it to heterogeneous graphs using `to_hetero`, allowing the model to handle different types of nodes and edges.
- `Decoder`: A custom `EdgeDecoder` for predicting edge labels based on embeddings.

**User Embeddings**:
Due to lack of user features, user embeddings are randomly initialized and then learned during training to capture user-specific patterns in interactions.

**Forward Pass**:
- Accepts a dictionary of node features `x_dict`, an edge index dictionary `edge_index_dict`, and the edge label index `edge_label_index`.
- Passes `x_dict` and `edge_index_dict` to the encoder to generate node embeddings.
- Uses the `decoder` to predict edge labels from these embeddings.

In [5]:
class GNNEncoder(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        # print(f"x: {x}")
        # print(f"edge_index: {edge_index}")
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x


class EdgeDecoder(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.lin1 = torch.nn.Linear(2 * hidden_channels, hidden_channels)
        self.lin2 = torch.nn.Linear(hidden_channels, 1)

    def forward(self, z_dict, edge_label_index):
        row, col = edge_label_index
        z = torch.cat([z_dict['user'][row], z_dict['recipe'][col]], dim=-1)

        z = self.lin1(z).relu()
        z = self.lin2(z)
        return z.view(-1)


class SageRecModel(torch.nn.Module):
    def __init__(self, hidden_channels, num_users):
        super().__init__()
        self.user_embedding = torch.nn.Embedding(num_users, hidden_channels)
        self.encoder = GNNEncoder(hidden_channels, hidden_channels)
        self.encoder = to_hetero(self.encoder, train_graph.metadata(), aggr='sum')
        self.decoder = EdgeDecoder(hidden_channels)

    def forward(self, x_dict, edge_index_dict, edge_label_index):
        # Initialize user features with embeddings
        x_dict['user'] = self.user_embedding.weight

        z_dict = self.encoder(x_dict, edge_index_dict)
        return self.decoder(z_dict, edge_label_index)


num_users = train_graph['user']['num_nodes']
# Embedding dimension size.
hidden_channels = 32
model = SageRecModel(hidden_channels=hidden_channels, num_users=num_users).to(device)

print(model.parameters())

model

<generator object Module.parameters at 0x7f7244129070>


SageRecModel(
  (user_embedding): Embedding(226570, 32)
  (encoder): GraphModule(
    (conv1): ModuleDict(
      (user__rates__recipe): SAGEConv((-1, -1), 32, aggr=mean)
      (recipe__rev_rates__user): SAGEConv((-1, -1), 32, aggr=mean)
    )
    (conv2): ModuleDict(
      (user__rates__recipe): SAGEConv((-1, -1), 32, aggr=mean)
      (recipe__rev_rates__user): SAGEConv((-1, -1), 32, aggr=mean)
    )
  )
  (decoder): EdgeDecoder(
    (lin1): Linear(in_features=64, out_features=32, bias=True)
    (lin2): Linear(in_features=32, out_features=1, bias=True)
  )
)

## Model training

In this step, we train the Graph Neural Network (GNN) model to optimize for the edge rating prediction task, where we predict ratings from users to recipes.

### Mini-batching
Since our graph dataset is very large, we need to perform mini-batching to manage memory and computational resources effectively.
Using PyG’s `LinkNeighborLoader`, we divide the data into smaller, manageable batches, allowing the model to process subsets of the graph at each step.  
This loader samples neighboring nodes and edges for each target edge in the batch, focusing on the local neighborhood of each user-recipe interaction. This enables the model to capture relevant context without requiring the full graph in memory, making it ideal for efficient training with large datasets.

### Optimization
- **Optimizer**: We use `torch.optim.Adam` to adjust the model's parameters during training. The Adam optimizer effectively handles sparse gradients, helping the GNN learn efficiently from the graph data. The learning rate is set to 0.01 but can be modified during model tuning.

- **Loss Function**: Mean Squared Error (MSE) loss measures the difference between predicted and actual ratings. Given the continuous nature of ratings, MSE is a suitable choice for our link regression task and is calculated as:

  $$
  \text{MSE Loss} = \frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2
  $$

  where $y_i$ is the true rating, $\hat{y}_i$ is the predicted rating, and $N$ is the number of samples.


### Training process
For each epoch, we perform the following steps:
  - **Batch sampling**: The `train_data_loader` samples mini-batches of edges, allowing the model to process manageable portions of the data.
  - **Forward propagation**: The model predicts ratings for each mini-batch of edges and calculates the Mean Squared Error (MSE) loss between the predicted and actual ratings.
  - **Backward propagation**: Calculates the gradients of the loss with respect to the model’s parameters, including embeddings.
  - **Parameter update**: The optimizer uses these gradients to update the model’s parameters, refining the embeddings and other learnable parameters to minimize the loss.

In [6]:
def create_link_neighbor_loader(data, edge_type, batch_size=1024, num_neighbors=[10, 10], shuffle=True, num_workers=4):
    """
    Creates a LinkNeighborLoader for the specified edge type in a HeteroData object.

    Parameters:
    - data (HeteroData): The heterogeneous graph data.
    - edge_type (tuple): The edge type for which to create the loader, e.g., ('user', 'rates', 'recipe').
    - batch_size (int): Number of edges to include in each batch.
    - num_neighbors (list): Number of neighbors to sample at each layer.
    - shuffle (bool): Whether to shuffle the data.
    - num_workers (int): Number of subprocesses to use for data loading.

    Returns:
    - loader (LinkNeighborLoader): The data loader for the specified edge type.
    """
    # Ensure the edge_type exists in the data
    if edge_type not in data.edge_types:
        raise ValueError(f"Edge type {edge_type} not found in the data.")

    # Access the edge_label_index and edge_label for the specified edge type
    edge_label_index = data[edge_type].get('edge_label_index', data[edge_type].edge_index)
    edge_label = data[edge_type].get('edge_label', None)

    # Create the LinkNeighborLoader
    loader = LinkNeighborLoader(
        data=data,
        num_neighbors=num_neighbors,
        edge_label_index=(edge_type, edge_label_index),
        edge_label=edge_label,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
    )

    return loader


edge_type = ('user', 'rates', 'recipe') # Define the edge type of interest
batch_size = 1024  # Adjust based on your GPU memory capacity
num_neighbors = [10, 10] # Number of neighbors to sample at each layer
num_workers = 4  # Adjust based on your system

# Create the training data loader
train_data_loader = create_link_neighbor_loader(
    data=train_graph,
    edge_type=edge_type,
    batch_size=batch_size, 
    num_neighbors=num_neighbors, 
    shuffle=True,
    num_workers=num_workers 
)

learning_rate = 0.01
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)



In [7]:
def train(model, data_loader, optimizer):
    model.train()
    total_loss = 0
    for batch in tqdm(data_loader, desc='Training', unit='batch', leave=False):
        batch = batch.to(device)
        optimizer.zero_grad()
        pred = model(batch.x_dict, batch.edge_index_dict, batch['user', 'rates', 'recipe'].edge_label_index)
        # Flatten target to match pred.
        target = batch['user', 'rates', 'recipe'].edge_label.float().view(-1)
        loss = F.mse_loss(pred, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * target.size(0)
    
    # Compute average loss (MSE) per data point (edge). 
    mse = total_loss / len(data_loader.dataset)

    return mse


# Training loop
num_epochs = 30
for epoch in range(1, num_epochs + 1):
    loss = train(model, train_data_loader, optimizer)
    print(f'Epoch: {epoch:03d}, Loss (MSE): {loss:.4f}')

flush()

                                                              

Epoch: 001, Loss (MSE): 326693619.3203


                                                              

Epoch: 002, Loss (MSE): 18353875.0401


                                                              

Epoch: 003, Loss (MSE): 19401.1392


                                                              

Epoch: 004, Loss (MSE): 31.8266


                                                              

Epoch: 005, Loss (MSE): 23.8397


                                                              

Epoch: 006, Loss (MSE): 20.4781


                                                              

Epoch: 007, Loss (MSE): 19.7001


                                                              

Epoch: 008, Loss (MSE): 19.3548


                                                              

Epoch: 009, Loss (MSE): 18.9419


                                                              

Epoch: 010, Loss (MSE): 18.4420


                                                              

Epoch: 011, Loss (MSE): 17.9090


                                                              

Epoch: 012, Loss (MSE): 17.3473


                                                              

Epoch: 013, Loss (MSE): 16.7501


                                                              

Epoch: 014, Loss (MSE): 1867.7014


                                                              

Epoch: 015, Loss (MSE): 27.2822


                                                              

Epoch: 016, Loss (MSE): 3.7226


                                                              

Epoch: 017, Loss (MSE): 3.2966


                                                              

Epoch: 018, Loss (MSE): 2.8632


                                                              

Epoch: 019, Loss (MSE): 3.0871


                                                              

Epoch: 020, Loss (MSE): 3.6367


                                                              

Epoch: 021, Loss (MSE): 5.4770


                                                              

Epoch: 022, Loss (MSE): 1.9378


                                                              

Epoch: 023, Loss (MSE): 1.8474


                                                              

Epoch: 024, Loss (MSE): 1.8143


                                                              

Epoch: 025, Loss (MSE): 2.1140


                                                              

Epoch: 026, Loss (MSE): 2.1182


                                                              

Epoch: 027, Loss (MSE): 1.7623


                                                              

Epoch: 028, Loss (MSE): 1.7336


                                                              

Epoch: 029, Loss (MSE): 1.7184


                                                              

Epoch: 030, Loss (MSE): 1.7048


# Evaluation
Finally, we evaluate the perfomance of our model on the validation and test graphs using **Root Mean Squared Error (RMSE)** metric. 
Although MSE is used a lost function for training because of its efficient gradient properties, we use RMSE for evaluation since it provides error values in the same units as the target variable, making it more interpretable when assessing model performance.

We also report the evalution result on train graph to monitor the model error.

In [8]:
validation_data_loader = create_link_neighbor_loader(
    data=validation_graph,
    edge_type=edge_type,
    batch_size=batch_size, 
    num_neighbors=num_neighbors, 
    shuffle=False, # No need to shuffle during validation
    num_workers=num_workers 
)

test_data_loader = create_link_neighbor_loader(
    data=test_graph,
    edge_type=edge_type,
    batch_size=batch_size, 
    num_neighbors=num_neighbors, 
    shuffle=False, # No need to shuffle during testing
    num_workers=num_workers 
)

In [9]:
@torch.no_grad()
def evaluate_by_rmse(data_loader):
    model.eval()
    total_loss = 0
    for batch in tqdm(data_loader, desc='Evaluating', unit='batch', leave=False):
        batch = batch.to(device)
        pred = model(batch.x_dict, batch.edge_index_dict, batch['user', 'rates', 'recipe'].edge_label_index)
        # min rating is 0 and max rating is 5.
        pred = pred.clamp(min=0, max=5)
        # Flatten target to match pred
        target = batch['user', 'rates', 'recipe'].edge_label.float().view(-1)  
        loss = F.mse_loss(pred, target, reduction='sum')
        total_loss += loss.item()

    rmse = (total_loss / len(data_loader.dataset)) ** 0.5
    return rmse

evluation_data_loaders = {
    'Train': train_data_loader,
    'Validation': validation_data_loader,
    'Test': test_data_loader
}
evaluation_results = {}
for data_split_name, data_loader in evluation_data_loaders.items():
    evaluation_result = evaluate_by_rmse(data_loader)
    print(f'{data_split_name} RMSE: {evaluation_result:.4f}')
    flush()


                                                                

Train RMSE: 1.2882


                                                              

Validation RMSE: 1.2979


                                                                

Test RMSE: 1.2842
