In [1]:
import json
from collections import defaultdict
from itertools import product

import pandas as pd
import roach
import torch

from relbench.base import TaskType
from relbench.datasets import get_dataset_names
from relbench.tasks import get_task, get_task_names

In [2]:
all_stores = roach.scan("relbench/2024-07-05")

In [3]:
len(all_stores)

742

In [4]:
all_stores[-1]["__roach__"]

{'project': 'relbench/2024-07-05',
 'timestamp': 1720608396648790026,
 'caller_file': 'idgnn_link.py',
 'done': True}

In [5]:
def wrap(name):
    return r"\texttt{" + name + r"}"
    # return name

In [6]:
txt = {
    "val": "Val",
    "test": "Test",
}

# classification

In [7]:
metric = "roc_auc"
higher_is_better = True

In [8]:
table_data = []
for dataset in get_dataset_names():
    for task in get_task_names(dataset):
        task_obj = get_task(dataset, task)
        if task_obj.task_type.value not in [
            TaskType.BINARY_CLASSIFICATION.value,
            TaskType.MULTICLASS_CLASSIFICATION.value,
        ]:
            continue
        for script in [
            "gnn_node",
            "lightgbm_node",
            "hybrid_node",
        ]:
            for split in [
                "val",
                "test",
            ]:
                vals = []
                for seed in range(5):
                    stores = list(
                        filter(
                            lambda store: store["__roach__"]["caller_file"]
                            == f"{script}.py"
                            and store["args"]["dataset"] == dataset
                            and store["args"]["task"] == task
                            and store["args"]["seed"] == seed,
                            all_stores,
                        )
                    )
                    try:
                        store = stores[-1]
                        val = store[split][metric]
                        vals.append(val)
                    except IndexError:
                        val = float("nan")
                        vals.append(val)
                        # pass
                val = torch.tensor(vals)
                mean = val.mean().item()
                std = val.std().item()
                record = {
                    "dataset": dataset,
                    "task": task,
                    "script": script,
                    "split": split,
                    "mean": mean,
                    "std": std,
                }
                table_data.append(record)

## main paper table

In [38]:
tex_tab = defaultdict(dict)
for dataset in get_dataset_names():
    for task in get_task_names(dataset):
        task_obj = get_task(dataset, task)
        if task_obj.task_type.value != TaskType.BINARY_CLASSIFICATION.value:
            continue
        for split in [
            "val",
            "test",
        ]:
            local_data = filter(
                lambda r: r["dataset"] == dataset
                and r["task"] == task
                and r["split"] == split,
                table_data
            )
            local_data = list(local_data)
            for script in [
                "lightgbm_node",
                "gnn_node",
                "relative",
            ]:
                if script == "relative":
                    f = filter(
                        lambda r: r["script"] == "gnn_node",
                        local_data
                    )
                    r = next(f)
                    rdl_mean = r["mean"]
                    
                    f = filter(
                        lambda r: r["script"] == "lightgbm_node",
                        local_data
                    )
                    r = next(f)
                    dt_mean = r["mean"]

                    if higher_is_better:
                        val = (rdl_mean - dt_mean) / dt_mean
                    else:
                        raise NotImplementedError

                    record = {
                        "dataset": dataset,
                        "task": task,
                        "script": "relative",
                        "split": split,
                        "mean": val,
                        "std": float("nan"),
                    }
                    table_data.append(record)

                    tex_val = (
                        r"$"
                        + f"{val * 100: .2f}"
                        + r"$ \%"
                    )
                        

                else:
                    for rec in local_data:
                        if rec["script"] == script:
                            break
                    mean = rec["mean"]
                    std = rec["std"]
                
                    is_best = True
                    for comp_rec in local_data:
                        if comp_rec["script"] == "hybrid_node":
                            continue
                        comp_mean = comp_rec["mean"]
                        comp_std = comp_rec["std"]
                        ### ignore std
                        std = 0
                        comp_std = 0
                        ###
                        if higher_is_better:
                            if mean + std < comp_mean - comp_std:
                                is_best = False
                        else:
                            if mean - std > comp_mean + comp_std:
                                is_best = False
                    opt_bm_open = r"\bm{" if is_best else ""
                    opt_bm_close = r"}" if is_best else ""
                    tex_val = (
                        r"$"
                        + opt_bm_open
                        + f"{mean * 100:.2f}"
                        + opt_bm_close
                        ### ignore std
                        # + r"_{"
                        # + f"{std * 100:.2f}"
                        # + r"}$"
                        ###
                        + r"$"
                    )
            
                tex_tab[script][(wrap(dataset), wrap(task), txt[split])] = tex_val

