In [1]:
from collections import defaultdict
from itertools import product

import pandas as pd
import torch

import roach
from relbench.data import TaskType
from relbench.datasets import get_dataset_names
from relbench.tasks import get_task_names, get_task

In [2]:
all_stores = roach.scan("relbench/2024-07-01")

In [3]:
len(all_stores)

369

In [4]:
all_stores[-1]["__roach__"]

{'project': 'relbench/2024-07-01',
 'timestamp': 1719951248879018973,
 'caller_file': 'gnn_node.py',
 'done': True}

In [5]:
def wrap(name):
    return r"\texttt{" + name + r"}"

In [6]:
txt = {
    "val": "Val",
    "test": "Test",
}

# classification

In [18]:
metric = "roc_auc"
higher_is_better = True

In [19]:
table_data = []
for dataset in get_dataset_names():
    for task in get_task_names(dataset):
        task_obj = get_task(dataset, task)
        if task_obj.task_type.value not in [
            TaskType.BINARY_CLASSIFICATION.value,
            TaskType.MULTICLASS_CLASSIFICATION.value,
        ]:
            continue
        for script in [
            "gnn_node",
            "lightgbm_baseline",
        ]:
            for split in [
                "val",
                "test",
            ]:
                vals = []
                for seed in range(5):
                    stores = list(
                        filter(
                            lambda store: store["__roach__"]["caller_file"] == f"{script}.py"
                            and store["args"]["dataset"] == dataset
                            and store["args"]["task"] == task
                            and store["args"]["seed"] == seed,
                            all_stores,
                        )
                    )
                    try:
                        store = stores[-1]
                        val = store[split][metric]
                        vals.append(val)
                    except IndexError:
                        # val = float("nan")
                        # vals.append(val)
                        pass
                val = torch.tensor(vals)
                mean = val.mean().item()
                std = val.std().item()
                record = {
                    "dataset": dataset,
                    "task": task,
                    "script": script,
                    "split": split,
                    "mean": mean,
                    "std": std,
                }
                table_data.append(record)

In [20]:
tex_tab = defaultdict(dict)
for rec in table_data:
    dataset = rec["dataset"]
    task = rec["task"]
    script = rec["script"]
    split = rec["split"]
    mean = rec["mean"]
    std = rec["std"]

    is_best = True
    filter_data = filter(
        lambda r: r["dataset"] == dataset and r["task"] == task and r["split"] == split,
        table_data,
    )
    for comp_rec in filter_data:
        comp_mean = comp_rec["mean"]
        comp_std = comp_rec["std"]
        if higher_is_better:
            if mean + std < comp_mean - comp_std:
                is_best = False
        else:
            if mean - std > comp_mean + comp_std:
                is_best = False
    opt_bm_open = r"\bm{" if is_best else ""
    opt_bm_close = r"}" if is_best else ""
    tex_val = (
        r"$"
        + opt_bm_open
        + f"{mean * 100:.2f}"
        + opt_bm_close
        + r"_{\pm "
        + f"{std * 100:.2f}"
        + r"}$"
    )

    tex_tab[script][(wrap(dataset), wrap(task), txt[split])] = tex_val
tex_df = pd.DataFrame(tex_tab)
tex_df

