In [1]:
import json
from collections import defaultdict
from itertools import product

import pandas as pd
import roach
import torch

from relbench.base import TaskType
from relbench.datasets import get_dataset_names
from relbench.tasks import get_task, get_task_names

In [2]:
all_stores = roach.scan("relbench/2024-07-05")

In [3]:
len(all_stores)

600

In [4]:
all_stores[-1]["__roach__"]

{'project': 'relbench/2024-07-05',
 'timestamp': 1720395884026350988,
 'caller_file': 'idgnn_link.py',
 'done': True}

In [15]:
def wrap(name):
    # return r"\texttt{" + name + r"}"
    return name

In [6]:
txt = {
    "val": "Val",
    "test": "Test",
}

# classification

In [7]:
metric = "roc_auc"
higher_is_better = True

In [8]:
table_data = []
for dataset in get_dataset_names():
    for task in get_task_names(dataset):
        task_obj = get_task(dataset, task)
        if task_obj.task_type.value not in [
            TaskType.BINARY_CLASSIFICATION.value,
            TaskType.MULTICLASS_CLASSIFICATION.value,
        ]:
            continue
        for script in [
            "gnn_node",
            "lightgbm_node",
            "hybrid_node",
        ]:
            for split in [
                "val",
                "test",
            ]:
                vals = []
                for seed in range(5):
                    stores = list(
                        filter(
                            lambda store: store["__roach__"]["caller_file"]
                            == f"{script}.py"
                            and store["args"]["dataset"] == dataset
                            and store["args"]["task"] == task
                            and store["args"]["seed"] == seed,
                            all_stores,
                        )
                    )
                    try:
                        store = stores[-1]
                        val = store[split][metric]
                        vals.append(val)
                    except IndexError:
                        val = float("nan")
                        vals.append(val)
                        # pass
                val = torch.tensor(vals)
                mean = val.mean().item()
                std = val.std().item()
                record = {
                    "dataset": dataset,
                    "task": task,
                    "script": script,
                    "split": split,
                    "mean": mean,
                    "std": std,
                }
                table_data.append(record)

In [16]:
tex_tab = defaultdict(dict)
for rec in table_data:
    dataset = rec["dataset"]
    task = rec["task"]
    script = rec["script"]
    split = rec["split"]
    mean = rec["mean"]
    std = rec["std"]

    if script == "hybrid_node":
        continue

    is_best = True
    filter_data = filter(
        lambda r: r["dataset"] == dataset
        and r["task"] == task
        and r["split"] == split
        and r["script"] != "hybrid_node",
        table_data,
    )
    for comp_rec in filter_data:
        comp_mean = comp_rec["mean"]
        comp_std = comp_rec["std"]
        if higher_is_better:
            if mean + std < comp_mean - comp_std:
                is_best = False
        else:
            if mean - std > comp_mean + comp_std:
                is_best = False
    opt_bm_open = r"\bm{" if is_best else ""
    opt_bm_close = r"}" if is_best else ""
    tex_val = (
        r"$"
        + opt_bm_open
        + f"{mean * 100:.2f}"
        + opt_bm_close
        + r"_{\pm "
        + f"{std * 100:.2f}"
        + r"}$"
    )

    tex_tab[script][(wrap(dataset), wrap(task), txt[split])] = tex_val
tex_df = pd.DataFrame(tex_tab)
tex_df

Unnamed: 0,Unnamed: 1,Unnamed: 2,gnn_node,lightgbm_node
rel-amazon,user-churn,Val,$\bm{70.06}_{\pm 0.04}$,$52.00_{\pm 0.02}$
rel-amazon,user-churn,Test,$\bm{70.20}_{\pm 0.06}$,$52.34_{\pm 0.07}$
rel-amazon,item-churn,Val,$\bm{81.04}_{\pm 0.05}$,$61.27_{\pm 0.20}$
rel-amazon,item-churn,Test,$\bm{81.46}_{\pm 0.07}$,$61.96_{\pm 0.26}$
rel-avito,user-visits,Val,$\bm{69.65}_{\pm 0.04}$,$53.31_{\pm 0.09}$
rel-avito,user-visits,Test,$\bm{66.20}_{\pm 0.10}$,$53.05_{\pm 0.32}$
rel-avito,user-clicks,Val,$\bm{64.73}_{\pm 0.32}$,$55.63_{\pm 0.31}$
rel-avito,user-clicks,Test,$\bm{65.90}_{\pm 1.95}$,$53.60_{\pm 0.59}$
rel-event,user-repeat,Val,$\bm{71.73}_{\pm 2.21}$,$67.76_{\pm 1.10}$
rel-event,user-repeat,Test,$\bm{78.31}_{\pm 1.26}$,$69.74_{\pm 2.17}$


