<a href="https://colab.research.google.com/github/subhashpolisetti/timegpt-tabula-rdl-forecasting/blob/main/Graph_Neural_Network_for_Driver_Position_Prediction_RelBench.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# End-to-End Machine Learning Model Training and Evaluation with RelBench, PyTorch, and PyTorch Geometric

## Overview
This Colab notebook demonstrates the end-to-end process of training and evaluating a **Graph Neural Network (GNN)** model for **driver position prediction** in Formula 1 races using **RelBench**, **PyTorch**, and **PyTorch Geometric**. We will be working with the `rel-f1` dataset, which contains data related to Formula 1 races, drivers, constructor standings, and results. Our goal is to predict the **driver position** during a race based on various historical data.

In this notebook, we will cover:
1. **Package Installation**: Installing all the required libraries.
2. **Dataset Loading and Preprocessing**: Loading and preprocessing the `rel-f1` dataset.
3. **Model Building**: Constructing a Graph Neural Network model.
4. **Model Training**: Training the model and evaluating the performance.
5. **Model Evaluation**: Assessing the model performance using common metrics such as **MAE**, **RMSE**, and **R²**.

## Steps

### 1. Install Required Packages
The first step is to install the necessary Python packages for this project. We install:
- **PyTorch**: For building and training the model.
- **PyTorch Geometric**: For graph neural networks.
- **RelBench**: For handling the dataset and evaluation.
- **Sentence-Transformers**: For embedding text data.

