In [None]:
!pip install relbench

import relbench

relbench.__version__

You can define your own tasks over the existing datasets or your own custom dataset (see tutorial at [custom_dataset.ipynb](custom_dataset.ipynb)). This tutorial shows how.

Code in this notebook has been adapted from `relbench/tasks/f1.py`.

# Custom Entity Classification Task

For illustration, we focus on creating an entity-classification task on the `rel-f1` dataset. Creating entity-regression or recommendation tasks are similar. See the RelBench code (e.g., the `relbench.base.EntityTask` and `relbench.base.RecommendationTask` classes) for more documentations and `relbench/tasks/*.py` for examples.

## Imports

In [1]:
import duckdb
import pandas as pd

from relbench.base import Database, EntityTask, Table, TaskType
from relbench.datasets import get_dataset
from relbench.metrics import accuracy, average_precision, f1, roc_auc
from relbench.tasks import get_task, get_task_names, register_task

## Overview

RelBench supports entity tasks which involve classification or regression over a single entity and recommendation tasks which involve predicting the target entity for a given source entity. For the former, subclass the `relbench.base.EntityTask` class and for the latter use `relbench.base.RecommendationTask`. Both `EntityTask` and `RecommendationTask` are themselves subclasses of `BaseTask` where you can find documentation for attributes shared by both kinds of tasks.

To define a custom task, you subclass either `EntityTask` or `RecommendationTask` and provide the attributes required (such as `timedelta`). See the code for these classes or the `BaseTask` class for documentation on the various attributes that need to be set.

### The make_table function

Labels for RelBench tasks can be constructed from the historical records in the database itself. Thus, to define a task we want to specify how to construct the labels for any given timestamp. For computational efficiency we may want to construct labels for many timestamps together. This facility is provided by the `make_db` function which inputs a `Database` and a `pd.Series` of `pd.Timestamp` values and returns the task `Table`.

In RelBench, we express tasks via SQL queries over the database (using `duckdb`) for efficiency (query-optimization, parallelization, etc.). You are free to use something else (e.g., Pandas).

## Annotated Sample Implementation

In [2]:
class DriverDNFTask(EntityTask):
    ################################################################################
    # Use docstrings to describe the task
    ################################################################################
    r"""Predict the if each driver will DNF (not finish) a race in the next 1 month."""

    ################################################################################
    # Fill out the task attributes
    ################################################################################
    task_type = TaskType.BINARY_CLASSIFICATION
    entity_col = "driverId"
    entity_table = "drivers"
    time_col = "date"
    target_col = "did_not_finish"
    timedelta = pd.Timedelta(days=30)
    metrics = [average_precision, accuracy, f1, roc_auc]
    num_eval_timestamps = 40

    def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Table:
        ################################################################################
        # Dataframes in the local context are accessible in the SQL query via duckdb
        ################################################################################

        timestamp_df = pd.DataFrame({"timestamp": timestamps})

        results = db.table_dict["results"].df
        drivers = db.table_dict["drivers"].df
        races = db.table_dict["races"].df

        ################################################################################
        # This SQL query computes the labels for date, driverID pairs
        ################################################################################
        df = duckdb.sql(
            f"""
                SELECT
                    t.timestamp as date,
                    dri.driverId as driverId,
                    CASE
                        WHEN MAX(CASE WHEN re.statusId != 1 THEN 1 ELSE 0 END) = 1 THEN 0
                        ELSE 1
                    END AS did_not_finish
                FROM
                    timestamp_df t
                LEFT JOIN
                    results re
                ON
                    re.date <= t.timestamp + INTERVAL '{self.timedelta}'
                    and re.date  > t.timestamp
                LEFT JOIN
                    drivers dri
                ON
                    re.driverId = dri.driverId
                WHERE
                    dri.driverId IN (
                        SELECT DISTINCT driverId
                        FROM results
                        WHERE date > t.timestamp - INTERVAL '1 year'
                    )
                GROUP BY t.timestamp, dri.driverId

            ;
            """
        ).df()

        ################################################################################
        # The task table is expressed via a Table object, same as used in Database
        ################################################################################
        return Table(
            df=df,
            fkey_col_to_pkey_table={self.entity_col: self.entity_table},
            pkey_col=None,
            time_col=self.time_col,
        )

