You can work with your own data in the RelBench framework. This tutorial walks you through the steps to create custom datasets and tasks. Code in this notebook has been adapted from `relbench/datasets/f1.py` and `relbench/tasks/f1.py`.

# Custom Dataset

## Imports

In [6]:
import os

import numpy as np
import pandas as pd
import pooch

from relbench.base import Database, Dataset, Table
from relbench.utils import unzip_processor

## Overview

To define a custom dataset, we subclass the `relbench.base.Dataset` class. This requires specifying 3 things:
1. `val_timestamp` of type `pd.Timestamp`
2. `test_timestamp` of type `pd.Timestamp`
3. a `make_db` function which returns a `relbench.base.Database` object

These are described in further detail below.

### Temporal splitting

`val_timestamp` and `test_timestamp` define a unified temporal splitting for the dataset. All tasks defined on the dataset will provide train/val/test sets based on this splitting.

Only database rows upto `val_timestamp` should be used by the model to predict over the val set. Similarly, only database rows upto `test_timestamp` should be used by the model to predict over the test set. Rows after `test_timestamp` are only used for computing ground truth labels for the test set. Importantly, only rows upto `val_timestamp` can be used to obtain ground truth labels for the train set; and only rows upto `test_timestamp` can be used to obtain ground truth labels for the val set.

This is important to prevent temporal leakage of information, which we will see is an important consideration in many aspects of Relational Deep Learning.

### The make_db function

The raw data for your dataset can be in any form. You will need to preprocess your dataset into the RelBench format to work with it in RelBench. The `make_db` function is the place to do this.

Inside the `make_db` function we first download the raw files (or read from the local filesystem) and then create a `relbench.base.Database` object out of those. Thus, the `make_db` functions serves as documentation for your pre-processing steps, while also conveniently allowing you to develop and debug them within the RelBench framework.

### The Database object

The `relbench.base.Database` object is simply a collection of named `relbench.base.Table` objects. A `relbench.base.Table` object is instantiated by providing:
1. `df` of type `pd.DataFrame` representing the table content.
2. `fkey_col_to_pkey_table` dict to map foreign-key columns in this table to their primary-key table names.
3. `pkey_col`, the name of the primary key column, if any.
4. `time_col`, the name of the time column, if available.

The `time_col` denotes the creation time of a row in the database. If absent or `None`, the row is treated as if it was created at time `-inf`. Note that there can be other columns of type `pd.Timestamp` which are not necessarily the creation time. For example, there can be a `CreationDate` column in a `Users` table which would be suitable for `time_col`, but also a `DateOfBirth` column, which is better treated as an ordinary feature column.

## Annotated Sample Implementation