In [13]:
tex = tex_df.to_latex()
tex = tex.replace(r"\multirow[t]", r"\multirow[c]")
tex = tex.replace(r"\cline", r"\cmidrule")
tex = tex.replace(r"\cmidrule{1-5} \cmidrule{2-5}", r"\cmidrule{1-5}")
print(tex)

\begin{tabular}{lllll}
\toprule
 &  &  & gnn_node & lightgbm_node \\
\midrule
\multirow[c]{4}{*}{\texttt{rel-amazon}} & \multirow[c]{2}{*}{\texttt{user-churn}} & Val & $\bm{70.06}_{\pm 0.04}$ & $52.00_{\pm 0.02}$ \\
 &  & Test & $\bm{70.20}_{\pm 0.06}$ & $52.34_{\pm 0.07}$ \\
\cmidrule{2-5}
 & \multirow[c]{2}{*}{\texttt{item-churn}} & Val & $\bm{81.04}_{\pm 0.05}$ & $61.27_{\pm 0.20}$ \\
 &  & Test & $\bm{81.46}_{\pm 0.07}$ & $61.96_{\pm 0.26}$ \\
\cmidrule{1-5}
\multirow[c]{4}{*}{\texttt{rel-avito}} & \multirow[c]{2}{*}{\texttt{user-visits}} & Val & $\bm{69.65}_{\pm 0.04}$ & $53.31_{\pm 0.09}$ \\
 &  & Test & $\bm{66.20}_{\pm 0.10}$ & $53.05_{\pm 0.32}$ \\
\cmidrule{2-5}
 & \multirow[c]{2}{*}{\texttt{user-clicks}} & Val & $\bm{64.73}_{\pm 0.32}$ & $55.63_{\pm 0.31}$ \\
 &  & Test & $\bm{65.90}_{\pm 1.95}$ & $53.60_{\pm 0.59}$ \\
\cmidrule{1-5}
\multirow[c]{4}{*}{\texttt{rel-event}} & \multirow[c]{2}{*}{\texttt{user-repeat}} & Val & $\bm{71.73}_{\pm 2.21}$ & $67.76_{\pm 1.10}$ \\
 &  &

In [14]:
lb_sub = defaultdict(dict)
for rec in table_data:
    dataset = rec["dataset"]
    task = rec["task"]
    script = rec["script"]
    split = rec["split"]
    mean = rec["mean"]
    std = rec["std"]

    if split != "test":
        continue

    lb_sub[script][f"{dataset}/{task}"] = [mean, std]
print(json.dumps(lb_sub, indent=2))

{
  "gnn_node": {
    "rel-amazon/user-churn": [
      0.7020211990816014,
      0.0006491522052338097
    ],
    "rel-amazon/item-churn": [
      0.8145945669344069,
      0.000681160632374158
    ],
    "rel-avito/user-visits": [
      0.6619693569204161,
      0.0009747970192118968
    ],
    "rel-avito/user-clicks": [
      0.659044131104187,
      0.019496906207840146
    ],
    "rel-event/user-repeat": [
      0.7830882352941175,
      0.012583633824532432
    ],
    "rel-event/user-ignore": [
      0.8037337652027563,
      0.014979613817607264
    ],
    "rel-f1/driver-dnf": [
      0.7262187088274044,
      0.0027175675790203574
    ],
    "rel-f1/driver-top3": [
      0.7554073474080267,
      0.006297637190809179
    ],
    "rel-hm/user-churn": [
      0.6988453175294079,
      0.0020724291135691144
    ],
    "rel-stack/user-engagement": [
      0.9059087278671942,
      0.0009162324704242817
    ],
    "rel-stack/user-badge": [
      0.8885698594534229,
      0.00081617326