Unnamed: 0,Unnamed: 1,Unnamed: 2,gnn_node,lightgbm_baseline
\texttt{rel-amazon},\texttt{user-churn},Val,$\bm{70.47}_{\pm 0.02}$,$52.21_{\pm 0.04}$
\texttt{rel-amazon},\texttt{user-churn},Test,$\bm{70.40}_{\pm 0.08}$,$52.25_{\pm 0.08}$
\texttt{rel-amazon},\texttt{item-churn},Val,$\bm{82.41}_{\pm 0.03}$,$62.31_{\pm 0.16}$
\texttt{rel-amazon},\texttt{item-churn},Test,$\bm{82.79}_{\pm 0.01}$,$62.75_{\pm 0.21}$
\texttt{rel-avito},\texttt{user-visits},Val,$\bm{69.66}_{\pm 0.09}$,$53.23_{\pm 0.07}$
\texttt{rel-avito},\texttt{user-visits},Test,$\bm{66.06}_{\pm 0.16}$,$53.03_{\pm 0.26}$
\texttt{rel-avito},\texttt{user-clicks},Val,$\bm{65.48}_{\pm 0.89}$,$55.90_{\pm 0.26}$
\texttt{rel-avito},\texttt{user-clicks},Test,$\bm{67.26}_{\pm 1.18}$,$54.08_{\pm 0.56}$
\texttt{rel-event},\texttt{user-repeat},Val,$\bm{68.72}_{\pm 0.94}$,$\bm{68.19}_{\pm 0.32}$
\texttt{rel-event},\texttt{user-repeat},Test,$\bm{77.55}_{\pm 0.75}$,$67.69_{\pm 2.85}$


In [21]:
tex = tex_df.to_latex()
tex = tex.replace(r"\multirow[t]", r"\multirow[c]")
tex = tex.replace(r"\cline", r"\cmidrule")
tex = tex.replace(r"\cmidrule{1-5} \cmidrule{2-5}", r"\cmidrule{1-5}")
print(tex)

\begin{tabular}{lllll}
\toprule
 &  &  & gnn_node & lightgbm_baseline \\
\midrule
\multirow[c]{4}{*}{\texttt{rel-amazon}} & \multirow[c]{2}{*}{\texttt{user-churn}} & Val & $\bm{70.47}_{\pm 0.02}$ & $52.21_{\pm 0.04}$ \\
 &  & Test & $\bm{70.40}_{\pm 0.08}$ & $52.25_{\pm 0.08}$ \\
\cmidrule{2-5}
 & \multirow[c]{2}{*}{\texttt{item-churn}} & Val & $\bm{82.41}_{\pm 0.03}$ & $62.31_{\pm 0.16}$ \\
 &  & Test & $\bm{82.79}_{\pm 0.01}$ & $62.75_{\pm 0.21}$ \\
\cmidrule{1-5}
\multirow[c]{4}{*}{\texttt{rel-avito}} & \multirow[c]{2}{*}{\texttt{user-visits}} & Val & $\bm{69.66}_{\pm 0.09}$ & $53.23_{\pm 0.07}$ \\
 &  & Test & $\bm{66.06}_{\pm 0.16}$ & $53.03_{\pm 0.26}$ \\
\cmidrule{2-5}
 & \multirow[c]{2}{*}{\texttt{user-clicks}} & Val & $\bm{65.48}_{\pm 0.89}$ & $55.90_{\pm 0.26}$ \\
 &  & Test & $\bm{67.26}_{\pm 1.18}$ & $54.08_{\pm 0.56}$ \\
\cmidrule{1-5}
\multirow[c]{4}{*}{\texttt{rel-event}} & \multirow[c]{2}{*}{\texttt{user-repeat}} & Val & $\bm{68.72}_{\pm 0.94}$ & $\bm{68.19}_{\pm 0.32}$

# regression

In [22]:
metric = "mae"
higher_is_better = False

