In [1]:
!pip install -q torch==2.0.0+cu117 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117
!pip install -q torch-geometric torch-sparse torch-scatter torch-cluster torch-spline-conv pyg-lib -f https://data.pyg.org/whl/torch-2.0.0+cu117.html
!pip install -q pytorch_frame
!pip install -q -U sentence-transformers
!pip install -q relbench

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 GB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.3/63.3 MB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m132.3/132.3 kB[0m [31m8.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.1/6.1 MB[0m [31m56.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.1/6.1 MB[0m [31m36.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.4/4.4 MB[0m [31m65.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.4/4.4 MB[0m [31m45.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for lit (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
import os
import math
import copy
import numpy as np
from tqdm import tqdm
from typing import List, Optional, Any, Dict
from sentence_transformers import SentenceTransformer

import torch
from torch import Tensor
import torch_geometric
from torch_geometric.data import HeteroData
from torch_geometric.nn import MLP
from torch_geometric.typing import NodeType
from torch_geometric.seed import seed_everything
from torch_geometric.loader import NeighborLoader
import torch_frame
from torch_frame.data.stats import StatType
from torch_frame.config.text_embedder import TextEmbedderConfig
from torch.nn import BCEWithLogitsLoss, Embedding, ModuleDict, L1Loss

import relbench
from relbench.datasets import get_dataset
from relbench.tasks import get_task
from relbench.modeling.utils import get_stype_proposal
from relbench.modeling.graph import make_pkey_fkey_graph, get_node_train_table_input
from relbench.modeling.nn import HeteroEncoder, HeteroGraphSAGE, HeteroTemporalEncoder

print(relbench.__version__)



1.1.0


In [4]:
dataset = get_dataset("rel-f1", download=True)
task = get_task("rel-f1", "driver-position", download=True)

train_table = task.get_table("train")
val_table = task.get_table("val")
test_table = task.get_table("test")

out_channels = 1
loss_fn = L1Loss()
tune_metric = "mae"
higher_is_better = False

In [7]:
print(train_table)

Table(df=
           date  driverId  position
0    2004-07-05        10     10.75
1    2004-07-05        47     12.00
2    2004-03-07         7     15.00
3    2004-01-07        10      9.00
4    2003-09-09        52     13.00
...         ...       ...       ...
7448 1995-08-22        96     15.75
7449 1975-06-08       228      8.00
7450 1965-05-31       418     16.00
7451 1961-08-20       467     37.00
7452 1954-05-29       677     30.00

[7453 rows x 3 columns],
  fkey_col_to_pkey_table={'driverId': 'drivers'},
  pkey_col=None,
  time_col=date)


In [8]:
seed_everything(42)
root_dir = "./data"

In [9]:
db = dataset.get_db()
col_to_stype_dict = get_stype_proposal(db)
print(col_to_stype_dict)

Loading Database object from /root/.cache/relbench/rel-f1/db...
Done in 0.05 seconds.
{'results': {'resultId': <stype.numerical: 'numerical'>, 'raceId': <stype.numerical: 'numerical'>, 'driverId': <stype.numerical: 'numerical'>, 'constructorId': <stype.numerical: 'numerical'>, 'number': <stype.numerical: 'numerical'>, 'grid': <stype.numerical: 'numerical'>, 'position': <stype.numerical: 'numerical'>, 'positionOrder': <stype.numerical: 'numerical'>, 'points': <stype.numerical: 'numerical'>, 'laps': <stype.numerical: 'numerical'>, 'milliseconds': <stype.numerical: 'numerical'>, 'fastestLap': <stype.numerical: 'numerical'>, 'rank': <stype.numerical: 'numerical'>, 'statusId': <stype.numerical: 'numerical'>, 'date': <stype.timestamp: 'timestamp'>}, 'constructor_results': {'constructorResultsId': <stype.numerical: 'numerical'>, 'raceId': <stype.numerical: 'numerical'>, 'constructorId': <stype.numerical: 'numerical'>, 'points': <stype.numerical: 'numerical'>, 'date': <stype.timestamp: 'timest

In [10]:
class GloveTextEmbedding:
    def __init__(self, device: Optional[torch.device
                                       ] = None):
        self.model = SentenceTransformer(
            "sentence-transformers/average_word_embeddings_glove.6B.300d",
            device=device,
        )

    def __call__(self, sentences: List[str]) -> Tensor:
        return torch.from_numpy(self.model.encode(sentences))

In [12]:
text_embedder_cfg = TextEmbedderConfig(
    text_embedder=GloveTextEmbedding(), batch_size=256
)

data, col_stats_dict = make_pkey_fkey_graph(
    db,
    col_to_stype_dict=col_to_stype_dict,
    text_embedder_cfg=text_embedder_cfg,
    cache_dir=os.path.join(
        root_dir, f"rel-f1_materialized_cache"
    ),
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/248 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/2.19k [00:00<?, ?B/s]

(…)beddings/whitespacetokenizer_config.json:   0%|          | 0.00/4.61M [00:00<?, ?B/s]

(…)WordEmbeddings/wordembedding_config.json:   0%|          | 0.00/164 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/480M [00:00<?, ?B/s]

  return self.fget.__get__(instance, owner)()


1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Embedding raw data in mini-batch: 100%|██████████| 4/4 [00:00<00:00, 87.81it/s]
Embedding raw data in mini-batch: 100%|██████████| 4/4 [00:00<00:00, 181.93it/s]
Embedding raw data in mini-batch: 100%|██████████| 4/4 [00:00<00:00, 184.43it/s]
Embedding raw data in mini-batch: 100%|██████████| 4/4 [00:00<00:00, 176.52it/s]
Embedding raw data in mini-batch: 100%|██████████| 4/4 [00:00<00:00, 177.31it/s]
  ser = pd.to_datetime(ser, format=time_format)
Embedding raw data in mini-batch: 100%|██████████| 4/4 [00:00<00:00, 154.42it/s]
  ser = pd.to_datetime(ser, format=self.format, errors='coerce')
Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00, 107.97it/s]
Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00, 100.06it/s]
Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00, 161.28it/s]
Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00, 157.71it/s]
Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00, 251.28it/

In [13]:
print(data)

HeteroData(
  results={
    tf=TensorFrame([20323, 11]),
    time=[20323],
  },
  constructor_results={
    tf=TensorFrame([9408, 2]),
    time=[9408],
  },
  drivers={ tf=TensorFrame([857, 6]) },
  races={
    tf=TensorFrame([820, 5]),
    time=[820],
  },
  qualifying={
    tf=TensorFrame([4082, 3]),
    time=[4082],
  },
  constructors={ tf=TensorFrame([211, 3]) },
  standings={
    tf=TensorFrame([28115, 4]),
    time=[28115],
  },
  constructor_standings={
    tf=TensorFrame([10170, 4]),
    time=[10170],
  },
  circuits={ tf=TensorFrame([77, 7]) },
  (results, f2p_raceId, races)={ edge_index=[2, 20323] },
  (races, rev_f2p_raceId, results)={ edge_index=[2, 20323] },
  (results, f2p_driverId, drivers)={ edge_index=[2, 20323] },
  (drivers, rev_f2p_driverId, results)={ edge_index=[2, 20323] },
  (results, f2p_constructorId, constructors)={ edge_index=[2, 20323] },
  (constructors, rev_f2p_constructorId, results)={ edge_index=[2, 20323] },
  (constructor_results, f2p_raceId, races)=

In [14]:
print(data["races"].tf)
print(list(data["races"].keys()))

TensorFrame(
  num_cols=5,
  num_rows=820,
  categorical (1): ['year'],
  numerical (1): ['round'],
  timestamp (2): ['date', 'time'],
  embedding (1): ['name'],
  has_target=False,
  device='cpu',
)
['tf', 'time']


In [15]:
print(data["races"].tf[10])
print(data["races"].tf[10:20])

TensorFrame(
  num_cols=5,
  num_rows=1,
  categorical (1): ['year'],
  numerical (1): ['round'],
  timestamp (2): ['date', 'time'],
  embedding (1): ['name'],
  has_target=False,
  device='cpu',
)
TensorFrame(
  num_cols=5,
  num_rows=10,
  categorical (1): ['year'],
  numerical (1): ['round'],
  timestamp (2): ['date', 'time'],
  embedding (1): ['name'],
  has_target=False,
  device='cpu',
)


In [16]:
print(data[("races", "f2p_circuitId", "circuits")])

{'edge_index': tensor([[  0,   1,   2,  ..., 817, 818, 819],
        [  8,   5,  18,  ...,  21,  17,  23]])}


In [17]:
loader_dict = {}

for split, table in [
    ("train", train_table),
    ("val", val_table),
    ("test", test_table),
]:
    table_input = get_node_train_table_input(
        table=table,
        task=task,
    )
    entity_table = table_input.nodes[0]
    loader_dict[split] = NeighborLoader(
        data,
        num_neighbors=[
            128 for i in range(2)
        ],
        time_attr="time",
        input_nodes=table_input.nodes,
        input_time=table_input.time,
        transform=table_input.transform,
        batch_size=512,
        temporal_strategy="uniform",
        shuffle=split == "train",
        num_workers=0,
        persistent_workers=False,
    )

In [19]:
class Model(torch.nn.Module):

    def __init__(
        self,
        data: HeteroData,
        col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]],
        num_layers: int,
        channels: int,
        out_channels: int,
        aggr: str,
        norm: str,
        # List of node types to add shallow embeddings to input
        shallow_list: List[NodeType] = [],
        # ID awareness
        id_awareness: bool = False,
    ):
        super().__init__()

        self.encoder = HeteroEncoder(
            channels=channels,
            node_to_col_names_dict={
                node_type: data[node_type].tf.col_names_dict
                for node_type in data.node_types
            },
            node_to_col_stats=col_stats_dict,
        )
        self.temporal_encoder = HeteroTemporalEncoder(
            node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
            channels=channels,
        )
        self.gnn = HeteroGraphSAGE(
            node_types=data.node_types,
            edge_types=data.edge_types,
            channels=channels,
            aggr=aggr,
            num_layers=num_layers,
        )
        self.head = MLP(
            channels,
            out_channels=out_channels,
            norm=norm,
            num_layers=1,
        )
        self.embedding_dict = ModuleDict(
            {
                node: Embedding(data.num_nodes_dict[node], channels)
                for node in shallow_list
            }
        )

        self.id_awareness_emb = None
        if id_awareness:
            self.id_awareness_emb = torch.nn.Embedding(1, channels)
        self.reset_parameters()

    def reset_parameters(self):
        self.encoder.reset_parameters()
        self.temporal_encoder.reset_parameters()
        self.gnn.reset_parameters()
        self.head.reset_parameters()
        for embedding in self.embedding_dict.values():
            torch.nn.init.normal_(embedding.weight, std=0.1)
        if self.id_awareness_emb is not None:
            self.id_awareness_emb.reset_parameters()

    def forward(
        self,
        batch: HeteroData,
        entity_table: NodeType,
    ) -> Tensor:
        seed_time = batch[entity_table].seed_time
        x_dict = self.encoder(batch.tf_dict)

        rel_time_dict = self.temporal_encoder(
            seed_time, batch.time_dict, batch.batch_dict
        )

        for node_type, rel_time in rel_time_dict.items():
            x_dict[node_type] = x_dict[node_type] + rel_time

        for node_type, embedding in self.embedding_dict.items():
            x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)

        x_dict = self.gnn(
            x_dict,
            batch.edge_index_dict,
            batch.num_sampled_nodes_dict,
            batch.num_sampled_edges_dict,
        )

        return self.head(x_dict[entity_table][: seed_time.size(0)])

    def forward_dst_readout(
        self,
        batch: HeteroData,
        entity_table: NodeType,
        dst_table: NodeType,
    ) -> Tensor:
        if self.id_awareness_emb is None:
            raise RuntimeError(
                "id_awareness must be set True to use forward_dst_readout"
            )
        seed_time = batch[entity_table].seed_time
        x_dict = self.encoder(batch.tf_dict)
        # Add ID-awareness to the root node
        x_dict[entity_table][: seed_time.size(0)] += self.id_awareness_emb.weight

        rel_time_dict = self.temporal_encoder(
            seed_time, batch.time_dict, batch.batch_dict
        )

        for node_type, rel_time in rel_time_dict.items():
            x_dict[node_type] = x_dict[node_type] + rel_time

        for node_type, embedding in self.embedding_dict.items():
            x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)

        x_dict = self.gnn(
            x_dict,
            batch.edge_index_dict,
        )

        return self.head(x_dict[dst_table])


model = Model(
    data=data,
    col_stats_dict=col_stats_dict,
    num_layers=2,
    channels=128,
    out_channels=1,
    aggr="sum",
    norm="batch_norm",
)

optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
epochs = 10

In [20]:
def train() -> float:
    model.train()

    loss_accum = count_accum = 0
    for batch in tqdm(loader_dict["train"]):
        batch = batch

        optimizer.zero_grad()
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred

        loss = loss_fn(pred.float(), batch[entity_table].y.float())
        loss.backward()
        optimizer.step()

        loss_accum += loss.detach().item() * pred.size(0)
        count_accum += pred.size(0)

    return loss_accum / count_accum


@torch.no_grad()
def test(loader: NeighborLoader) -> np.ndarray:
    model.eval()

    pred_list = []
    for batch in loader:
        batch = batch
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
    return torch.cat(pred_list, dim=0).numpy()

In [21]:
state_dict = None
best_val_metric = -math.inf if higher_is_better else math.inf
for epoch in range(1, epochs + 1):
    train_loss = train()
    val_pred = test(loader_dict["val"])
    val_metrics = task.evaluate(val_pred, val_table)
    print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")

    if (higher_is_better and val_metrics[tune_metric] > best_val_metric) or (
        not higher_is_better and val_metrics[tune_metric] < best_val_metric
    ):
        best_val_metric = val_metrics[tune_metric]
        state_dict = copy.deepcopy(model.state_dict())


model.load_state_dict(state_dict)
val_pred = test(loader_dict["val"])
val_metrics = task.evaluate(val_pred, val_table)
print(f"Best Val metrics: {val_metrics}")

test_pred = test(loader_dict["test"])
test_metrics = task.evaluate(test_pred)
print(f"Best test metrics: {test_metrics}")

100%|██████████| 15/15 [01:39<00:00,  6.62s/it]


Epoch: 01, Train loss: 9.099163366678848, Val metrics: {'r2': -0.2675986401170305, 'mae': 4.392142855173441, 'rmse': 5.219647467271514}


100%|██████████| 15/15 [01:30<00:00,  6.05s/it]


Epoch: 02, Train loss: 5.910797346906663, Val metrics: {'r2': -0.3550511039380364, 'mae': 4.306414009685427, 'rmse': 5.396698139701008}


100%|██████████| 15/15 [01:30<00:00,  6.05s/it]


Epoch: 03, Train loss: 5.5677416742079915, Val metrics: {'r2': 0.015994014881060004, 'mae': 3.7902872582475746, 'rmse': 4.5988491013713775}


100%|██████████| 15/15 [01:32<00:00,  6.14s/it]


Epoch: 04, Train loss: 5.422654834702785, Val metrics: {'r2': 0.03601160515959145, 'mae': 3.7473916792121025, 'rmse': 4.551831662288851}


100%|██████████| 15/15 [01:33<00:00,  6.25s/it]


Epoch: 05, Train loss: 5.387928307944139, Val metrics: {'r2': 0.029777992367798878, 'mae': 3.6591208485340228, 'rmse': 4.566525113673676}


100%|██████████| 15/15 [01:34<00:00,  6.31s/it]


Epoch: 06, Train loss: 5.303422879999985, Val metrics: {'r2': 0.23147598946830195, 'mae': 3.2190776916687374, 'rmse': 4.0642368558318465}


100%|██████████| 15/15 [01:29<00:00,  6.00s/it]


Epoch: 07, Train loss: 5.03447160589028, Val metrics: {'r2': 0.2480980571656164, 'mae': 3.2604838240043112, 'rmse': 4.020044801583635}


100%|██████████| 15/15 [01:30<00:00,  6.03s/it]


Epoch: 08, Train loss: 4.911777497297579, Val metrics: {'r2': 0.20175415396362728, 'mae': 3.3278617384916314, 'rmse': 4.142081341392157}


100%|██████████| 15/15 [01:31<00:00,  6.07s/it]


Epoch: 09, Train loss: 4.833422489299337, Val metrics: {'r2': 0.26563119784378597, 'mae': 3.1834546666345998, 'rmse': 3.9728978519413483}


100%|██████████| 15/15 [01:30<00:00,  6.03s/it]


Epoch: 10, Train loss: 4.770833364064687, Val metrics: {'r2': 0.2511895039171441, 'mae': 3.253102861059134, 'rmse': 4.011772077277201}




Best Val metrics: {'r2': 0.26551552856427907, 'mae': 3.1854161768972515, 'rmse': 3.973210722080112}
Best test metrics: {'r2': 0.21654510498338153, 'mae': 3.761037751248008, 'rmse': 4.611867856705597}