for script in ["lightgbm_node", "gnn_node", "relative"]:
    for split in ["val", "test"]:
        local_data = filter(
            lambda r: r["script"] == script
            and r["split"] == split,
            table_data
        )
        local_data = list(local_data)
        vals = []
        for rec in local_data:
            vals.append(rec["mean"])
        mean = sum(vals)/len(vals)

        is_best = script == "gnn_node"
        opt_bm_open = r"\bm{" if is_best else ""
        opt_bm_close = r"}" if is_best else ""
        tex_val = (
            r"$"
            + opt_bm_open
            + f"{mean * 100:.2f}"
            + opt_bm_close
            ### ignore std
            # + r"_{"
            # + f"{std * 100:.2f}"
            # + r"}$"
            ###
            + r"$"
        )

        if script == "relative":
            tex_val += r" \%"
        
        tex_tab[script]["average", "", txt[split]] = tex_val

            
tex_df = pd.DataFrame(tex_tab)
tex_df

Unnamed: 0,Unnamed: 1,Unnamed: 2,lightgbm_node,gnn_node,relative
\texttt{rel-amazon},\texttt{user-churn},Val,$52.05$,$\bm{70.45}$,$ 35.35$ \%
\texttt{rel-amazon},\texttt{user-churn},Test,$52.22$,$\bm{70.42}$,$ 34.86$ \%
\texttt{rel-amazon},\texttt{item-churn},Val,$62.39$,$\bm{82.39}$,$ 32.06$ \%
\texttt{rel-amazon},\texttt{item-churn},Test,$62.54$,$\bm{82.81}$,$ 32.40$ \%
\texttt{rel-avito},\texttt{user-visits},Val,$53.31$,$\bm{69.65}$,$ 30.66$ \%
\texttt{rel-avito},\texttt{user-visits},Test,$53.05$,$\bm{66.20}$,$ 24.78$ \%
\texttt{rel-avito},\texttt{user-clicks},Val,$55.63$,$\bm{64.73}$,$ 16.35$ \%
\texttt{rel-avito},\texttt{user-clicks},Test,$53.60$,$\bm{65.90}$,$ 22.96$ \%
\texttt{rel-event},\texttt{user-repeat},Val,$67.76$,$\bm{71.73}$,$ 5.86$ \%
\texttt{rel-event},\texttt{user-repeat},Test,$69.74$,$\bm{78.31}$,$ 12.28$ \%


In [41]:
tex = tex_df.to_latex()
tex = tex.replace(r"\multirow[t]", r"\multirow[c]")
tex = tex.replace(r"\cline", r"\cmidrule")
tex = tex.replace(r"\cmidrule{1-6} \cmidrule{2-6}", r"\cmidrule{1-6}")
tex = tex.replace(r"\multirow[c]{2}{*}{average} & \multirow[c]{2}{*}{}", r"\multicolumn{2}{c}{\multirow[c]{2}{*}{Average}}")
print(tex)

\begin{tabular}{llllll}
\toprule
 &  &  & lightgbm_node & gnn_node & relative \\
\midrule
\multirow[c]{4}{*}{\texttt{rel-amazon}} & \multirow[c]{2}{*}{\texttt{user-churn}} & Val & $52.05$ & $\bm{70.45}$ & $ 35.35$ \% \\
 &  & Test & $52.22$ & $\bm{70.42}$ & $ 34.86$ \% \\
