# Training a GraphSAGE-based GNN Model for Food Recipe Recommendation

## Environment setup

In [1]:
import torch
import torch.nn.functional as F
import torch.optim as optim

from torch_geometric.nn import SAGEConv, to_hetero

  from .autonotebook import tqdm as notebook_tqdm


## Graph dataset Loading

In this step, we load the graphs already generated in the graph dataset generation step.    
Since generating graph datasets is time consiming, we won't add them to each notebook.

In [2]:
def load_graph(file_path):
    return torch.load(file_path)

dataset_version = 1
base_data_path = f"../data/graph/v{dataset_version}"

train_data = load_graph(f"{base_data_path}/train_graph.pt")
val_data = load_graph(f"{base_data_path}/validation_graph.pt")
# test_graph = load_graph(f"{base_data_path}/test_graph.pt")

train_data

  return torch.load(file_path)


HeteroData(
  user={ num_nodes=226570 },
  recipe={ x=[231637, 3081] },
  (user, rates, recipe)={
    edge_index=[2, 770011],
    edge_label=[192502, 1],
    edge_label_index=[2, 192502],
  },
  (recipe, rev_rates, user)={ edge_index=[2, 770011] }
)

In [3]:
train_data['user', 'rates', 'recipe'].edge_index

tensor([[  3106,    317,  16543,  ...,    541, 208023,    489],
        [211809,   6600, 109688,  ...,  62108,  96459, 200804]])

In [4]:
train_data.metadata()

(['user', 'recipe'],
 [('user', 'rates', 'recipe'), ('recipe', 'rev_rates', 'user')])

## Model implementation

### Model Architecture Overview
This GNN model is designed for edge rating prediction in a recipe recommendation system. It uses the GraphSAGE architecture to encode node features into embeddings and a custom edge decoder to predict ratings between user and recipe nodes. The model consists of three main components: the encoder, the decoder, and the integration of both into a complete model for training and inference.

### GNNEncoder Class
**Purpose**: Encodes node features into embeddings using the GraphSAGE architecture.

**Structure**:
- `conv1`: First SAGE convolutional layer for initial feature transformation.
- `conv2`: Second SAGE convolutional layer for output embedding generation.

**Forward Pass**:
- Takes node features `x` and edge connections `edge_index` as input.
- Applies `conv1` followed by a ReLU activation.
- Applies `conv2` to output the final node embeddings.

### EdgeDecoder Class
**Purpose**: Decodes the node embeddings to predict edge labels, such as ratings between users and recipes.

**Structure**:
- `lin1`: A fully connected layer that combines node embeddings from both ends of an edge.
- `lin2`: A linear layer that outputs a scalar representing the predicted edge label (e.g., a rating).

**Forward Pass**:
- Extracts embeddings for connected nodes (e.g., user and recipe).
- Concatenates these embeddings and passes them through `lin1` with a ReLU activation.
- Outputs a single value through `lin2` representing the predicted edge rating.

### Model Class
**Purpose**: Integrates the encoder and decoder to create a complete GNN model.

**Structure**:
- `encoder`: Instantiates the `GNNEncoder` and adapts it to heterogeneous graphs using `to_hetero`, allowing the model to handle different types of nodes and edges.
- `decoder`: A custom `EdgeDecoder` for predicting edge labels based on embeddings.

**Forward Pass**:
- Accepts a dictionary of node features `x_dict`, an edge index dictionary `edge_index_dict`, and the edge label index `edge_label_index`.
- Passes `x_dict` and `edge_index_dict` to the encoder to generate node embeddings.
- Uses the `decoder` to predict edge labels from these embeddings.

**Execution Context**:
- The model is set to run on a GPU if available, with `hidden_channels` set to 32 for embedding dimensions.

In [5]:
from torch_geometric.nn import SAGEConv, to_hetero

class GNNEncoder(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x


class EdgeDecoder(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.lin1 = torch.nn.Linear(2 * hidden_channels, hidden_channels)
        self.lin2 = torch.nn.Linear(hidden_channels, 1)

    def forward(self, z_dict, edge_label_index):
        row, col = edge_label_index
        z = torch.cat([z_dict['user'][row], z_dict['recipe'][col]], dim=-1)

        z = self.lin1(z).relu()
        z = self.lin2(z)
        return z.view(-1)


class Model(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.encoder = GNNEncoder(hidden_channels, hidden_channels)
        self.encoder = to_hetero(self.encoder, train_data.metadata(), aggr='sum')
        self.decoder = EdgeDecoder(hidden_channels)

    def forward(self, x_dict, edge_index_dict, edge_label_index):
        z_dict = self.encoder(x_dict, edge_index_dict)
        return self.decoder(z_dict, edge_label_index)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = Model(hidden_channels=32).to(device)

model

Model(
  (encoder): GraphModule(
    (conv1): ModuleDict(
      (user__rates__recipe): SAGEConv((-1, -1), 32, aggr=mean)
      (recipe__rev_rates__user): SAGEConv((-1, -1), 32, aggr=mean)
    )
    (conv2): ModuleDict(
      (user__rates__recipe): SAGEConv((-1, -1), 32, aggr=mean)
      (recipe__rev_rates__user): SAGEConv((-1, -1), 32, aggr=mean)
    )
  )
  (decoder): EdgeDecoder(
    (lin1): Linear(in_features=64, out_features=32, bias=True)
    (lin2): Linear(in_features=32, out_features=1, bias=True)
  )
)

## Model Training

In this step, we perform training to optimize our GNN model for the edge rating prediction task.

#### Key Training Components:
- **Optimizer**: `torch.optim.Adam` is used to adjust the model's parameters. It helps the model learn by minimizing the loss function over time. The learning rate is set to 0.01 but can be modified during model tuning.
- **Loss Function**: Mean Squared Error (MSE) loss is used for training, as we are predicting continuous values (ratings). The loss formula is:

$$
\text{MSE Loss} = \frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2
$$

where $y_i$ is the actual rating, $\hat{y}_i$ is the predicted rating, and $N$ is the total number of samples.

#### Training Process:
- The model makes predictions and calculates the MSE loss between the predictions and actual ratings.
- The optimizer updates the model's parameters to reduce the loss in subsequent training steps.

In [6]:
import torch.nn.functional as F

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def train():
    model.train()
    optimizer.zero_grad()
    pred = model(train_data.x_dict, train_data.edge_index_dict,
                 train_data['user', 'recipe'].edge_label_index)
    target = train_data['user', 'recipe'].edge_label
    loss = F.mse_loss(pred, target)
    loss.backward()
    optimizer.step()
    return float(loss)

@torch.no_grad()
def test(data):
    data = data.to(device)
    model.eval()
    pred = model(data.x_dict, data.edge_index_dict,
                 data['user', 'recipe'].edge_label_index)
    pred = pred.clamp(min=0, max=5)
    target = data['user', 'recipe'].edge_label.float()
    rmse = F.mse_loss(pred, target).sqrt()
    return float(rmse)


for epoch in range(1, 301):
    train_data = train_data.to(device)
    loss = train()
    train_rmse = test(train_data)
    val_rmse = test(val_data)
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_rmse:.4f}, '
          f'Val: {val_rmse:.4f}')

AttributeError: 'NoneType' object has no attribute 'dim'