In [None]:
!pip install relbench

import relbench

relbench.__version__

You can work with your own data in the RelBench framework. This tutorial walks you through the steps to create custom datasets. Code in this notebook has been adapted from `relbench/datasets/f1.py` and `relbench/datasets/amazon.py`.

# Custom Dataset

## Imports

In [1]:
import os
import time

import numpy as np
import pandas as pd
import pooch
import pyarrow as pa
import pyarrow.json

from relbench.base import Database, Dataset, Table
from relbench.datasets import get_dataset, get_dataset_names, register_dataset
from relbench.utils import unzip_processor

## Overview

To define a custom dataset, we subclass the `relbench.base.Dataset` class. This requires specifying 3 things:
1. `val_timestamp` of type `pd.Timestamp`
2. `test_timestamp` of type `pd.Timestamp`
3. a `make_db` function which returns a `relbench.base.Database` object

These are described in further detail below.

### Temporal splitting

`val_timestamp` and `test_timestamp` define a unified temporal splitting for the dataset. All tasks defined on the dataset will provide train/val/test sets based on this splitting.

Only database rows upto `val_timestamp` should be used by the model to predict over the val set. Similarly, only database rows upto `test_timestamp` should be used by the model to predict over the test set. Rows after `test_timestamp` are only used for computing ground truth labels for the test set. Importantly, only rows upto `val_timestamp` can be used to obtain ground truth labels for the train set; and only rows upto `test_timestamp` can be used to obtain ground truth labels for the val set.

This is important to prevent temporal leakage of information, which we will see is an important consideration in many aspects of Relational Deep Learning.

### The make_db function

The raw data for your dataset can be in any form. You will need to preprocess your dataset into the RelBench format to work with it in RelBench. The `make_db` function is the place to do this.

Inside the `make_db` function we first download the raw files (or read from the local filesystem) and then create a `relbench.base.Database` object out of those. Thus, the `make_db` functions serves as documentation for your pre-processing steps, while also conveniently allowing you to develop and debug them within the RelBench framework.

#### Pkey/Fkey Reindexing

The intended usage is not to call the `make_db` function directly but to use the `get_db` function which internally calls `make_db` and adds a layer of other functionality such as caching.

Another important thing that `get_db` does is that it calls `db.reindex_pkeys_and_fkeys()` on the database `db` returned by `make_db`. This reindexes the primary- and foreign- key columns so that the primary keys columns are consecutive integers starting from 0. This makes some downstream logic in RelBench convenient to implement, as it can work under the unified assumption that the pkeys and fkeys are integers, and that too sequential.

If you want to preserve the original pkey values, either because you believe they can be used as features for predictive tasks, or because you would like to cross-reference the prediction results with the original data source, simply add a duplicate column without marking it as pkey_col. The model designer is free to decide whether to include this duplicate column as input to the model or not.

### The Database object

The `relbench.base.Database` object is simply a collection of named `relbench.base.Table` objects. A `relbench.base.Table` object is instantiated by providing:
1. `df` of type `pd.DataFrame` representing the table content.
2. `fkey_col_to_pkey_table` dict to map foreign-key columns in this table to their primary-key table names.
3. `pkey_col`, the name of the primary key column, if any (else `None`).
4. `time_col`, the name of the time column, if available (else `None`).

The `time_col` denotes the creation time of a row in the database. If absent or `None`, the row is treated as if it was created at time `-inf`. Note that there can be other columns of type `pd.Timestamp` which are not necessarily the creation time. For example, there can be a `CreationDate` column in a `Users` table which would be suitable for `time_col`, but also a `DateOfBirth` column, which is better treated as an ordinary feature column.

## Annotated Sample Implementation