In [23]:
table_data = []
for dataset in get_dataset_names():
    for task in get_task_names(dataset):
        task_obj = get_task(dataset, task)
        if task_obj.task_type.value != TaskType.REGRESSION.value:
            continue
        for script in [
            "gnn_node",
            "lightgbm_baseline",
            "node_baseline",
        ]:
            for split in [
                "val",
                "test",
            ]:
                if script == "node_baseline":
                    stores = list(
                        filter(
                            lambda store: store["__roach__"]["caller_file"].endswith(f"/{script}.py")
                            and store["args"]["dataset"] == dataset
                            and store["args"]["task"] == task,
                            all_stores,
                        ),
                    )
                    for baseline in [
                        "global_zero",
                        "global_mean",
                        "global_median",
                        "entity_mean",
                        "entity_median",
                    ]:
                        try:
                            store = stores[-1]
                            val = store[baseline][split][metric]
                        except IndexError:
                            val = float("nan")
                        record = {
                            "dataset": dataset,
                            "task": task,
                            "script": baseline,
                            "split": split,
                            "mean": val,
                            "std": 0.0,
                        }
                        table_data.append(record)
                else:
                    vals = []
                    for seed in range(5):
                        stores = list(
                            filter(
                                lambda store: store["__roach__"]["caller_file"] == f"{script}.py"
                                and store["args"]["dataset"] == dataset
                                and store["args"]["task"] == task
                                and store["args"]["seed"] == seed,
                                all_stores,
                            )
                        )
                        try:
                            store = stores[-1]
                            val = store[split][metric]
                            vals.append(val)
                        except IndexError:
                            # val = float("nan")
                            # vals.append(val)
                            pass
                    val = torch.tensor(vals)
                    mean = val.mean().item()
                    std = val.std().item()
                    record = {
                        "dataset": dataset,
                        "task": task,
                        "script": script,
                        "split": split,
                        "mean": mean,
                        "std": std,
                    }
                    table_data.append(record)

In [24]:
tex_tab = defaultdict(dict)
for rec in table_data:
    dataset = rec["dataset"]
    task = rec["task"]
    script = rec["script"]
    split = rec["split"]
    mean = rec["mean"]
    std = rec["std"]

    is_best = True
    filter_data = filter(
        lambda r: r["dataset"] == dataset and r["task"] == task and r["split"] == split,
        table_data,
    )
    for comp_rec in filter_data:
        comp_mean = comp_rec["mean"]
        comp_std = comp_rec["std"]
        if higher_is_better:
            if mean + std < comp_mean - comp_std:
                is_best = False
        else:
            if mean - std > comp_mean + comp_std:
                is_best = False
    opt_bm_open = r"\bm{" if is_best else ""
    opt_bm_close = r"}" if is_best else ""
    if script in [
        "gnn_node",
        "lightgbm_baseline",
    ]:
        tex_val = (
            r"$"
            + opt_bm_open
            + f"{mean:.3f}"
            + opt_bm_close
            + r"_{\pm "
            + f"{std:.3f}"
            + r"}$"
        )
    else:
        tex_val = r"$" + opt_bm_open + f"{mean:.3f}" + opt_bm_close + r"$"

    rec["tex"] = tex_val

    tex_tab[script][(wrap(dataset), wrap(task), txt[split])] = tex_val

tex_df = pd.DataFrame(tex_tab)
tex_df