```python
!pip install torch==2.4.0
!pip install torch-geometric torch-sparse torch-scatter torch-cluster torch-spline-conv pyg-lib -f https://data.pyg.org/whl/torch-2.4.0+cpu.html
!pip install pytorch_frame
!pip install relbench
!pip install -U sentence-transformers


In [None]:
pip install torch==2.0.0+cu117 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117

Looking in indexes: https://download.pytorch.org/whl/cu117
Collecting torch==2.0.0+cu117
  Downloading https://download.pytorch.org/whl/cu117/torch-2.0.0%2Bcu117-cp310-cp310-linux_x86_64.whl (1843.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 GB[0m [31m858.6 kB/s[0m eta [36m0:00:00[0m
Collecting triton==2.0.0 (from torch==2.0.0+cu117)
  Downloading https://download.pytorch.org/whl/triton-2.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (63.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.3/63.3 MB[0m [31m12.8 MB/s[0m eta [36m0:00:00[0m
Collecting lit (from triton==2.0.0->torch==2.0.0+cu117)
  Downloading https://download.pytorch.org/whl/lit-15.0.7.tar.gz (132 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m132.3/132.3 kB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
INFO: pip is looking at multiple versions of torchvision to de

In [None]:
pip install torch-geometric torch-sparse torch-scatter torch-cluster torch-spline-conv pyg-lib -f https://data.pyg.org/whl/torch-2.0.0+cu117.html

Looking in links: https://data.pyg.org/whl/torch-2.0.0+cu117.html
Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-2.0.0%2Bcu117/torch_sparse-0.6.18%2Bpt20cu117-cp310-cp310-linux_x86_64.whl (4.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.9/4.9 MB[0m [31m53.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-2.0.0%2Bcu117/torch_scatter-2.1.2%2Bpt20cu117-cp310-cp310-linux_x86_64.whl (10.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.1/10.1 MB[0m [31m67.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-cluster
  Downloading https://data.pyg.org/whl/torch-2.0.0%2Bcu117/torch_cluster-1.6.3%2Bpt20cu117-cp310-cp310-linux_x86_64.whl (3.3 M

In [None]:
# Install required packages.
!pip install torch==2.4.0
!pip install torch-geometric torch-sparse torch-scatter torch-cluster torch-spline-conv pyg-lib -f https://data.pyg.org/whl/torch-2.4.0+cpu.html
!pip install pytorch_frame
!pip install relbench

2.3.0+cu121
Looking in links: https://data.pyg.org/whl/torch-2.3.0+cu121.html
Collecting pyg-lib
  Downloading https://data.pyg.org/whl/torch-2.3.0%2Bcu121/pyg_lib-0.4.0%2Bpt23cu121-cp310-cp310-linux_x86_64.whl (2.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m22.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pyg-lib
Successfully installed pyg-lib-0.4.0+pt23cu121
Collecting git+https://github.com/pyg-team/pytorch_geometric.git
  Cloning https://github.com/pyg-team/pytorch_geometric.git to /tmp/pip-req-build-0v5d62c1
  Running command git clone --filter=blob:none --quiet https://github.com/pyg-team/pytorch_geometric.git /tmp/pip-req-build-0v5d62c1
  Resolved https://github.com/pyg-team/pytorch_geometric.git to commit fbafbc4fc9181e8759ec1f39d9618992793b5fe1
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdo

In [None]:
import os
import torch
import relbench

relbench.__version__

'1.0.0'

In [None]:
import numpy as np

from torch.nn import BCEWithLogitsLoss, L1Loss
from relbench.datasets import get_dataset
from relbench.tasks import get_task

dataset = get_dataset("rel-f1", download=True)
task = get_task("rel-f1", "driver-position", download=True)

train_table = task.get_table("train")
val_table = task.get_table("val")
test_table = task.get_table("test")

out_channels = 1
loss_fn = L1Loss()
tune_metric = "mae"
higher_is_better = False

In [None]:
train_table

Table(df=
           date  driverId  position
0    2004-07-05        10     10.75
1    2004-07-05        47     12.00
2    2004-03-07         7     15.00
3    2004-01-07        10      9.00
4    2003-09-09        52     13.00
...         ...       ...       ...
7448 1995-08-22        96     15.75
7449 1975-06-08       228      8.00
7450 1965-05-31       418     16.00
7451 1961-08-20       467     37.00
7452 1954-05-29       677     30.00

[7453 rows x 3 columns],
  fkey_col_to_pkey_table={'driverId': 'drivers'},
  pkey_col=None,
  time_col=date)

In [None]:
import os
import math
import numpy as np
from tqdm import tqdm

import torch
import torch_geometric
import torch_frame

# Some book keeping
from torch_geometric.seed import seed_everything

seed_everything(42)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)  # check that it's cuda if you want it to run in reasonable time!
root_dir = "./data"

cuda


In [None]:
from relbench.modeling.utils import get_stype_proposal

db = dataset.get_db()
col_to_stype_dict = get_stype_proposal(db)
col_to_stype_dict

Loading Database object from /root/.cache/relbench/rel-f1/db...
Done in 0.04 seconds.


{'drivers': {'driverId': <stype.numerical: 'numerical'>,
  'driverRef': <stype.text_embedded: 'text_embedded'>,
  'code': <stype.text_embedded: 'text_embedded'>,
  'forename': <stype.text_embedded: 'text_embedded'>,
  'surname': <stype.text_embedded: 'text_embedded'>,
  'dob': <stype.timestamp: 'timestamp'>,
  'nationality': <stype.text_embedded: 'text_embedded'>},
 'races': {'raceId': <stype.numerical: 'numerical'>,
  'year': <stype.categorical: 'categorical'>,
  'round': <stype.numerical: 'numerical'>,
  'circuitId': <stype.numerical: 'numerical'>,
  'name': <stype.text_embedded: 'text_embedded'>,
  'date': <stype.timestamp: 'timestamp'>,
  'time': <stype.timestamp: 'timestamp'>},
 'constructor_standings': {'constructorStandingsId': <stype.numerical: 'numerical'>,
  'raceId': <stype.numerical: 'numerical'>,
  'constructorId': <stype.numerical: 'numerical'>,
  'points': <stype.numerical: 'numerical'>,
  'position': <stype.numerical: 'numerical'>,
  'wins': <stype.numerical: 'numerical

In [None]:
!pip install -U sentence-transformers # we need another package for text encoding
from typing import List, Optional
from sentence_transformers import SentenceTransformer
from torch import Tensor


class GloveTextEmbedding:
    def __init__(self, device: Optional[torch.device
                                       ] = None):
        self.model = SentenceTransformer(
            "sentence-transformers/average_word_embeddings_glove.6B.300d",
            device=device,
        )

    def __call__(self, sentences: List[str]) -> Tensor:
        return torch.from_numpy(self.model.encode(sentences))



Collecting sentence-transformers
  Downloading sentence_transformers-3.0.1-py3-none-any.whl (227 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/227.1 kB[0m [31m?[0m eta [36m-:--:--[0m
[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m225.3/227.1 kB[0m [31m7.1 MB/s[0m eta [36m0:00:01[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m227.1/227.1 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: sentence-transformers
Successfully installed sentence-transformers-3.0.1


In [None]:
# Import necessary functions for text embedding and graph creation
from torch_frame.config.text_embedder import TextEmbedderConfig
from relbench.modeling.graph import make_pkey_fkey_graph

# Set up the text embedding configuration using the GloveTextEmbedding model
text_embedder_cfg = TextEmbedderConfig(
    text_embedder=GloveTextEmbedding(device=device),  # Specifying the text embedding model (GloVe embeddings)
    batch_size=256  # Setting the batch size for embedding
)

# Create the primary key (pkey) and foreign key (fkey) graph using the relational data
data, col_stats_dict = make_pkey_fkey_graph(
    db,  # The database containing the relational data
    col_to_stype_dict=col_to_stype_dict,  # A dictionary mapping column names to their types
    text_embedder_cfg=text_embedder_cfg,  # Configuration for the text embedding model
    cache_dir=os.path.join(
        root_dir, f"rel-f1_materialized_cache"  # Directory to store the materialized graph for convenience
    ),
)


Embedding raw data in mini-batch: 100%|██████████| 4/4 [00:00<00:00,  6.11it/s]
Embedding raw data in mini-batch: 100%|██████████| 4/4 [00:00<00:00, 177.19it/s]
Embedding raw data in mini-batch: 100%|██████████| 4/4 [00:00<00:00, 185.70it/s]
Embedding raw data in mini-batch: 100%|██████████| 4/4 [00:00<00:00, 184.06it/s]
Embedding raw data in mini-batch: 100%|██████████| 4/4 [00:00<00:00, 157.68it/s]
  ser = pd.to_datetime(ser, format=time_format)
Embedding raw data in mini-batch: 100%|██████████| 4/4 [00:00<00:00, 144.99it/s]
  ser = pd.to_datetime(ser, format=self.format, errors='coerce')
Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00, 151.76it/s]
Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00, 172.90it/s]
Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00, 207.29it/s]
Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00, 148.61it/s]
Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00, 42.85it/s

We can now check out `data`, our main graph object. `data` is a heterogeneous and temporal graph, with node types given by the table it originates from.

In [None]:
data

HeteroData(
  drivers={ tf=TensorFrame([857, 6]) },
  races={
    tf=TensorFrame([820, 5]),
    time=[820],
  },
  constructor_standings={
    tf=TensorFrame([10170, 4]),
    time=[10170],
  },
  constructor_results={
    tf=TensorFrame([9408, 2]),
    time=[9408],
  },
  results={
    tf=TensorFrame([20323, 11]),
    time=[20323],
  },
  qualifying={
    tf=TensorFrame([4082, 3]),
    time=[4082],
  },
  circuits={ tf=TensorFrame([77, 7]) },
  constructors={ tf=TensorFrame([211, 3]) },
  standings={
    tf=TensorFrame([28115, 4]),
    time=[28115],
  },
  (races, f2p_circuitId, circuits)={ edge_index=[2, 820] },
  (circuits, rev_f2p_circuitId, races)={ edge_index=[2, 820] },
  (constructor_standings, f2p_raceId, races)={ edge_index=[2, 10170] },
  (races, rev_f2p_raceId, constructor_standings)={ edge_index=[2, 10170] },
  (constructor_standings, f2p_constructorId, constructors)={ edge_index=[2, 10170] },
  (constructors, rev_f2p_constructorId, constructor_standings)={ edge_index=[2, 1

In [None]:
data["races"].tf

TensorFrame(
  num_cols=5,
  num_rows=820,
  categorical (1): ['year'],
  numerical (1): ['round'],
  timestamp (2): ['date', 'time'],
  embedding (1): ['name'],
  has_target=False,
  device='cpu',
)

In [None]:
list(data["races"].keys())

['tf', 'time']

In [None]:
data["races"].tf[10]

TensorFrame(
  num_cols=5,
  num_rows=1,
  categorical (1): ['year'],
  numerical (1): ['round'],
  timestamp (2): ['date', 'time'],
  embedding (1): ['name'],
  has_target=False,
  device='cpu',
)

In [None]:
data["races"].tf[10:20]

TensorFrame(
  num_cols=5,
  num_rows=10,
  categorical (1): ['year'],
  numerical (1): ['round'],
  timestamp (2): ['date', 'time'],
  embedding (1): ['name'],
  has_target=False,
  device='cpu',
)

In [None]:
data[("races", "f2p_circuitId", "circuits")]

{'edge_index': tensor([[  0,   1,   2,  ..., 817, 818, 819],
        [  8,   5,  18,  ...,  21,  17,  23]])}

In [None]:
# Importing necessary functions from relbench and torch_geometric for graph processing
from relbench.modeling.graph import get_node_train_table_input, make_pkey_fkey_graph
from torch_geometric.loader import NeighborLoader

# Dictionary to store loaders for different data splits (train, validation, and test)
loader_dict = {}

# Iterate over different splits: train, val, and test tables
for split, table in [
    ("train", train_table),
    ("val", val_table),
    ("test", test_table),
]:
    # Get the input data for nodes from the table, this prepares the data for the training task
    table_input = get_node_train_table_input(
        table=table,  # Data table for the split (train, val, or test)
        task=task,    # The task at hand (could be regression, classification, etc.)
    )

    # Extract the entity table (the first node type in the input)
    entity_table = table_input.nodes[0]

    # Create a NeighborLoader for sampling neighborhoods of nodes (for mini-batch processing)
    loader_dict[split] = NeighborLoader(
        data,  # The graph data
        num_neighbors=[128 for i in range(2)],  # We sample subgraphs of depth 2, with 128 neighbors per node
        time_attr="time",  # Specify the attribute used for time (if applicable)
        input_nodes=table_input.nodes,  # The input nodes that we are considering for the mini-batch
        input_time=table_input.time,  # The time attribute for temporal data
        transform=table_input.transform,  # Apply transformations if any (e.g., normalization)
        batch_size=512,  # The size of each batch
        temporal_strategy="uniform",  # Temporal strategy to apply (e.g., uniform, sequential)
        shuffle=split == "train",  # Shuffle the data only for training
        num_workers=0,  # Number of workers for loading the data (0 for sequential loading)
        persistent_workers=False,  # Whether to keep workers alive between epochs (False means new workers per epoch)
    )


In [None]:
# Import necessary libraries and modules
from torch.nn import BCEWithLogitsLoss
import copy
from typing import Any, Dict, List

import torch
from torch import Tensor
from torch.nn import Embedding, ModuleDict
from torch_frame.data.stats import StatType
from torch_geometric.data import HeteroData
from torch_geometric.nn import MLP
from torch_geometric.typing import NodeType

from relbench.modeling.nn import HeteroEncoder, HeteroGraphSAGE, HeteroTemporalEncoder


class Model(torch.nn.Module):
    """
    A deep learning model class that integrates heterogeneity, temporal data, and graph-based learning.
    """

    def __init__(
        self,
        data: HeteroData,
        col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]],
        num_layers: int,
        channels: int,
        out_channels: int,
        aggr: str,
        norm: str,
        shallow_list: List[NodeType] = [],
        id_awareness: bool = False,
    ):
        """
        Initialize the model components, including graph neural networks, encoders, and embeddings.

        Args:
            data: The input heterogeneous graph data
            col_stats_dict: Column statistics for each node type
            num_layers: Number of layers for the GNN
            channels: Number of channels for the model layers
            out_channels: Number of output channels for the prediction layer
            aggr: Aggregation method for GNN
            norm: Normalization method for layers
            shallow_list: List of node types to add shallow embeddings
            id_awareness: Whether to add ID-awareness to the model
        """
        super().__init__()

        # Encoder to process different columns of data for each node type
        self.encoder = HeteroEncoder(
            channels=channels,
            node_to_col_names_dict={
                node_type: data[node_type].tf.col_names_dict
                for node_type in data.node_types
            },
            node_to_col_stats=col_stats_dict,
        )

        # Temporal encoder to process time-based features
        self.temporal_encoder = HeteroTemporalEncoder(
            node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
            channels=channels,
        )

        # Graph neural network for learning representations
        self.gnn = HeteroGraphSAGE(
            node_types=data.node_types,
            edge_types=data.edge_types,
            channels=channels,
            aggr=aggr,
            num_layers=num_layers,
        )

        # Fully connected layer for making predictions
        self.head = MLP(
            channels,
            out_channels=out_channels,
            norm=norm,
            num_layers=1,
        )

        # Shallow embeddings for the specified node types
        self.embedding_dict = ModuleDict(
            {
                node: Embedding(data.num_nodes_dict[node], channels)
                for node in shallow_list
            }
        )

        # ID-awareness embedding if required
        self.id_awareness_emb = None
        if id_awareness:
            self.id_awareness_emb = torch.nn.Embedding(1, channels)

        # Initialize the model parameters
        self.reset_parameters()

    def reset_parameters(self):
        """
        Reset the parameters of all components in the model.
        """
        self.encoder.reset_parameters()
        self.temporal_encoder.reset_parameters()
        self.gnn.reset_parameters()
        self.head.reset_parameters()

        # Initialize shallow embeddings
        for embedding in self.embedding_dict.values():
            torch.nn.init.normal_(embedding.weight, std=0.1)

        # Initialize ID-awareness embedding if necessary
        if self.id_awareness_emb is not None:
            self.id_awareness_emb.reset_parameters()

    def forward(
        self,
        batch: HeteroData,
        entity_table: NodeType,
    ) -> Tensor:
        """
        Forward pass for the model, computing the node representations and final predictions.

        Args:
            batch: The input data batch
            entity_table: The entity table for which predictions are made

        Returns:
            Tensor: Model's prediction for the given entity table
        """
        seed_time = batch[entity_table].seed_time
        x_dict = self.encoder(batch.tf_dict)

        # Temporal encoding for time-sensitive data
        rel_time_dict = self.temporal_encoder(
            seed_time, batch.time_dict, batch.batch_dict
        )

        # Add temporal information to the node embeddings
        for node_type, rel_time in rel_time_dict.items():
            x_dict[node_type] = x_dict[node_type] + rel_time

        # Add shallow embeddings for specified node types
        for node_type, embedding in self.embedding_dict.items():
            x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)

        # Apply the graph neural network (GNN)
        x_dict = self.gnn(
            x_dict,
            batch.edge_index_dict,
            batch.num_sampled_nodes_dict,
            batch.num_sampled_edges_dict,
        )

        # Return the prediction from the head (fully connected layer)
        return self.head(x_dict[entity_table][: seed_time.size(0)])

    def forward_dst_readout(
        self,
        batch: HeteroData,
        entity_table: NodeType,
        dst_table: NodeType,
    ) -> Tensor:
        """
        Perform a forward pass with ID-awareness for destination node table readout.

        Args:
            batch: The input data batch
            entity_table: The entity table for which predictions are made
            dst_table: The destination node type

        Returns:
            Tensor: Prediction for the destination table
        """
        if self.id_awareness_emb is None:
            raise RuntimeError(
                "id_awareness must be set True to use forward_dst_readout"
            )

        seed_time = batch[entity_table].seed_time
        x_dict = self.encoder(batch.tf_dict)

        # Add ID-awareness to the root node
        x_dict[entity_table][: seed_time.size(0)] += self.id_awareness_emb.weight

        rel_time_dict = self.temporal_encoder(
            seed_time, batch.time_dict, batch.batch_dict
        )

        # Add temporal information to the node embeddings
        for node_type, rel_time in rel_time_dict.items():
            x_dict[node_type] = x_dict[node_type] + rel_time

        # Add shallow embeddings for specified node types
        for node_type, embedding in self.embedding_dict.items():
            x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)

        # Apply the graph neural network (GNN)
        x_dict = self.gnn(
            x_dict,
            batch.edge_index_dict,
        )

        # Return the prediction for the destination table
        return self.head(x_dict[dst_table])


# Instantiate the model with specified parameters
model = Model(
    data=data,
    col_stats_dict=col_stats_dict,
    num_layers=2,
    channels=128,
    out_channels=1,
    aggr="sum",  # Aggregation method for GNN
    norm="batch_norm",  # Normalization method
).to(device)  # Move the model to the device (GPU or CPU)

# Define the optimizer and number of epochs
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)  # Adam optimizer
epochs = 10  # Number of training epochs


In [None]:
# Function to train the model for one epoch
def train() -> float:
    model.train()  # Set the model to training mode

    loss_accum = count_accum = 0  # Initialize variables to accumulate loss and count number of examples

    # Iterate through each batch in the training loader
    for batch in tqdm(loader_dict["train"]):  # tqdm provides a progress bar for iterations
        batch = batch.to(device)  # Move the batch to the device (GPU or CPU)

        optimizer.zero_grad()  # Reset the gradients of the model
        pred = model(
            batch,  # Forward pass: Pass the batch to the model
            task.entity_table,  # Specify the entity table for the task
        )

        # If the prediction has a size of 1 for the second dimension (i.e., a single value per example),
        # reshape it to a flat vector to match the target labels.
        pred = pred.view(-1) if pred.size(1) == 1 else pred

        # Compute the loss between the predicted values and the actual target labels
        loss = loss_fn(pred.float(), batch[entity_table].y.float())
        loss.backward()  # Perform backpropagation to calculate gradients
        optimizer.step()  # Update the model parameters using the computed gradients

        # Accumulate the loss and number of examples for computing the average loss
        loss_accum += loss.detach().item() * pred.size(0)
        count_accum += pred.size(0)

    # Return the average loss for the epoch
    return loss_accum / count_accum


# Function to test the model's performance on a given loader
@torch.no_grad()  # Disable gradient calculation to save memory and computations during testing
def test(loader: NeighborLoader) -> np.ndarray:
    model.eval()  # Set the model to evaluation mode

    pred_list = []  # Initialize a list to store predictions
    for batch in loader:  # Iterate through each batch in the loader
        batch = batch.to(device)  # Move the batch to the device (GPU or CPU)
        pred = model(
            batch,  # Forward pass: Pass the batch to the model
            task.entity_table,  # Specify the entity table for the task
        )

        # If the prediction has a size of 1 for the second dimension (i.e., a single value per example),
        # reshape it to a flat vector to match the target labels.
        pred = pred.view(-1) if pred.size(1) == 1 else pred

        # Append the predictions (detached from the computational graph) to the list
        pred_list.append(pred.detach().cpu())

    # Concatenate all predictions into a single array and return
    return torch.cat(pred_list, dim=0).numpy()


In [None]:
# Initialize state_dict to store the model's best weights based on validation metrics
state_dict = None

# Set the initial best validation metric based on whether we want to maximize or minimize it
best_val_metric = -math.inf if higher_is_better else math.inf

# Loop through each epoch for training
for epoch in range(1, epochs + 1):
    # Train the model for one epoch and calculate the training loss
    train_loss = train()

    # Evaluate the model on the validation set
    val_pred = test(loader_dict["val"])
    val_metrics = task.evaluate(val_pred, val_table)

    # Print the current epoch's training loss and validation metrics
    print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")

    # Check if we should update the best model weights based on the validation metrics
    if (higher_is_better and val_metrics[tune_metric] > best_val_metric) or (
        not higher_is_better and val_metrics[tune_metric] < best_val_metric
    ):
        # Update the best validation metric and store the current model's weights
        best_val_metric = val_metrics[tune_metric]
        state_dict = copy.deepcopy(model.state_dict())

# Load the best model weights after training
model.load_state_dict(state_dict)

# Re-evaluate the model on the validation set with the best weights
val_pred = test(loader_dict["val"])
val_metrics = task.evaluate(val_pred, val_table)
print(f"Best Val metrics: {val_metrics}")

# Evaluate the model on the test set using the best weights
test_pred = test(loader_dict["test"])
test_metrics = task.evaluate(test_pred)
print(f"Best test metrics: {test_metrics}")


100%|██████████| 15/15 [00:03<00:00,  4.55it/s]


Epoch: 01, Train loss: 4.5940603298195954, Val metrics: {'r2': 0.26664876365429346, 'mae': 3.192621118001486, 'rmse': 3.970144408166473}


100%|██████████| 15/15 [00:02<00:00,  5.32it/s]


Epoch: 02, Train loss: 4.568261575973963, Val metrics: {'r2': 0.28401349234585715, 'mae': 3.1377921566297466, 'rmse': 3.9228590932193326}


100%|██████████| 15/15 [00:02<00:00,  5.33it/s]


Epoch: 03, Train loss: 4.508331975344442, Val metrics: {'r2': 0.2779053222548966, 'mae': 3.145869382063229, 'rmse': 3.9395567561975118}


100%|██████████| 15/15 [00:02<00:00,  5.21it/s]


Epoch: 04, Train loss: 4.454095084475588, Val metrics: {'r2': 0.2748767008953672, 'mae': 3.1624791348545886, 'rmse': 3.9478097883334575}


100%|██████████| 15/15 [00:03<00:00,  3.98it/s]


Epoch: 05, Train loss: 4.428226253993716, Val metrics: {'r2': 0.2624319702044985, 'mae': 3.2122242543724435, 'rmse': 3.9815422768300364}


100%|██████████| 15/15 [00:02<00:00,  5.23it/s]


Epoch: 06, Train loss: 4.3808791889818695, Val metrics: {'r2': 0.23813335297839988, 'mae': 3.2412406909282634, 'rmse': 4.046595277497622}


100%|██████████| 15/15 [00:03<00:00,  4.95it/s]


Epoch: 07, Train loss: 4.338853529823013, Val metrics: {'r2': 0.2456946012502239, 'mae': 3.247610264136621, 'rmse': 4.026464715586905}


100%|██████████| 15/15 [00:02<00:00,  5.10it/s]


Epoch: 08, Train loss: 4.280953785724709, Val metrics: {'r2': 0.2494808324189115, 'mae': 3.227723037073751, 'rmse': 4.016346595592636}


100%|██████████| 15/15 [00:03<00:00,  4.67it/s]


Epoch: 09, Train loss: 4.242680662992767, Val metrics: {'r2': 0.25660482780398064, 'mae': 3.185533819345132, 'rmse': 3.997239384230069}


100%|██████████| 15/15 [00:02<00:00,  5.23it/s]


Epoch: 10, Train loss: 4.173513173641142, Val metrics: {'r2': 0.24525833749507364, 'mae': 3.2292002695755078, 'rmse': 4.027628930176982}
Best Val metrics: {'r2': 0.2855097827935169, 'mae': 3.132982980591819, 'rmse': 3.918757894088351}
Best test metrics: {'r2': -0.026182202542586408, 'mae': 4.38193179076178, 'rmse': 5.278154456184378}