In [2]:
class F1Dataset(Dataset):
    ################################################################################
    # Choose the val_timestamp and test_timestamp carefully
    ################################################################################
    val_timestamp = pd.Timestamp("2005-01-01")
    test_timestamp = pd.Timestamp("2010-01-01")

    def make_db(self) -> Database:
        r"""Process the raw files into a database."""
        ################################################################################
        # The raw files are at this URL. You can use any URL or even local files.
        ################################################################################
        url = "https://relbench.stanford.edu/data/relbench-f1-raw.zip"

        path = pooch.retrieve(
            url,
            known_hash="2933348953b30aa9723b4831fea8071b336b74977bbcf1fb059da63a04f06eba",
            progressbar=True,
            processor=unzip_processor,
        )

        path = os.path.join(path, "raw")

        ################################################################################
        # Here, we read from the raw CSV files
        ################################################################################
        circuits = pd.read_csv(os.path.join(path, "circuits.csv"))
        drivers = pd.read_csv(os.path.join(path, "drivers.csv"))
        results = pd.read_csv(os.path.join(path, "results.csv"))
        races = pd.read_csv(os.path.join(path, "races.csv"))
        standings = pd.read_csv(os.path.join(path, "driver_standings.csv"))
        constructors = pd.read_csv(os.path.join(path, "constructors.csv"))
        constructor_results = pd.read_csv(os.path.join(path, "constructor_results.csv"))
        constructor_standings = pd.read_csv(
            os.path.join(path, "constructor_standings.csv")
        )
        qualifying = pd.read_csv(os.path.join(path, "qualifying.csv"))

        ################################################################################
        # It is important to understand the data.
        # This can point to columns which should be removed.
        # The most important of these is temporal leakage columns, which we will
        # discuss in detail later.
        ################################################################################

        # Remove columns that are irrelevant, leak time,
        # or have too many missing values

        # Drop the Wikipedia URL and some time columns with many missing values
        races.drop(
            columns=[
                "url",
                "fp1_date",
                "fp1_time",
                "fp2_date",
                "fp2_time",
                "fp3_date",
                "fp3_time",
                "quali_date",
                "quali_time",
                "sprint_date",
                "sprint_time",
            ],
            inplace=True,
        )

        # Drop the Wikipedia URL as it is unique for each row
        circuits.drop(
            columns=["url"],
            inplace=True,
        )

        # Drop the Wikipedia URL (unique) and number (803 / 857 are nulls)
        drivers.drop(
            columns=["number", "url"],
            inplace=True,
        )

        # Drop the positionText, time, fastestLapTime and fastestLapSpeed
        results.drop(
            columns=[
                "positionText",
                "time",
                "fastestLapTime",
                "fastestLapSpeed",
            ],
            inplace=True,
        )

        # Drop the positionText
        standings.drop(
            columns=["positionText"],
            inplace=True,
        )

        # Drop the Wikipedia URL
        constructors.drop(
            columns=["url"],
            inplace=True,
        )

        # Drop the positionText
        constructor_standings.drop(
            columns=["positionText"],
            inplace=True,
        )

        # Drop the status as it only contains two categories, and
        # only 17 rows have value 'D' (0.138%)
        constructor_results.drop(
            columns=["status"],
            inplace=True,
        )

        # Drop the time in qualifying 1, 2, and 3
        qualifying.drop(
            columns=["q1", "q2", "q3"],
            inplace=True,
        )

        ################################################################################
        # Make sure to properly process time columns into pd.Timestamp datatype.
        # Sometimes, you might need to handle timezone information carefully to do
        # this correctly.
        # If time information can be inferred for other tables, it might help to add
        # as it makes temporal sampling more effective.
        ################################################################################

        # replase missing data and combine date and time columns
        races["time"] = races["time"].replace(r"^\\N$", "00:00:00", regex=True)
        races["date"] = races["date"] + " " + races["time"]
        # Convert date column to pd.Timestamp
        races["date"] = pd.to_datetime(races["date"])

        # add time column to other tables
        results = results.merge(races[["raceId", "date"]], on="raceId", how="left")
        standings = standings.merge(races[["raceId", "date"]], on="raceId", how="left")
        constructor_results = constructor_results.merge(
            races[["raceId", "date"]], on="raceId", how="left"
        )
        constructor_standings = constructor_standings.merge(
            races[["raceId", "date"]], on="raceId", how="left"
        )

        qualifying = qualifying.merge(
            races[["raceId", "date"]], on="raceId", how="left"
        )

        # Subtract a day from the date to account for the fact
        # that the qualifying time is the day before the main race
        qualifying["date"] = qualifying["date"] - pd.Timedelta(days=1)

        ################################################################################
        # Make sure that the missing data has been parsed properly.
        # Following Pandas, we represent missing values with NaNs in the dataframe.
        ################################################################################

        # Replace "\N" with NaN in results tables
        results = results.replace(r"^\\N$", np.nan, regex=True)

        # Replace "\N" with NaN in circuits tables, especially
        # for the column `alt` which has 3 rows of "\N"
        circuits = circuits.replace(r"^\\N$", np.nan, regex=True)
        # Convert alt from string to float
        circuits["alt"] = circuits["alt"].astype(float)

        # Convert non-numeric values to NaN in the specified column
        results["rank"] = pd.to_numeric(results["rank"], errors="coerce")
        results["number"] = pd.to_numeric(results["number"], errors="coerce")
        results["grid"] = pd.to_numeric(results["grid"], errors="coerce")
        results["position"] = pd.to_numeric(results["position"], errors="coerce")
        results["points"] = pd.to_numeric(results["points"], errors="coerce")
        results["laps"] = pd.to_numeric(results["laps"], errors="coerce")
        results["milliseconds"] = pd.to_numeric(
            results["milliseconds"], errors="coerce"
        )
        results["fastestLap"] = pd.to_numeric(results["fastestLap"], errors="coerce")

        # Convert drivers date of birth to datetime
        drivers["dob"] = pd.to_datetime(drivers["dob"])

        ################################################################################
        # Here, we collect all tables in the database as relbench.base.Table objects.
        ################################################################################

        tables = {}

        tables["races"] = Table(
            df=pd.DataFrame(races),
            fkey_col_to_pkey_table={
                "circuitId": "circuits",
            },
            pkey_col="raceId",
            time_col="date",
        )

        tables["circuits"] = Table(
            df=pd.DataFrame(circuits),
            fkey_col_to_pkey_table={},
            pkey_col="circuitId",
            time_col=None,
        )

        tables["drivers"] = Table(
            df=pd.DataFrame(drivers),
            fkey_col_to_pkey_table={},
            pkey_col="driverId",
            time_col=None,
        )

        tables["results"] = Table(
            df=pd.DataFrame(results),
            fkey_col_to_pkey_table={
                "raceId": "races",
                "driverId": "drivers",
                "constructorId": "constructors",
            },
            pkey_col="resultId",
            time_col="date",
        )

        tables["standings"] = Table(
            df=pd.DataFrame(standings),
            fkey_col_to_pkey_table={"raceId": "races", "driverId": "drivers"},
            pkey_col="driverStandingsId",
            time_col="date",
        )

        tables["constructors"] = Table(
            df=pd.DataFrame(constructors),
            fkey_col_to_pkey_table={},
            pkey_col="constructorId",
            time_col=None,
        )

        tables["constructor_results"] = Table(
            df=pd.DataFrame(constructor_results),
            fkey_col_to_pkey_table={"raceId": "races", "constructorId": "constructors"},
            pkey_col="constructorResultsId",
            time_col="date",
        )

        tables["constructor_standings"] = Table(
            df=pd.DataFrame(constructor_standings),
            fkey_col_to_pkey_table={"raceId": "races", "constructorId": "constructors"},
            pkey_col="constructorStandingsId",
            time_col="date",
        )

        tables["qualifying"] = Table(
            df=pd.DataFrame(qualifying),
            fkey_col_to_pkey_table={
                "raceId": "races",
                "driverId": "drivers",
                "constructorId": "constructors",
            },
            pkey_col="qualifyId",
            time_col="date",
        )

        return Database(tables)