Unnamed: 0,Unnamed: 1,Unnamed: 2,gnn_node,lightgbm_baseline,global_zero,global_mean,global_median,entity_mean,entity_median
\texttt{rel-amazon},\texttt{user-ltv},Val,$\bm{12.135}_{\pm 0.014}$,$14.141_{\pm 0.000}$,$\bm{nan}$,$\bm{nan}$,$\bm{nan}$,$\bm{nan}$,$\bm{nan}$
\texttt{rel-amazon},\texttt{user-ltv},Test,$\bm{14.333}_{\pm 0.043}$,$16.783_{\pm 0.000}$,$\bm{nan}$,$\bm{nan}$,$\bm{nan}$,$\bm{nan}$,$\bm{nan}$
\texttt{rel-amazon},\texttt{item-ltv},Val,$\bm{45.237}_{\pm 0.138}$,$55.800_{\pm 0.052}$,$\bm{nan}$,$\bm{nan}$,$\bm{nan}$,$\bm{nan}$,$\bm{nan}$
\texttt{rel-amazon},\texttt{item-ltv},Test,$\bm{50.330}_{\pm 0.193}$,$60.639_{\pm 0.041}$,$\bm{nan}$,$\bm{nan}$,$\bm{nan}$,$\bm{nan}$,$\bm{nan}$
\texttt{rel-avito},\texttt{ads-clicks},Val,$\bm{0.037}_{\pm 0.001}$,$\bm{0.037}_{\pm 0.000}$,$\bm{nan}$,$\bm{nan}$,$\bm{nan}$,$\bm{nan}$,$\bm{nan}$
\texttt{rel-avito},\texttt{ads-clicks},Test,$\bm{0.041}_{\pm 0.001}$,$\bm{0.041}_{\pm 0.000}$,$\bm{nan}$,$\bm{nan}$,$\bm{nan}$,$\bm{nan}$,$\bm{nan}$
\texttt{rel-event},\texttt{user-attendance},Val,$\bm{0.261}_{\pm 0.002}$,$0.264_{\pm 0.000}$,$\bm{nan}$,$\bm{nan}$,$\bm{nan}$,$\bm{nan}$,$\bm{nan}$
\texttt{rel-event},\texttt{user-attendance},Test,$\bm{0.265}_{\pm 0.006}$,$\bm{0.266}_{\pm 0.000}$,$\bm{nan}$,$\bm{nan}$,$\bm{nan}$,$\bm{nan}$,$\bm{nan}$
\texttt{rel-f1},\texttt{driver-position},Val,$\bm{3.181}_{\pm 0.007}$,$3.477_{\pm 0.022}$,$\bm{nan}$,$\bm{nan}$,$\bm{nan}$,$\bm{nan}$,$\bm{nan}$
\texttt{rel-f1},\texttt{driver-position},Test,$\bm{4.096}_{\pm 0.165}$,$\bm{4.129}_{\pm 0.073}$,$\bm{nan}$,$\bm{nan}$,$\bm{nan}$,$\bm{nan}$,$\bm{nan}$


In [25]:
tex = tex_df.to_latex()
tex = tex.replace(r"\multirow[t]", r"\multirow[c]")
tex = tex.replace(r"\cline", r"\cmidrule")
tex = tex.replace(r"\cmidrule{1-10} \cmidrule{2-10}", r"\cmidrule{1-10}")
print(tex)

\begin{tabular}{llllllllll}
\toprule
 &  &  & gnn_node & lightgbm_baseline & global_zero & global_mean & global_median & entity_mean & entity_median \\
\midrule
\multirow[c]{4}{*}{\texttt{rel-amazon}} & \multirow[c]{2}{*}{\texttt{user-ltv}} & Val & $\bm{12.135}_{\pm 0.014}$ & $14.141_{\pm 0.000}$ & $\bm{nan}$ & $\bm{nan}$ & $\bm{nan}$ & $\bm{nan}$ & $\bm{nan}$ \\
 &  & Test & $\bm{14.333}_{\pm 0.043}$ & $16.783_{\pm 0.000}$ & $\bm{nan}$ & $\bm{nan}$ & $\bm{nan}$ & $\bm{nan}$ & $\bm{nan}$ \\
\cmidrule{2-10}
 & \multirow[c]{2}{*}{\texttt{item-ltv}} & Val & $\bm{45.237}_{\pm 0.138}$ & $55.800_{\pm 0.052}$ & $\bm{nan}$ & $\bm{nan}$ & $\bm{nan}$ & $\bm{nan}$ & $\bm{nan}$ \\
 &  & Test & $\bm{50.330}_{\pm 0.193}$ & $60.639_{\pm 0.041}$ & $\bm{nan}$ & $\bm{nan}$ & $\bm{nan}$ & $\bm{nan}$ & $\bm{nan}$ \\
