## GATv2-based Recommender Model (GATv2RecModel) for Food Recipe Recommendation (V0)

## Environment setup

In [1]:
import gc
from collections import defaultdict

import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch import nn

from tqdm import tqdm
from torch_geometric.nn import GATv2Conv, to_hetero
from torch_geometric.loader import LinkNeighborLoader

In [2]:
def flush():
  gc.collect()
  torch.cuda.empty_cache()
  torch.cuda.reset_peak_memory_stats()

flush()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

## Graph dataset loading
In this step, we load the graphs already generated in the graph dataset generation step.    
Please read the README.md to find how to generate the graph datasets. Since generating graph datasets is time consiming and memory intensive, we didn't add them here. 

In [3]:
def load_graph(file_path):
    return torch.load(file_path)

dataset_version = 1
base_data_path = f"../data/graph/v{dataset_version}"

train_graph = load_graph(f"{base_data_path}/train_graph.pt")
validation_graph = load_graph(f"{base_data_path}/validation_graph.pt")
test_graph = load_graph(f"{base_data_path}/test_graph.pt")

train_graph

  return torch.load(file_path)


HeteroData(
  user={ num_nodes=226570 },
  recipe={ x=[231637, 3081] },
  (user, rates, recipe)={
    edge_index=[2, 770011],
    edge_label=[192502, 1],
    edge_label_index=[2, 192502],
  },
  (recipe, rev_rates, user)={ edge_index=[2, 770011] }
)

In [4]:
print("Train graph information: ")
print("Numbder of nodes:", train_graph.num_nodes)
print("Numbder of edges:", train_graph.num_edges)
print("Metadata:", train_graph.metadata())
print("Edge index:", train_graph['user', 'rates', 'recipe'].edge_index)
print("Recipe embeddings dimension: ", train_graph['recipe'].x.size(1))
print(train_graph[('user', 'rates', 'recipe')].edge_index.dtype)  # Should be torch.long
print(train_graph[('user', 'rates', 'recipe')].edge_label_index.dtype)  # Should be torch.long

Train graph information: 
Numbder of nodes: 458207
Numbder of edges: 1540022
Metadata: (['user', 'recipe'], [('user', 'rates', 'recipe'), ('recipe', 'rev_rates', 'user')])
Edge index: tensor([[  3106,    317,  16543,  ...,    541, 208023,    489],
        [211809,   6600, 109688,  ...,  62108,  96459, 200804]])
Recipe embeddings dimension:  3081
torch.int64
torch.int64


## Model Implementation


In [5]:
class GNNEncoder(torch.nn.Module):
    def __init__(self, hidden_dim, out_channels, heads=4, dropout=0.5):
        super().__init__()
        # GATv2Conv with attention heads
        self.conv1 = GATv2Conv((-1, -1), hidden_dim, heads=heads, dropout=dropout, add_self_loops=False)
        self.conv2 = GATv2Conv((-1, -1), out_channels, heads=1, concat=False, dropout=dropout, add_self_loops=False)
        
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()  # Apply first GATv2Conv and ReLU activation
        x = self.conv2(x, edge_index)         # Apply second GATv2Conv
        return x