## Using the custom task

Tasks are defined on a `Dataset` object. You can use a RelBench dataset or even your custom one. The custom `BaseTask` subclass can be used directly as follows:

In [3]:
f1_dataset = get_dataset("rel-f1")
f1_dataset

F1Dataset()

In [4]:
driver_dnf_task = DriverDNFTask(f1_dataset, cache_dir="./cache/driver_dnf")
driver_dnf_task

DriverDNFTask(dataset=F1Dataset())

In [5]:
driver_dnf_task.get_table("train")

Making task table for train split from scratch...
(You can also use `get_task(..., download=True)` for tasks prepared by the RelBench team.)
Loading Database object from /afs/cs.stanford.edu/u/ranjanr/.cache/relbench/rel-f1/db...
Done in 0.04 seconds.
Loading Database object from /afs/cs.stanford.edu/u/ranjanr/.cache/relbench/rel-f1/db...
Done in 0.02 seconds.
Done in 0.32 seconds.


Table(df=
            date  driverId  did_not_finish
0     2004-08-04        20               0
1     2004-08-04        12               0
2     2004-07-05        10               1
3     2004-07-05        47               1
4     2004-06-05        31               0
...          ...       ...             ...
11406 1977-04-28       228               0
11407 1977-06-27       259               0
11408 1957-05-13       611               1
11409 1955-05-24       611               0
11410 1954-06-28       683               0

[11411 rows x 3 columns],
  fkey_col_to_pkey_table={'driverId': 'drivers'},
  pkey_col=None,
  time_col=date)

### Development advice

While developing `make_table` code, it is suggested to call `make_table` directly (instead of `get_table`) to avoid caching artifacts while debugging. Alternatively, you can call `_get_table` which is an uncached version of `get_table`.

In [6]:
driver_dnf_task._get_table("train")

Table(df=
            date  driverId  did_not_finish
0     2004-08-04        20               0
1     2004-08-04        12               0
2     2004-07-05        10               1
3     2004-07-05        47               1
4     2004-06-05        31               0
...          ...       ...             ...
11406 1977-04-28       228               0
11407 1977-06-27       259               0
11408 1957-05-13       611               1
11409 1955-05-24       611               0
11410 1954-06-28       683               0

[11411 rows x 3 columns],
  fkey_col_to_pkey_table={'driverId': 'drivers'},
  pkey_col=None,
  time_col=date)

### Registering your custom task

You can also register your task to make it available to `relbench.tasks.get_task` and use standardized caching locations (`~/.cache/relbench/<dataset-name>/tasks/<task-name>`). Note that the dataset should already be registered for this.

In [7]:
register_task("rel-f1", "custom_driver-dnf", DriverDNFTask)
get_task_names("rel-f1")

['driver-position', 'driver-dnf', 'driver-top3', 'custom_driver-dnf']

In [8]:
get_task("rel-f1", "custom_driver-dnf")

DriverDNFTask(dataset=F1Dataset())

Note that the registry does not persist beyond the running Python process. This means that to run the baseline scripts at `examples/` in the RelBench repo, you will first have to modify the script to register your own task before `get_task` is called in the script.

## Advanced Customization

You can also add an `__init__` function to the `BaseTask` subclass. This can allow customizing the returned `BaseTask` object with `args` and `kwargs`.

Another form of customization is overriding the `_get_table` method directly. This can allow expressing tasks where the labels cannot be computed from the database alone, or splitting is not purely temporal.

## Next steps

Please also consider sharing your custom task with the community by getting it added to the RelBench task repository. Check out our [CONTRIBUTING.md](https://github.com/snap-stanford/relbench/blob/main/CONTRIBUTING.md) for how to do this.