\cmidrule{1-10}
\multirow[c]{2}{*}{\texttt{rel-avito}} & \multirow[c]{2}{*}{\texttt{ads-clicks}} & Val & $\bm{0.037}_{\pm 0.001}$ & $\bm{0.037}_{\pm 0.000}$ & $\bm{nan}$ & $\bm{nan}$ & $\bm{na

# link prediction

In [30]:
metric = "link_prediction_map"
higher_is_better = True

In [31]:
table_data = []
for dataset in get_dataset_names():
    for task in get_task_names(dataset):
        task_obj = get_task(dataset, task)
        if task_obj.task_type.value != TaskType.LINK_PREDICTION.value:
            continue
        for script in [
            "gnn_link",
            "idgnn_link",
            "lightgbm_link_baseline",
            "link_baseline",
        ]:
            for split in [
                "val",
                "test",
            ]:
                if script == "link_baseline":
                    stores = list(
                        filter(
                            lambda store: store["__roach__"]["caller_file"] == f"{script}.py"
                            and store["args"]["dataset"] == dataset
                            and store["args"]["task"] == task,
                            all_stores,
                        ),
                    )
                    for baseline in [
                        "global_popularity",
                        "past_visit",
                    ]:
                        try:
                            store = stores[-1]
                            val = store[baseline][split][metric]
                        except IndexError:
                            val = float("nan")
                        record = {
                            "dataset": dataset,
                            "task": task,
                            "script": baseline,
                            "split": split,
                            "mean": val,
                            "std": 0.0,
                        }
                        table_data.append(record)
                else:
                    vals = []
                    for seed in range(5):
                        stores = list(
                            filter(
                                lambda store: store["__roach__"]["caller_file"] == f"{script}.py"
                                and store["args"]["dataset"] == dataset
                                and store["args"]["task"] == task
                                and store["args"]["seed"] == seed,
                                all_stores,
                            )
                        )
                        try:
                            store = stores[-1]
                            val = store[split][metric]
                            vals.append(val)
                        except IndexError:
                            # val = float("nan")
                            # vals.append(val)
                            pass
                    val = torch.tensor(vals)
                    mean = val.mean().item()
                    std = val.std().item()
                    record = {
                        "dataset": dataset,
                        "task": task,
                        "script": script,
                        "split": split,
                        "mean": mean,
                        "std": std,
                    }
                    table_data.append(record)

  std = val.std().item()


In [32]:
tex_tab = defaultdict(dict)
for rec in table_data:
    dataset = rec["dataset"]
    task = rec["task"]
    script = rec["script"]
    split = rec["split"]
    mean = rec["mean"]
    std = rec["std"]

    is_best = True
    filter_data = filter(
        lambda r: r["dataset"] == dataset and r["task"] == task and r["split"] == split,
        table_data,
    )
    for comp_rec in filter_data:
        comp_mean = comp_rec["mean"]
        comp_std = comp_rec["std"]
        if higher_is_better:
            if mean + std < comp_mean - comp_std:
                is_best = False
        else:
            if mean - std > comp_mean + comp_std:
                is_best = False
    opt_bm_open = r"\bm{" if is_best else ""
    opt_bm_close = r"}" if is_best else ""
    if script in [
        "gnn_link",
        "idgnn_link",
        "lightgbm_link_baseline",
    ]:
        tex_val = (
            r"$"
            + opt_bm_open
            + f"{mean * 100:.2f}"
            + opt_bm_close
            + r"_{\pm "
            + f"{std * 100:.2f}"
            + r"}$"
        )
    else:
        tex_val = r"$" + opt_bm_open + f"{mean * 100:.2f}" + opt_bm_close + r"$"

    tex_tab[script][(wrap(dataset), wrap(task), txt[split])] = tex_val
tex_df = pd.DataFrame(tex_tab)
tex_df

Unnamed: 0,Unnamed: 1,Unnamed: 2,gnn_link,idgnn_link,lightgbm_link_baseline,global_popularity,past_visit
\texttt{rel-amazon},\texttt{user-item-purchase},Val,$\bm{1.48}_{\pm 0.01}$,$0.13_{\pm 0.00}$,$\bm{nan}_{\pm nan}$,$0.31$,$0.08$
\texttt{rel-amazon},\texttt{user-item-purchase},Test,$\bm{0.69}_{\pm 0.04}$,$0.10_{\pm 0.00}$,$\bm{nan}_{\pm nan}$,$0.24$,$0.06$
\texttt{rel-amazon},\texttt{user-item-rate},Val,$\bm{1.48}_{\pm 0.04}$,$0.15_{\pm 0.00}$,$\bm{nan}_{\pm nan}$,$0.16$,$0.09$
\texttt{rel-amazon},\texttt{user-item-rate},Test,$\bm{0.89}_{\pm 0.10}$,$0.12_{\pm 0.00}$,$\bm{nan}_{\pm nan}$,$0.15$,$0.07$
\texttt{rel-amazon},\texttt{user-item-review},Val,$\bm{1.07}_{\pm 0.07}$,$0.11_{\pm 0.00}$,$\bm{nan}_{\pm nan}$,$0.18$,$0.04$
\texttt{rel-amazon},\texttt{user-item-review},Test,$\bm{0.40}_{\pm 0.03}$,$0.09_{\pm 0.00}$,$\bm{nan}_{\pm nan}$,$0.11$,$0.04$
\texttt{rel-avito},\texttt{user-ad-visit},Val,$\bm{nan}_{\pm nan}$,$\bm{5.43}_{\pm 0.02}$,$\bm{nan}_{\pm nan}$,$0.01$,$3.67$
\texttt{rel-avito},\texttt{user-ad-visit},Test,$\bm{nan}_{\pm nan}$,$\bm{3.66}_{\pm 0.02}$,$\bm{nan}_{\pm nan}$,$0.00$,$1.93$
\texttt{rel-hm},\texttt{user-item-purchase},Val,$0.82_{\pm 0.04}$,$\bm{2.61}_{\pm 0.00}$,$\bm{nan}_{\pm nan}$,$0.36$,$1.10$
\texttt{rel-hm},\texttt{user-item-purchase},Test,$0.78_{\pm 0.02}$,$\bm{2.80}_{\pm 0.01}$,$\bm{nan}_{\pm nan}$,$0.30$,$0.87$


In [33]:
tex = tex_df.to_latex()
tex = tex.replace(r"\multirow[t]", r"\multirow[c]")
tex = tex.replace(r"\cline", r"\cmidrule")
tex = tex.replace(r"\cmidrule{1-7} \cmidrule{2-7}", r"\cmidrule{1-7}")
print(tex)

\begin{tabular}{llllllll}
\toprule
 &  &  & gnn_link & idgnn_link & lightgbm_link_baseline & global_popularity & past_visit \\
\midrule
\multirow[c]{6}{*}{\texttt{rel-amazon}} & \multirow[c]{2}{*}{\texttt{user-item-purchase}} & Val & $\bm{1.48}_{\pm 0.01}$ & $0.13_{\pm 0.00}$ & $\bm{nan}_{\pm nan}$ & $0.31$ & $0.08$ \\
 &  & Test & $\bm{0.69}_{\pm 0.04}$ & $0.10_{\pm 0.00}$ & $\bm{nan}_{\pm nan}$ & $0.24$ & $0.06$ \\
\cmidrule{2-8}
 & \multirow[c]{2}{*}{\texttt{user-item-rate}} & Val & $\bm{1.48}_{\pm 0.04}$ & $0.15_{\pm 0.00}$ & $\bm{nan}_{\pm nan}$ & $0.16$ & $0.09$ \\
 &  & Test & $\bm{0.89}_{\pm 0.10}$ & $0.12_{\pm 0.00}$ & $\bm{nan}_{\pm nan}$ & $0.15$ & $0.07$ \\
\cmidrule{2-8}
 & \multirow[c]{2}{*}{\texttt{user-item-review}} & Val & $\bm{1.07}_{\pm 0.07}$ & $0.11_{\pm 0.00}$ & $\bm{nan}_{\pm nan}$ & $0.18$ & $0.04$ \\
 &  & Test & $\bm{0.40}_{\pm 0.03}$ & $0.09_{\pm 0.00}$ & $\bm{nan}_{\pm nan}$ & $0.11$ & $0.04$ \\
\cmidrule{1-8} \cmidrule{2-8}
\multirow[c]{2}{*}{\texttt{rel-av