\cmidrule{2-6}
 & \multirow[c]{2}{*}{\texttt{item-churn}} & Val & $62.39$ & $\bm{82.39}$ & $ 32.06$ \% \\
 &  & Test & $62.54$ & $\bm{82.81}$ & $ 32.40$ \% \\
\cmidrule{1-6}
\multirow[c]{4}{*}{\texttt{rel-avito}} & \multirow[c]{2}{*}{\texttt{user-visits}} & Val & $53.31$ & $\bm{69.65}$ & $ 30.66$ \% \\
 &  & Test & $53.05$ & $\bm{66.20}$ & $ 24.78$ \% \\
\cmidrule{2-6}
 & \multirow[c]{2}{*}{\texttt{user-clicks}} & Val & $55.63$ & $\bm{64.73}$ & $ 16.35$ \% \\
 &  & Test & $53.60$ & $\bm{65.90}$ & $ 22.96$ \% \\
\cmidrule{1-6}
\multirow[c]{4}{*}{\texttt{rel-event}} & \multirow[c]{2}{*}{\texttt{user-repeat}} & Val & $67.76$ & $\bm{71.73}$ & $ 5.86$ \% \\
 &  & Test & $69.74$ & $\bm{78.31}$ & $ 12.28$ \% \\
\cmidrule{2-6

## appendix table

In [32]:
tex_tab = defaultdict(dict)
for dataset in get_dataset_names():
    for task in get_task_names(dataset):
        task_obj = get_task(dataset, task)
        if task_obj.task_type.value != TaskType.BINARY_CLASSIFICATION.value:
            continue
        for split in [
            "val",
            "test",
        ]:
            local_data = filter(
                lambda r: r["dataset"] == dataset
                and r["task"] == task
                and r["split"] == split,
                table_data
            )
            local_data = list(local_data)
            for script in [
                "lightgbm_node",
                "gnn_node",
            ]:
                for rec in local_data:
                    if rec["script"] == script:
                        break
                mean = rec["mean"]
                std = rec["std"]
            
                is_best = True
                for comp_rec in local_data:
                    if comp_rec["script"] == "hybrid_node":
                        continue
                    comp_mean = comp_rec["mean"]
                    comp_std = comp_rec["std"]
                    if higher_is_better:
                        if mean + std < comp_mean - comp_std:
                            is_best = False
                    else:
                        if mean - std > comp_mean + comp_std:
                            is_best = False
                opt_bm_open = r"\bm{" if is_best else ""
                opt_bm_close = r"}" if is_best else ""
                tex_val = (
                    r"$"
                    + opt_bm_open
                    + f"{mean * 100:.2f}"
                    + opt_bm_close
                    + r"_{\pm "
                    + f"{std * 100:.2f}"
                    + r"}$"
                )
        
                tex_tab[script][(wrap(dataset), wrap(task), txt[split])] = tex_val
        
tex_df = pd.DataFrame(tex_tab)
tex_df

Unnamed: 0,Unnamed: 1,Unnamed: 2,lightgbm_node,gnn_node
\texttt{rel-amazon},\texttt{user-churn},Val,$52.05_{\pm 0.06}$,$\bm{70.45}_{\pm 0.06}$
\texttt{rel-amazon},\texttt{user-churn},Test,$52.22_{\pm 0.06}$,$\bm{70.42}_{\pm 0.05}$
\texttt{rel-amazon},\texttt{item-churn},Val,$62.39_{\pm 0.20}$,$\bm{82.39}_{\pm 0.02}$
\texttt{rel-amazon},\texttt{item-churn},Test,$62.54_{\pm 0.18}$,$\bm{82.81}_{\pm 0.03}$
\texttt{rel-avito},\texttt{user-visits},Val,$53.31_{\pm 0.09}$,$\bm{69.65}_{\pm 0.04}$
\texttt{rel-avito},\texttt{user-visits},Test,$53.05_{\pm 0.32}$,$\bm{66.20}_{\pm 0.10}$
\texttt{rel-avito},\texttt{user-clicks},Val,$55.63_{\pm 0.31}$,$\bm{64.73}_{\pm 0.32}$
\texttt{rel-avito},\texttt{user-clicks},Test,$53.60_{\pm 0.59}$,$\bm{65.90}_{\pm 1.95}$
\texttt{rel-event},\texttt{user-repeat},Val,$67.76_{\pm 1.10}$,$\bm{71.73}_{\pm 2.21}$
\texttt{rel-event},\texttt{user-repeat},Test,$69.74_{\pm 2.17}$,$\bm{78.31}_{\pm 1.26}$


In [34]:
tex = tex_df.to_latex()
tex = tex.replace(r"\multirow[t]", r"\multirow[c]")
tex = tex.replace(r"\cline", r"\cmidrule")
tex = tex.replace(r"\cmidrule{1-5} \cmidrule{2-5}", r"\cmidrule{1-5}")
print(tex)

\begin{tabular}{lllll}
\toprule
 &  &  & lightgbm_node & gnn_node \\
\midrule
\multirow[c]{4}{*}{\texttt{rel-amazon}} & \multirow[c]{2}{*}{\texttt{user-churn}} & Val & $52.05_{\pm 0.06}$ & $\bm{70.45}_{\pm 0.06}$ \\
 &  & Test & $52.22_{\pm 0.06}$ & $\bm{70.42}_{\pm 0.05}$ \\
\cmidrule{2-5}
 & \multirow[c]{2}{*}{\texttt{item-churn}} & Val & $62.39_{\pm 0.20}$ & $\bm{82.39}_{\pm 0.02}$ \\
 &  & Test & $62.54_{\pm 0.18}$ & $\bm{82.81}_{\pm 0.03}$ \\
\cmidrule{1-5}
\multirow[c]{4}{*}{\texttt{rel-avito}} & \multirow[c]{2}{*}{\texttt{user-visits}} & Val & $53.31_{\pm 0.09}$ & $\bm{69.65}_{\pm 0.04}$ \\
 &  & Test & $53.05_{\pm 0.32}$ & $\bm{66.20}_{\pm 0.10}$ \\
\cmidrule{2-5}
 & \multirow[c]{2}{*}{\texttt{user-clicks}} & Val & $55.63_{\pm 0.31}$ & $\bm{64.73}_{\pm 0.32}$ \\
 &  & Test & $53.60_{\pm 0.59}$ & $\bm{65.90}_{\pm 1.95}$ \\
\cmidrule{1-5}
\multirow[c]{4}{*}{\texttt{rel-event}} & \multirow[c]{2}{*}{\texttt{user-repeat}} & Val & $67.76_{\pm 1.10}$ & $\bm{71.73}_{\pm 2.21}$ \\
 &  &

## leaderboard submission

In [42]:
lb_sub = defaultdict(lambda: defaultdict(dict))
for rec in table_data:
    dataset = rec["dataset"]
    task = rec["task"]
    script = rec["script"]
    split = rec["split"]
    mean = rec["mean"]
    std = rec["std"]

    lb_sub[script][split][f"{dataset}/{task}"] = [mean, std]
print(json.dumps(lb_sub, indent=2))

{
  "gnn_node": {
    "val": {
      "rel-amazon/user-churn": [
        0.7044906356287329,
        0.0006415979808004229
      ],
      "rel-amazon/item-churn": [
        0.8238862620478642,
        0.00024490444792151027
      ],
      "rel-avito/user-visits": [
        0.6964859033965352,
        0.000393475245965563
      ],
      "rel-avito/user-clicks": [
        0.6472513957748135,
        0.0032469814432288606
      ],
      "rel-event/user-repeat": [
        0.7173021181716834,
        0.02208112913214512
      ],
      "rel-event/user-ignore": [
        0.9073610142062524,
        0.006697674853008263
      ],
      "rel-f1/driver-dnf": [
        0.7136290249433108,
        0.01538835001954269
      ],
      "rel-f1/driver-top3": [
        0.7763738331153357,
        0.031579689099974516
      ],
      "rel-hm/user-churn": [
        0.7041904537946169,
        0.0009418573540523883
      ],
      "rel-stack/user-engagement": [
        0.9021041908413816,
        0.00068874534