class EdgeDecoder(torch.nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.lin1 = torch.nn.Linear(2 * hidden_dim, hidden_dim)
        self.lin2 = torch.nn.Linear(hidden_dim, 1)

    def forward(self, z_dict, edge_label_index):
        row, col = edge_label_index
        z = torch.cat([z_dict['user'][row], z_dict['recipe'][col]], dim=-1)
        z = self.lin1(z).relu()
        z = self.lin2(z)
        return z.view(-1)

class GATv2RecModel(torch.nn.Module):
    def __init__(self, hidden_dim, num_users, item_feature_dim, args):
        super().__init__()
        self.user_embedding = torch.nn.Embedding(num_users, hidden_dim)
        self.recipe_transform = torch.nn.Linear(item_feature_dim, hidden_dim)
        
        # Initialize GNN Encoder with GATv2Conv
        self.encoder = GNNEncoder(hidden_dim, hidden_dim, heads=4, dropout=args['dropout_rate'])
        self.encoder = to_hetero(self.encoder, train_graph.metadata(), aggr='sum')
        
        self.decoder = EdgeDecoder(hidden_dim)
        self.dropout = torch.nn.Dropout(p=args['dropout_rate'])
        self._initialize_embeddings()

    def _initialize_embeddings(self):
        torch.nn.init.xavier_uniform_(self.user_embedding.weight)
        self.user_embedding.weight.data = F.normalize(self.user_embedding.weight.data, p=2, dim=1)

    def forward(self, x_dict, edge_index_dict, edge_label_index):
        # Normalize user embeddings
        x_dict['user'] = F.normalize(self.user_embedding.weight, p=2, dim=1)

        # Normalize and transform recipe features and regularize
        x_dict['recipe'] = F.normalize(x_dict['recipe'], p=2, dim=1)
        x_dict['recipe'] = self.recipe_transform(x_dict['recipe']).relu()
        x_dict['recipe'] = self.dropout(x_dict['recipe'])  # Apply dropout

        # Encode the heterogeneous graph with GATv2Conv
        z_dict = self.encoder(x_dict, edge_index_dict)

        # Decode the edge embeddings to predict ratings
        return self.decoder(z_dict, edge_label_index)

# Model initialization with hyperparameters
hidden_dim = 64  # Embedding dimension size
item_feature_dim = train_graph['recipe'].x.size(1)
num_users = train_graph['user']['num_nodes']
args = {
    'dropout_rate': 0.6
}

# Instantiate the model
model = GATv2RecModel(hidden_dim=hidden_dim, num_users=num_users, item_feature_dim=item_feature_dim, args=args).to(device)

model

GATv2RecModel(
  (user_embedding): Embedding(226570, 64)
  (recipe_transform): Linear(in_features=3081, out_features=64, bias=True)
  (encoder): GraphModule(
    (conv1): ModuleDict(
      (user__rates__recipe): GATv2Conv((-1, -1), 64, heads=4)
      (recipe__rev_rates__user): GATv2Conv((-1, -1), 64, heads=4)
    )
    (conv2): ModuleDict(
      (user__rates__recipe): GATv2Conv((-1, -1), 64, heads=1)
      (recipe__rev_rates__user): GATv2Conv((-1, -1), 64, heads=1)
    )
  )
  (decoder): EdgeDecoder(
    (lin1): Linear(in_features=128, out_features=64, bias=True)
    (lin2): Linear(in_features=64, out_features=1, bias=True)
  )
  (dropout): Dropout(p=0.6, inplace=False)
)

## Model Training

In this step, we train our initialized \texttt{SageRecModel} model to optimize for the edge rating prediction task, where we predict user ratings for recipes.

### Mini-Batching
To manage memory and computation on our large graph datasets, we use mini-batching, dividing the data into smaller, manageable subsets that the model processes sequentially. With PyG’s `LinkNeighborLoader`, we sample neighbors around each target edge, focusing on the local neighborhood of each user-recipe interaction. This approach enables the model to capture essential neighborhood context without loading the entire graph, making it highly efficient for large-scale training.


### Optimization
- **Optimizer**: To optimize the model’s parameters, we use Adam optimizer that handles sparse gradients well, suitable for graph neural networks. Regularization is applied through weight decay, which discourages complex solutions and helps prevent overfitting, leading to better generalization.  
To further stabilize training, we use a learning rate scheduler that dynamically adjusts the learning rate, reducing it gradually as the model approaches convergence. This prevents overshooting during optimization and enables fine-tuning for more accurate predictions.


- **Loss Function**: Mean Squared Error (MSE) loss measures the difference between predicted and actual ratings. Given the continuous nature of ratings, MSE is a suitable choice for our link regression task and is calculated as:

  $$
  \text{MSE Loss} = \frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2
  $$

  where $y_i$ is the true rating, $\hat{y}_i$ is the predicted rating, and $N$ is the number of samples.


### Training Process
For each epoch, we perform the following steps:
  1. Batch sampling: Sampling mini-batches of edges with the train data loader, allowing the model to process a portion of the data at each step.
  2. Forward propagation: Generating predictions for each mini-batch and calculating the MSE loss based on the difference between predicted and actual ratings.
  3. Backward Ppropagation: Updating the model’s parameters based on the computed gradients.
  4. Parameter update: Updating the model’s parameters based on the computed gradients by optimizer.
  5. Learning rate adjustment: Adjusting the learning rate periodically with the scheduler, stabilizing training as the model converges.

In [6]:
def train(model, data_loader, optimizer, scheduler):
    model.train()
    total_loss = 0
    for batch in tqdm(data_loader, desc='Training', unit='batch', leave=False):
        batch = batch.to(device)
        optimizer.zero_grad()
        pred = model(batch.x_dict, batch.edge_index_dict, batch['user', 'rates', 'recipe'].edge_label_index)
        # Flatten target to match pred.
        target = batch['user', 'rates', 'recipe'].edge_label.float().view(-1)
        loss = F.mse_loss(pred, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * target.size(0)
    
    # Compute average loss (MSE) per data point (edge). 
    mse = total_loss / len(data_loader.dataset)

    # Step the scheduler to update the learning rate
    scheduler.step()

    return mse


def create_link_neighbor_loader(data, edge_type, batch_size, num_neighbors, shuffle, num_workers):
    """
    Creates a LinkNeighborLoader for the specified edge type in a HeteroData object.

    Parameters:
    - data (HeteroData): The heterogeneous graph data.
    - edge_type (tuple): The edge type for which to create the loader, e.g., ('user', 'rates', 'recipe').
    - batch_size (int): Number of edges to include in each batch.
    - num_neighbors (list): Number of neighbors to sample at each layer.
    - shuffle (bool): Whether to shuffle the data.
    - num_workers (int): Number of subprocesses to use for data loading.

    Returns:
    - loader (LinkNeighborLoader): The data loader for the specified edge type.
    """
    # Ensure the edge_type exists in the data
    if edge_type not in data.edge_types:
        raise ValueError(f"Edge type {edge_type} not found in the data.")

    # Access the edge_label_index and edge_label for the specified edge type
    edge_label_index = data[edge_type].get('edge_label_index', data[edge_type].edge_index)
    edge_label = data[edge_type].get('edge_label', None)

    # Create the LinkNeighborLoader
    loader = LinkNeighborLoader(
        data=data,
        num_neighbors=num_neighbors,
        edge_label_index=(edge_type, edge_label_index),
        edge_label=edge_label,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
    )

    return loader


edge_type = ('user', 'rates', 'recipe') # Define the edge type of interest
batch_size = 512  # Adjust based on your GPU memory capacity
num_neighbors = [10, 5] # Number of neighbors to sample at each layer
num_workers = 4  # Adjust based on your system

# Create the training data loader
train_data_loader = create_link_neighbor_loader(
    data=train_graph,
    edge_type=edge_type,
    batch_size=batch_size, 
    num_neighbors=num_neighbors, 
    shuffle=True,
    num_workers=num_workers 
)

weight_decay = 0.0001
learning_rate = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# Implement a learning rate scheduler
step_size = 10
gamma = 0.1
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

# Training loop
num_epochs = 20
for epoch in range(1, num_epochs + 1):
    loss = train(model, train_data_loader, optimizer, scheduler)
    print(f'Epoch: {epoch:03d}, Loss (MSE): {loss:.4f}')

                                                              

Epoch: 001, Loss (MSE): 3.2935


                                                              

Epoch: 002, Loss (MSE): 1.5323


                                                              

Epoch: 003, Loss (MSE): 1.5269


                                                              

Epoch: 004, Loss (MSE): 1.5259


                                                              

Epoch: 005, Loss (MSE): 1.5254


                                                              

Epoch: 006, Loss (MSE): 1.5250


                                                              

Epoch: 007, Loss (MSE): 1.5249


                                                              

Epoch: 008, Loss (MSE): 1.5233


                                                              

Epoch: 009, Loss (MSE): 1.5236


                                                              

Epoch: 010, Loss (MSE): 1.5218


                                                              

Epoch: 011, Loss (MSE): 1.5190


                                                              

Epoch: 012, Loss (MSE): 1.5180


                                                              

Epoch: 013, Loss (MSE): 1.5178


                                                              

Epoch: 014, Loss (MSE): 1.5167


                                                              

Epoch: 015, Loss (MSE): 1.5189


                                                              

Epoch: 016, Loss (MSE): 1.5179


                                                              

Epoch: 017, Loss (MSE): 1.5183


                                                              

Epoch: 018, Loss (MSE): 1.5166


                                                              

Epoch: 019, Loss (MSE): 1.5175


                                                              

Epoch: 020, Loss (MSE): 1.5186




# Evaluation
Finally, we evaluate the performance of our model on the validation and test graphs using the **Root Mean Squared Error (RMSE)** and **@Recall@k** metrics. 
Although MSE is used as a loss function for training due to its efficient gradient properties, we use RMSE for evaluation because it provides error values in the same units as the target variable, making it more interpretable when assessing model performance.

We also report the evaluation results on the training graph to monitor the model error.

In [7]:
@torch.no_grad()
def encode_graph(model, graph, device):
    """
    Precomputes node embeddings for the full graph.
    """
    graph = graph.to(device)
    x_dict = graph.x_dict
    edge_index_dict = graph.edge_index_dict

    # Normalize and precompute embeddings
    x_dict['user'] = F.normalize(model.user_embedding.weight, p=2, dim=1)
    x_dict['recipe'] = F.normalize(model.recipe_transform(x_dict['recipe']), p=2, dim=1)
    return model.encoder(x_dict, edge_index_dict)


@torch.no_grad()
def calculate_rmse(decoder, z_dict, edge_label_index, edge_label, batch_size, device):
    """
    Computes RMSE for edge predictions.
    """
    total_loss = 0
    total_edges = 0

    # Ensure model is in evaluation mode
    model.eval()
    
    for i in tqdm(range(0, edge_label_index.size(1), batch_size), desc="RMSE", leave=False):
        batch_edge_index = edge_label_index[:, i:i+batch_size]
        batch_edge_label = edge_label[i:i+batch_size].to(device)
        
        # Remove extra dimension to match predictions
        batch_edge_label = batch_edge_label.view(-1)
        
        pred = decoder(z_dict, batch_edge_index).clamp(min=0, max=5)
        total_loss += F.mse_loss(pred, batch_edge_label, reduction='sum').item()
        total_edges += batch_edge_label.size(0)

    return (total_loss / total_edges) ** 0.5

@torch.no_grad()
def calculate_recall_at_k(decoder, z_dict, edge_label_index, edge_label, batch_size, k, relevance_threshold, device):
    """
    Computes Recall@K for edge predictions.
    """
    from collections import defaultdict
    user_predictions = defaultdict(list)
    user_true_items = defaultdict(set)

    # Ensure model is in evaluation mode
    model.eval()

    for i in tqdm(range(0, edge_label_index.size(1), batch_size), desc="Recall@K", leave=False):
        batch_edge_index = edge_label_index[:, i:i+batch_size]
        batch_edge_label = edge_label[i:i+batch_size].to(device)
        
        # Remove extra dimension to match predictions
        batch_edge_label = batch_edge_label.view(-1)
        
        pred = decoder(z_dict, batch_edge_index).clamp(min=0, max=5)

        users, items = batch_edge_index
        for j, user_id in enumerate(users.cpu().numpy()):
            user_predictions[user_id].append((pred[j].item(), items[j].item()))
            if batch_edge_label[j].item() >= relevance_threshold:
                user_true_items[user_id].add(items[j].item())

    recalls = []
    for user_id in user_predictions:
        top_k_pred_items = {item for _, item in sorted(user_predictions[user_id], key=lambda x: x[0], reverse=True)[:k]}
        true_items = user_true_items[user_id]
        if true_items:
            recalls.append(len(top_k_pred_items & true_items) / len(true_items))
    return sum(recalls) / len(recalls) if recalls else 0.0


k = 5
relevance_threshold = 4
model.eval()

for split_name, graph in zip(["Validation", "Test"], [validation_graph, test_graph]):
    z_dict = encode_graph(model, graph, device)

    edge_label_index = graph['user', 'rates', 'recipe'].edge_label_index
    edge_label = graph['user', 'rates', 'recipe'].edge_label.float()

    rmse = calculate_rmse(model.decoder, z_dict, edge_label_index, edge_label, batch_size, device)
    recall_at_k = calculate_recall_at_k(model.decoder, z_dict, edge_label_index, edge_label, batch_size, k, relevance_threshold, device)

    print(f"{split_name}: RMSE = {rmse:.4f}, Recall@{k} = {recall_at_k:.4f}")

                                                           

Validation: RMSE = 1.2432, Recall@5 = 0.9646


                                                           

Test: RMSE = 1.2270, Recall@5 = 0.9588