## Using the custom dataset

You can use the custom `Dataset` subclass directly as follows:

In [3]:
f1_dataset = F1Dataset(cache_dir="./cache/f1")
f1_dataset

F1Dataset()

In [4]:
f1_db = f1_dataset.get_db()
f1_db

Making Database object from scratch...
(You can also use `get_dataset(..., download=True)` for datasets prepared by the RelBench team.)
Done in 0.24 seconds.
Caching Database object to ./cache/f1/db...
Done in 0.05 seconds.


Database()

In [5]:
f1_db.table_dict["races"]

Table(df=
     raceId  year  round  circuitId                  name                date  \
0         0  1950      1          8    British Grand Prix 1950-05-13 00:00:00   
1         1  1950      2          5     Monaco Grand Prix 1950-05-21 00:00:00   
2         2  1950      3         18      Indianapolis 500 1950-05-30 00:00:00   
3         3  1950      4         65      Swiss Grand Prix 1950-06-04 00:00:00   
4         4  1950      5         12    Belgian Grand Prix 1950-06-18 00:00:00   
..      ...   ...    ...        ...                   ...                 ...   
815     815  2009     13         13    Italian Grand Prix 2009-09-13 12:00:00   
816     816  2009     14         14  Singapore Grand Prix 2009-09-27 12:00:00   
817     817  2009     15         21   Japanese Grand Prix 2009-10-04 05:00:00   
818     818  2009     16         17  Brazilian Grand Prix 2009-10-18 16:00:00   
819     819  2009     17         23  Abu Dhabi Grand Prix 2009-11-01 11:00:00   

         time  
0

### Development advice

While developing `make_db` code, it is suggested to call `make_db` directly to avoid caching artifacts while debugging. Note that this will give the full database, unlike `get_db()` which removes rows after `test_timestamp`.

In [6]:
f1_full_db = f1_dataset.make_db()
f1_full_db

Database()

In [7]:
f1_full_db.table_dict["races"]

Table(df=
      raceId  year  round  circuitId                      name  \
0          1  2009      1          1     Australian Grand Prix   
1          2  2009      2          2      Malaysian Grand Prix   
2          3  2009      3         17        Chinese Grand Prix   
3          4  2009      4          3        Bahrain Grand Prix   
4          5  2009      5          4        Spanish Grand Prix   
...      ...   ...    ...        ...                       ...   
1096    1116  2023     18         69  United States Grand Prix   
1097    1117  2023     19         32    Mexico City Grand Prix   
1098    1118  2023     20         18      São Paulo Grand Prix   
1099    1119  2023     21         80      Las Vegas Grand Prix   
1100    1120  2023     22         24      Abu Dhabi Grand Prix   

                    date      time  
0    2009-03-29 06:00:00  06:00:00  
1    2009-04-05 09:00:00  09:00:00  
2    2009-04-19 07:00:00  07:00:00  
3    2009-04-26 12:00:00  12:00:00  
4    2009-05

### Registering your custom dataset

You can also register your dataset to make it available to `relbench.datasets.get_dataset` and use standardized caching locations (`~/.cache/relbench/<dataset-name>`).

In [8]:
register_dataset("rel-custom_f1", F1Dataset)
get_dataset_names()

['rel-amazon',
 'rel-avito',
 'rel-event',
 'rel-f1',
 'rel-hm',
 'rel-stack',
 'rel-trial',
 'rel-custom_f1']

In [9]:
reg_f1_dataset = get_dataset("rel-custom_f1")
reg_f1_dataset

F1Dataset()

## Advanced Customization

You can also add an `__init__` function to the `Dataset` subclass. This can allow customizing the returned `Dataset` object with `args` and `kwargs`. An example of this can be seen in `relbench/datasets/amazon.py`, where the same `AmazonDataset` class can be used for the Books subset of the raw Amazon Reviews dataset or the Fashion subset (this can easily be extended to the other subsets too).

We reproduce the code snippet below:

In [10]:
class AmazonDataset(Dataset):
    val_timestamp = pd.Timestamp("2015-10-01")
    test_timestamp = pd.Timestamp("2016-01-01")

    url_prefix = "https://datarepo.eng.ucsd.edu/mcauley_group/data/amazon_v2"
    _category_to_url_key = {"books": "Books", "fashion": "AMAZON_FASHION"}

    known_hashes = {
        "meta_Books.json.gz": "80ed7ac64f5967a140401e8d7bf0587d2e5087492de9e94077a7f554ef6b18f0",
        "Books_5.json.gz": "ded924d1d1a22bae499f1a1c2b39397104304bfdb24232a2dd0aa50e89cd37bb",
    }

    def __init__(
        self,
        category: str = "books",
        use_5_core: bool = True,
        cache_dir: str = None,
    ):
        self.category = category
        self.use_5_core = use_5_core
        super().__init__(cache_dir=cache_dir)

    def make_db(self) -> Database:
        r"""Process the raw files into a database."""

        ### product table ###

        url_key = self._category_to_url_key[self.category]
        url = f"{self.url_prefix}/metaFiles2/meta_{url_key}.json.gz"
        path = pooch.retrieve(
            url,
            known_hash=self.known_hashes.get(url.split("/")[-1], None),
            progressbar=True,
            processor=pooch.Decompress(),
        )
        print(f"reading product info from {path}...")
        tic = time.time()
        ptable = pa.json.read_json(
            path,
            parse_options=pa.json.ParseOptions(
                explicit_schema=pa.schema(
                    [
                        ("asin", pa.string()),
                        ("category", pa.list_(pa.string())),
                        ("brand", pa.string()),
                        ("title", pa.string()),
                        ("description", pa.list_(pa.string())),
                        ("price", pa.string()),
                    ]
                ),
                unexpected_field_behavior="ignore",
            ),
        )
        toc = time.time()
        print(f"done in {toc - tic:.2f} seconds.")

        print("converting to pandas dataframe...")
        tic = time.time()
        pdf = ptable.to_pandas()
        toc = time.time()
        print(f"done in {toc - tic:.2f} seconds.")

        print("processing product info...")
        tic = time.time()

        # asin is not intuitive / recognizable
        pdf.rename(columns={"asin": "product_id"}, inplace=True)

        # somehow the raw data has duplicate product_id's
        pdf.drop_duplicates(subset=["product_id"], inplace=True)

        # price is like "$x,xxx.xx", "$xx.xx", or "$xx.xx - $xx.xx", or garbage html
        # if it's a range, we take the first value
        pdf.loc[:, "price"] = pdf["price"].apply(
            lambda x: (
                None
                if x is None or x == "" or x[0] != "$"
                else float(x.split(" ")[0][1:].replace(",", ""))
            )
        )

        # remove products with missing price
        pdf = pdf.dropna(subset=["price"])

        pdf.loc[:, "category"] = pdf["category"].apply(
            lambda x: None if x is None or len(x) == 0 else x
        )

        # some rows are stored as ['cat1' 'cat2' 'cat3' ...]
        # this function maps them to ['cat1', 'cat2', 'cat3', ...] (list of strings)
        # since otherwise pytorch-frame breaks
        def fix_column(value):
            if isinstance(value, str):
                return value  # Already a string
            elif value is None:
                return None
            else:
                return list(value)

        pdf["category"] = pdf["category"].apply(fix_column)

        # description is either [] or ["some description"]
        pdf.loc[:, "description"] = pdf["description"].apply(
            lambda x: None if x is None or len(x) == 0 else x[0]
        )

        toc = time.time()
        print(f"done in {toc - tic:.2f} seconds.")

        ### review table ###

        if self.use_5_core:
            url = f"{self.url_prefix}/categoryFilesSmall/{url_key}_5.json.gz"
        else:
            url = f"{self.url_prefix}/categoryFiles/{url_key}.json.gz"
        path = pooch.retrieve(
            url,
            known_hash=self.known_hashes.get(url.split("/")[-1], None),
            progressbar=True,
            processor=pooch.Decompress(),
        )
        print(f"reading review and customer info from {path}...")
        tic = time.time()
        rtable = pa.json.read_json(
            path,
            parse_options=pa.json.ParseOptions(
                explicit_schema=pa.schema(
                    [
                        ("unixReviewTime", pa.int32()),
                        ("reviewerID", pa.string()),
                        ("reviewerName", pa.string()),
                        ("asin", pa.string()),
                        ("overall", pa.float32()),
                        ("verified", pa.bool_()),
                        ("reviewText", pa.string()),
                        ("summary", pa.string()),
                    ]
                ),
                unexpected_field_behavior="ignore",
            ),
        )
        toc = time.time()
        print(f"done in {toc - tic:.2f} seconds.")

        print("converting to pandas dataframe...")
        tic = time.time()
        rdf = rtable.to_pandas()
        toc = time.time()
        print(f"done in {toc - tic:.2f} seconds.")

        print("processing review and customer info...")
        tic = time.time()

        rdf.rename(
            columns={
                "unixReviewTime": "review_time",
                "reviewerID": "customer_id",
                "reviewerName": "customer_name",
                "asin": "product_id",
                "overall": "rating",
                "reviewText": "review_text",
            },
            inplace=True,
        )

        rdf.loc[:, "review_time"] = pd.to_datetime(rdf["review_time"], unit="s")

        toc = time.time()
        print(f"done in {toc - tic:.2f} seconds.")

        print("keeping only products common to product and review tables...")
        tic = time.time()
        plist = list(set(pdf["product_id"]) & set(rdf["product_id"]))
        pdf.query("product_id == @plist", inplace=True)
        rdf.query("product_id == @plist", inplace=True)
        toc = time.time()
        print(f"done in {toc - tic:.2f} seconds.")

        print("extracting customer table...")
        tic = time.time()
        cdf = (
            rdf[["customer_id", "customer_name"]]
            .drop_duplicates(subset=["customer_id"])
            .copy()
        )
        rdf.drop(columns=["customer_name"], inplace=True)
        toc = time.time()
        print(f"done in {toc - tic:.2f} seconds.")

        db = Database(
            table_dict={
                "product": Table(
                    df=pdf,
                    fkey_col_to_pkey_table={},
                    pkey_col="product_id",
                    time_col=None,
                ),
                "customer": Table(
                    df=cdf,
                    fkey_col_to_pkey_table={},
                    pkey_col="customer_id",
                    time_col=None,
                ),
                "review": Table(
                    df=rdf,
                    fkey_col_to_pkey_table={
                        "customer_id": "customer",
                        "product_id": "product",
                    },
                    pkey_col=None,
                    time_col="review_time",
                ),
            }
        )

        db = db.from_(pd.Timestamp("2008-01-01"))

        return db

Now to make the fashion dataset:

In [11]:
amazon_fashion_dataset = AmazonDataset(category="fashion", use_5_core=True)
amazon_fashion_dataset

AmazonDataset()

In [12]:
amazon_fashion_db = amazon_fashion_dataset.get_db()
amazon_fashion_db

Making Database object from scratch...
(You can also use `get_dataset(..., download=True)` for datasets prepared by the RelBench team.)
reading product info from /lfs/ampere4/0/ranjanr/.cache/pooch/b70e8d295f37e2465ea17803b6d1e11d-meta_AMAZON_FASHION.json.gz.decomp...
done in 0.07 seconds.
converting to pandas dataframe...
done in 0.16 seconds.
processing product info...
done in 0.11 seconds.
reading review and customer info from /lfs/ampere4/0/ranjanr/.cache/pooch/26323778935ec86761e2c260cea27601-AMAZON_FASHION_5.json.gz.decomp...
done in 0.01 seconds.
converting to pandas dataframe...
done in 0.00 seconds.
processing review and customer info...
done in 0.00 seconds.
keeping only products common to product and review tables...
done in 0.01 seconds.
extracting customer table...
done in 0.00 seconds.
Done in 0.38 seconds.


Database()

To register this dataset, pass in the `args` and `kwargs` too.

In [13]:
register_dataset(
    "rel-amazon_fashion", AmazonDataset, category="fashion", use_5_core=True
)
get_dataset_names()

['rel-amazon',
 'rel-avito',
 'rel-event',
 'rel-f1',
 'rel-hm',
 'rel-stack',
 'rel-trial',
 'rel-custom_f1',
 'rel-amazon_fashion']

In [14]:
registered_amazon_fashion_dataset = get_dataset("rel-amazon_fashion")
registered_amazon_fashion_dataset

AmazonDataset()

Note that the registry does not persist beyond the running Python process. This means that to run the baseline scripts at `examples/` in the RelBench repo, you will first have to modify the script to register your own dataset before `get_dataset` is called in the script.

## Next steps

To use the custom dataset within the RelBench framework, you would want to define custom tasks on it. See this tutorial for how to do that: [custom_task.ipynb](custom_task.ipynb)

Please also consider sharing your dataset with the community by getting it added to the RelBench dataset repository. Check out our [CONTRIBUTING.md](https://github.com/snap-stanford/relbench/blob/main/CONTRIBUTING.md) for how to do this.