In [8]:
class F1Dataset(Dataset):
    ################################################################################
    # Choose the val_timestamp and test_timestamp carefully
    ################################################################################
    val_timestamp = pd.Timestamp("2005-01-01")
    test_timestamp = pd.Timestamp("2010-01-01")

    def make_db(self) -> Database:
        r"""Process the raw files into a database."""
        ################################################################################
        # The raw files are at this URL. You can use any URL or even local files.
        ################################################################################
        url = "https://relbench.stanford.edu/data/relbench-f1-raw.zip"

        path = pooch.retrieve(
            url,
            known_hash="2933348953b30aa9723b4831fea8071b336b74977bbcf1fb059da63a04f06eba",
            progressbar=True,
            processor=unzip_processor,
        )

        path = os.path.join(path, "raw")

        ################################################################################
        # Here, we read from the raw CSV files
        ################################################################################
        circuits = pd.read_csv(os.path.join(path, "circuits.csv"))
        drivers = pd.read_csv(os.path.join(path, "drivers.csv"))
        results = pd.read_csv(os.path.join(path, "results.csv"))
        races = pd.read_csv(os.path.join(path, "races.csv"))
        standings = pd.read_csv(os.path.join(path, "driver_standings.csv"))
        constructors = pd.read_csv(os.path.join(path, "constructors.csv"))
        constructor_results = pd.read_csv(os.path.join(path, "constructor_results.csv"))
        constructor_standings = pd.read_csv(
            os.path.join(path, "constructor_standings.csv")
        )
        qualifying = pd.read_csv(os.path.join(path, "qualifying.csv"))

        ################################################################################
        # It is important to understand the data.
        # This can point to columns which should be removed.
        # The most important of these is temporal leakage columns, which we will
        # discuss in detail later.
        ################################################################################

        # Remove columns that are irrelevant, leak time,
        # or have too many missing values

        # Drop the Wikipedia URL and some time columns with many missing values
        races.drop(
            columns=[
                "url",
                "fp1_date",
                "fp1_time",
                "fp2_date",
                "fp2_time",
                "fp3_date",
                "fp3_time",
                "quali_date",
                "quali_time",
                "sprint_date",
                "sprint_time",
            ],
            inplace=True,
        )

        # Drop the Wikipedia URL as it is unique for each row
        circuits.drop(
            columns=["url"],
            inplace=True,
        )

        # Drop the Wikipedia URL (unique) and number (803 / 857 are nulls)
        drivers.drop(
            columns=["number", "url"],
            inplace=True,
        )

        # Drop the positionText, time, fastestLapTime and fastestLapSpeed
        results.drop(
            columns=[
                "positionText",
                "time",
                "fastestLapTime",
                "fastestLapSpeed",
            ],
            inplace=True,
        )

        # Drop the positionText
        standings.drop(
            columns=["positionText"],
            inplace=True,
        )

        # Drop the Wikipedia URL
        constructors.drop(
            columns=["url"],
            inplace=True,
        )

        # Drop the positionText
        constructor_standings.drop(
            columns=["positionText"],
            inplace=True,
        )

        # Drop the status as it only contains two categories, and
        # only 17 rows have value 'D' (0.138%)
        constructor_results.drop(
            columns=["status"],
            inplace=True,
        )

        # Drop the time in qualifying 1, 2, and 3
        qualifying.drop(
            columns=["q1", "q2", "q3"],
            inplace=True,
        )

        ################################################################################
        # Make sure to properly process time columns into pd.Timestamp datatype.
        # Sometimes, you might need to handle timezone information carefully to do
        # this correctly.
        # If time information can be inferred for other tables, it might help to add
        # as it makes temporal sampling more effective.
        ################################################################################

        # replase missing data and combine date and time columns
        races["time"] = races["time"].replace(r"^\\N$", "00:00:00", regex=True)
        races["date"] = races["date"] + " " + races["time"]
        # Convert date column to pd.Timestamp
        races["date"] = pd.to_datetime(races["date"])

        # add time column to other tables
        results = results.merge(races[["raceId", "date"]], on="raceId", how="left")
        standings = standings.merge(races[["raceId", "date"]], on="raceId", how="left")
        constructor_results = constructor_results.merge(
            races[["raceId", "date"]], on="raceId", how="left"
        )
        constructor_standings = constructor_standings.merge(
            races[["raceId", "date"]], on="raceId", how="left"
        )

        qualifying = qualifying.merge(
            races[["raceId", "date"]], on="raceId", how="left"
        )

        # Subtract a day from the date to account for the fact
        # that the qualifying time is the day before the main race
        qualifying["date"] = qualifying["date"] - pd.Timedelta(days=1)

        ################################################################################
        # Make sure that the missing data has been parsed properly.
        # Following Pandas, we represent missing values with NaNs in the dataframe.
        ################################################################################

        # Replace "\N" with NaN in results tables
        results = results.replace(r"^\\N$", np.nan, regex=True)

        # Replace "\N" with NaN in circuits tables, especially
        # for the column `alt` which has 3 rows of "\N"
        circuits = circuits.replace(r"^\\N$", np.nan, regex=True)
        # Convert alt from string to float
        circuits["alt"] = circuits["alt"].astype(float)

        # Convert non-numeric values to NaN in the specified column
        results["rank"] = pd.to_numeric(results["rank"], errors="coerce")
        results["number"] = pd.to_numeric(results["number"], errors="coerce")
        results["grid"] = pd.to_numeric(results["grid"], errors="coerce")
        results["position"] = pd.to_numeric(results["position"], errors="coerce")
        results["points"] = pd.to_numeric(results["points"], errors="coerce")
        results["laps"] = pd.to_numeric(results["laps"], errors="coerce")
        results["milliseconds"] = pd.to_numeric(
            results["milliseconds"], errors="coerce"
        )
        results["fastestLap"] = pd.to_numeric(results["fastestLap"], errors="coerce")

        # Convert drivers date of birth to datetime
        drivers["dob"] = pd.to_datetime(drivers["dob"])

        ################################################################################
        # Here, we collect all tables in the database as relbench.base.Table objects.
        ################################################################################

        tables = {}

        tables["races"] = Table(
            df=pd.DataFrame(races),
            fkey_col_to_pkey_table={
                "circuitId": "circuits",
            },
            pkey_col="raceId",
            time_col="date",
        )

        tables["circuits"] = Table(
            df=pd.DataFrame(circuits),
            fkey_col_to_pkey_table={},
            pkey_col="circuitId",
            time_col=None,
        )

        tables["drivers"] = Table(
            df=pd.DataFrame(drivers),
            fkey_col_to_pkey_table={},
            pkey_col="driverId",
            time_col=None,
        )

        tables["results"] = Table(
            df=pd.DataFrame(results),
            fkey_col_to_pkey_table={
                "raceId": "races",
                "driverId": "drivers",
                "constructorId": "constructors",
            },
            pkey_col="resultId",
            time_col="date",
        )

        tables["standings"] = Table(
            df=pd.DataFrame(standings),
            fkey_col_to_pkey_table={"raceId": "races", "driverId": "drivers"},
            pkey_col="driverStandingsId",
            time_col="date",
        )

        tables["constructors"] = Table(
            df=pd.DataFrame(constructors),
            fkey_col_to_pkey_table={},
            pkey_col="constructorId",
            time_col=None,
        )

        tables["constructor_results"] = Table(
            df=pd.DataFrame(constructor_results),
            fkey_col_to_pkey_table={"raceId": "races", "constructorId": "constructors"},
            pkey_col="constructorResultsId",
            time_col="date",
        )

        tables["constructor_standings"] = Table(
            df=pd.DataFrame(constructor_standings),
            fkey_col_to_pkey_table={"raceId": "races", "constructorId": "constructors"},
            pkey_col="constructorStandingsId",
            time_col="date",
        )

        tables["qualifying"] = Table(
            df=pd.DataFrame(qualifying),
            fkey_col_to_pkey_table={
                "raceId": "races",
                "driverId": "drivers",
                "constructorId": "constructors",
            },
            pkey_col="qualifyId",
            time_col="date",
        )

        return Database(tables)