In [1]:
import json
from collections import defaultdict
from itertools import product

import pandas as pd
import roach
import torch

from relbench.base import TaskType
from relbench.datasets import get_dataset_names
from relbench.tasks import get_task, get_task_names

In [2]:
all_stores = roach.scan("relbench/2024-07-05")

In [3]:
len(all_stores)

742

In [4]:
all_stores[-1]["__roach__"]

{'project': 'relbench/2024-07-05',
 'timestamp': 1720608396648790026,
 'caller_file': 'idgnn_link.py',
 'done': True}

In [5]:
def wrap(name):
    return r"\texttt{" + name + r"}"
    # return name

In [6]:
txt = {
    "val": "Val",
    "test": "Test",
}

# link prediction

In [7]:
metric = "link_prediction_map"
higher_is_better = True

In [8]:
table_data = []
for dataset in get_dataset_names():
    for task in get_task_names(dataset):
        task_obj = get_task(dataset, task)
        if task_obj.task_type.value != TaskType.LINK_PREDICTION.value:
            continue
        for script in [
            "gnn_link",
            "idgnn_link",
            "lightgbm_link",
            "baseline_link",
        ]:
            for split in [
                "val",
                "test",
            ]:
                if script == "baseline_link":
                    stores = list(
                        filter(
                            lambda store: store["__roach__"]["caller_file"]
                            == f"{script}.py"
                            and store["args"]["dataset"] == dataset
                            and store["args"]["task"] == task,
                            all_stores,
                        ),
                    )
                    for baseline in [
                        "global_popularity",
                        "past_visit",
                    ]:
                        try:
                            store = stores[-1]
                            val = store[baseline][split][metric]
                        except IndexError:
                            val = float("nan")
                        record = {
                            "dataset": dataset,
                            "task": task,
                            "script": baseline,
                            "split": split,
                            "mean": val,
                            "std": 0.0,
                        }
                        table_data.append(record)
                else:
                    vals = []
                    for seed in range(5):
                        stores = list(
                            filter(
                                lambda store: store["__roach__"]["caller_file"]
                                == f"{script}.py"
                                and store["args"]["dataset"] == dataset
                                and store["args"]["task"] == task
                                and store["args"]["seed"] == seed,
                                all_stores,
                            )
                        )
                        try:
                            store = stores[-1]
                            val = store[split][metric]
                            vals.append(val)
                        except IndexError:
                            # val = float("nan")
                            # vals.append(val)
                            pass
                    val = torch.tensor(vals)
                    mean = val.mean().item()
                    std = val.std().item()
                    record = {
                        "dataset": dataset,
                        "task": task,
                        "script": script,
                        "split": split,
                        "mean": mean,
                        "std": std,
                    }
                    table_data.append(record)

## main paper table

In [26]:
tex_tab = defaultdict(dict)
for dataset in get_dataset_names():
    for task in get_task_names(dataset):
        task_obj = get_task(dataset, task)
        if task_obj.task_type.value != TaskType.LINK_PREDICTION.value:
            continue
        for split in [
            "val",
            "test",
        ]:
            local_data = filter(
                lambda r: r["dataset"] == dataset
                and r["task"] == task
                and r["split"] == split,
                table_data,
            )
            local_data = list(local_data)
            for script in [
                "global_popularity",
                "past_visit",
                "lightgbm_link",
                "gnn_link",
                "idgnn_link",
                "relative",
            ]:
                if script == "relative":
                    f = filter(
                        lambda r: r["script"] in ["gnn_link", "idgnn_link"], local_data
                    )
                    r = list(f)
                    assert len(r) == 2
                    rdl_mean = max(r[0]["mean"], r[1]["mean"])

                    f = filter(
                        lambda r: r["script"]
                        in ["global_popularity", "past_visit", "lightgbm_link"],
                        local_data,
                    )
                    r = list(f)
                    assert len(r) == 3
                    dt_mean = max(r[0]["mean"], r[1]["mean"], r[2]["mean"])

                    if higher_is_better:
                        val = (rdl_mean - dt_mean) / dt_mean
                    else:
                        val = (dt_mean - rdl_mean) / dt_mean

                    record = {
                        "dataset": dataset,
                        "task": task,
                        "script": "relative",
                        "split": split,
                        "mean": val,
                        "std": float("nan"),
                    }
                    table_data.append(record)

                    tex_val = r"$" + f"{val * 100: .2f}" + r"$ \%"

                else:
                    for rec in local_data:
                        if rec["script"] == script:
                            break
                    mean = rec["mean"]
                    std = rec["std"]

                    is_best = True
                    for comp_rec in local_data:
                        if comp_rec["script"] in ["hybrid_node", "relative"]:
                            continue
                        comp_mean = comp_rec["mean"]
                        comp_std = comp_rec["std"]
                        ### ignore std
                        std = 0
                        comp_std = 0
                        ###
                        if higher_is_better:
                            if mean + std < comp_mean - comp_std:
                                is_best = False
                        else:
                            if mean - std > comp_mean + comp_std:
                                is_best = False
                    opt_bm_open = r"\bm{" if is_best else ""
                    opt_bm_close = r"}" if is_best else ""
                    tex_val = (
                        r"$"
                        + opt_bm_open
                        + f"{mean * 100:.2f}"
                        + opt_bm_close
                        ### ignore std
                        # + r"_{"
                        # + f"{std * 100:.2f}"
                        # + r"}$"
                        ###
                        + r"$"
                    )

                tex_tab[script][(wrap(dataset), wrap(task), txt[split])] = tex_val

for script in [
    "global_popularity",
    "past_visit",
    "lightgbm_link",
    "gnn_link",
    "idgnn_link",
    "relative",
]:
    for split in ["val", "test"]:
        local_data = filter(
            lambda r: r["script"] == script and r["split"] == split, table_data
        )
        local_data = list(local_data)
        vals = []
        for rec in local_data:
            vals.append(rec["mean"])
        mean = sum(vals) / len(vals)

        is_best = script == "idgnn_link"
        opt_bm_open = r"\bm{" if is_best else ""
        opt_bm_close = r"}" if is_best else ""
        tex_val = (
            r"$"
            + opt_bm_open
            + f"{mean * 100:.2f}"
            + opt_bm_close
            ### ignore std
            # + r"_{"
            # + f"{std * 100:.2f}"
            # + r"}$"
            ###
            + r"$"
        )

        if script == "relative":
            tex_val += r" \%"

        tex_tab[script]["average", "", txt[split]] = tex_val


tex_df = pd.DataFrame(tex_tab)
tex_df

Unnamed: 0,Unnamed: 1,Unnamed: 2,global_popularity,past_visit,lightgbm_link,gnn_link,idgnn_link,relative
\texttt{rel-amazon},\texttt{user-item-purchase},Val,$0.31$,$0.07$,$0.18$,$\bm{1.53}$,$0.13$,$ 397.55$ \%
\texttt{rel-amazon},\texttt{user-item-purchase},Test,$0.24$,$0.06$,$0.16$,$\bm{0.74}$,$0.10$,$ 204.74$ \%
\texttt{rel-amazon},\texttt{user-item-rate},Val,$0.16$,$0.09$,$0.22$,$\bm{1.42}$,$0.15$,$ 550.12$ \%
\texttt{rel-amazon},\texttt{user-item-rate},Test,$0.15$,$0.07$,$0.17$,$\bm{0.87}$,$0.12$,$ 395.92$ \%
\texttt{rel-amazon},\texttt{user-item-review},Val,$0.18$,$0.05$,$0.14$,$\bm{1.03}$,$0.11$,$ 476.06$ \%
\texttt{rel-amazon},\texttt{user-item-review},Test,$0.11$,$0.04$,$0.09$,$\bm{0.47}$,$0.09$,$ 313.07$ \%
\texttt{rel-avito},\texttt{user-ad-visit},Val,$0.01$,$3.66$,$0.17$,$0.09$,$\bm{5.40}$,$ 47.37$ \%
\texttt{rel-avito},\texttt{user-ad-visit},Test,$0.00$,$1.95$,$0.06$,$0.02$,$\bm{3.66}$,$ 87.09$ \%
\texttt{rel-hm},\texttt{user-item-purchase},Val,$0.36$,$1.07$,$0.44$,$0.92$,$\bm{2.64}$,$ 145.60$ \%
\texttt{rel-hm},\texttt{user-item-purchase},Test,$0.30$,$0.89$,$0.38$,$0.80$,$\bm{2.81}$,$ 214.49$ \%


In [27]:
tex = tex_df.to_latex()
tex = tex.replace(r"\multirow[t]", r"\multirow[c]")
tex = tex.replace(r"\cline", r"\cmidrule")
tex = tex.replace(r"\cmidrule{1-9} \cmidrule{2-9}", r"\cmidrule{1-9}")
tex = tex.replace(r"\multirow[c]{2}{*}{average} & \multirow[c]{2}{*}{}", r"\multicolumn{2}{c}{\multirow[c]{2}{*}{Average}}")
print(tex)

\begin{tabular}{lllllllll}
\toprule
 &  &  & global_popularity & past_visit & lightgbm_link & gnn_link & idgnn_link & relative \\
\midrule
\multirow[c]{6}{*}{\texttt{rel-amazon}} & \multirow[c]{2}{*}{\texttt{user-item-purchase}} & Val & $0.31$ & $0.07$ & $0.18$ & $\bm{1.53}$ & $0.13$ & $ 397.55$ \% \\
 &  & Test & $0.24$ & $0.06$ & $0.16$ & $\bm{0.74}$ & $0.10$ & $ 204.74$ \% \\
\cmidrule{2-9}
 & \multirow[c]{2}{*}{\texttt{user-item-rate}} & Val & $0.16$ & $0.09$ & $0.22$ & $\bm{1.42}$ & $0.15$ & $ 550.12$ \% \\
 &  & Test & $0.15$ & $0.07$ & $0.17$ & $\bm{0.87}$ & $0.12$ & $ 395.92$ \% \\
\cmidrule{2-9}
 & \multirow[c]{2}{*}{\texttt{user-item-review}} & Val & $0.18$ & $0.05$ & $0.14$ & $\bm{1.03}$ & $0.11$ & $ 476.06$ \% \\
 &  & Test & $0.11$ & $0.04$ & $0.09$ & $\bm{0.47}$ & $0.09$ & $ 313.07$ \% \\
\cmidrule{1-9}
\multirow[c]{2}{*}{\texttt{rel-avito}} & \multirow[c]{2}{*}{\texttt{user-ad-visit}} & Val & $0.01$ & $3.66$ & $0.17$ & $0.09$ & $\bm{5.40}$ & $ 47.37$ \% \\
 &  & Test & $

## appendix table

In [22]:
tex_tab = defaultdict(dict)
for dataset in get_dataset_names():
    for task in get_task_names(dataset):
        task_obj = get_task(dataset, task)
        if task_obj.task_type.value != TaskType.LINK_PREDICTION.value:
            continue
        for split in [
            "val",
            "test",
        ]:
            local_data = filter(
                lambda r: r["dataset"] == dataset
                and r["task"] == task
                and r["split"] == split,
                table_data,
            )
            local_data = list(local_data)
            for script in [
                "global_popularity",
                "past_visit",
                "lightgbm_link",
                "gnn_link",
                "idgnn_link",
            ]:
                for rec in local_data:
                    if rec["script"] == script:
                        break
                mean = rec["mean"]
                std = rec["std"]

                is_best = True
                for comp_rec in local_data:
                    if comp_rec["script"] == "hybrid_node":
                        continue
                    comp_mean = comp_rec["mean"]
                    comp_std = comp_rec["std"]
                    if higher_is_better:
                        if mean + std < comp_mean - comp_std:
                            is_best = False
                    else:
                        if mean - std > comp_mean + comp_std:
                            is_best = False
                opt_bm_open = r"\bm{" if is_best else ""
                opt_bm_close = r"}" if is_best else ""
                if script in ["gnn_link", "idgnn_link", "lightgbm_link"]:
                    tex_val = (
                        r"$"
                        + opt_bm_open
                        + f"{mean * 100:.2f}"
                        + opt_bm_close
                        + r"_{\pm "
                        + f"{std * 100:.2f}"
                        + r"}$"
                    )
                else:
                    tex_val = (
                        r"$" + opt_bm_open + f"{mean * 100:.2f}" + opt_bm_close + r"$"
                    )

                tex_tab[script][(wrap(dataset), wrap(task), txt[split])] = tex_val

tex_df = pd.DataFrame(tex_tab)
tex_df

Unnamed: 0,Unnamed: 1,Unnamed: 2,global_popularity,past_visit,lightgbm_link,gnn_link,idgnn_link
\texttt{rel-amazon},\texttt{user-item-purchase},Val,$0.31$,$0.07$,$0.18_{\pm 0.07}$,$\bm{1.53}_{\pm 0.05}$,$0.13_{\pm 0.00}$
\texttt{rel-amazon},\texttt{user-item-purchase},Test,$0.24$,$0.06$,$0.16_{\pm 0.05}$,$\bm{0.74}_{\pm 0.08}$,$0.10_{\pm 0.00}$
\texttt{rel-amazon},\texttt{user-item-rate},Val,$0.16$,$0.09$,$0.22_{\pm 0.02}$,$\bm{1.42}_{\pm 0.06}$,$0.15_{\pm 0.00}$
\texttt{rel-amazon},\texttt{user-item-rate},Test,$0.15$,$0.07$,$0.17_{\pm 0.01}$,$\bm{0.87}_{\pm 0.05}$,$0.12_{\pm 0.00}$
\texttt{rel-amazon},\texttt{user-item-review},Val,$0.18$,$0.05$,$0.14_{\pm 0.03}$,$\bm{1.03}_{\pm 0.03}$,$0.11_{\pm 0.00}$
\texttt{rel-amazon},\texttt{user-item-review},Test,$0.11$,$0.04$,$0.09_{\pm 0.01}$,$\bm{0.47}_{\pm 0.05}$,$0.09_{\pm 0.00}$
\texttt{rel-avito},\texttt{user-ad-visit},Val,$0.01$,$3.66$,$0.17_{\pm 0.01}$,$0.09_{\pm 0.01}$,$\bm{5.40}_{\pm 0.02}$
\texttt{rel-avito},\texttt{user-ad-visit},Test,$0.00$,$1.95$,$0.06_{\pm 0.01}$,$0.02_{\pm 0.00}$,$\bm{3.66}_{\pm 0.02}$
\texttt{rel-hm},\texttt{user-item-purchase},Val,$0.36$,$1.07$,$0.44_{\pm 0.03}$,$0.92_{\pm 0.04}$,$\bm{2.64}_{\pm 0.00}$
\texttt{rel-hm},\texttt{user-item-purchase},Test,$0.30$,$0.89$,$0.38_{\pm 0.02}$,$0.80_{\pm 0.03}$,$\bm{2.81}_{\pm 0.01}$


In [23]:
tex = tex_df.to_latex()
tex = tex.replace(r"\multirow[t]", r"\multirow[c]")
tex = tex.replace(r"\cline", r"\cmidrule")
tex = tex.replace(r"\cmidrule{1-8} \cmidrule{2-8}", r"\cmidrule{1-8}")
print(tex)

\begin{tabular}{llllllll}
\toprule
 &  &  & global_popularity & past_visit & lightgbm_link & gnn_link & idgnn_link \\
\midrule
\multirow[c]{6}{*}{\texttt{rel-amazon}} & \multirow[c]{2}{*}{\texttt{user-item-purchase}} & Val & $0.31$ & $0.07$ & $0.18_{\pm 0.07}$ & $\bm{1.53}_{\pm 0.05}$ & $0.13_{\pm 0.00}$ \\
 &  & Test & $0.24$ & $0.06$ & $0.16_{\pm 0.05}$ & $\bm{0.74}_{\pm 0.08}$ & $0.10_{\pm 0.00}$ \\
\cmidrule{2-8}
 & \multirow[c]{2}{*}{\texttt{user-item-rate}} & Val & $0.16$ & $0.09$ & $0.22_{\pm 0.02}$ & $\bm{1.42}_{\pm 0.06}$ & $0.15_{\pm 0.00}$ \\
 &  & Test & $0.15$ & $0.07$ & $0.17_{\pm 0.01}$ & $\bm{0.87}_{\pm 0.05}$ & $0.12_{\pm 0.00}$ \\
\cmidrule{2-8}
 & \multirow[c]{2}{*}{\texttt{user-item-review}} & Val & $0.18$ & $0.05$ & $0.14_{\pm 0.03}$ & $\bm{1.03}_{\pm 0.03}$ & $0.11_{\pm 0.00}$ \\
 &  & Test & $0.11$ & $0.04$ & $0.09_{\pm 0.01}$ & $\bm{0.47}_{\pm 0.05}$ & $0.09_{\pm 0.00}$ \\
\cmidrule{1-8}
\multirow[c]{2}{*}{\texttt{rel-avito}} & \multirow[c]{2}{*}{\texttt{user-ad

## leaderboard submission

In [17]:
lb_sub = defaultdict(lambda: defaultdict(dict))
for rec in table_data:
    dataset = rec["dataset"]
    task = rec["task"]
    script = rec["script"]
    split = rec["split"]
    mean = rec["mean"]
    std = rec["std"]

    lb_sub[script][split][f"{dataset}/{task}"] = [mean, std]
print(json.dumps(lb_sub, indent=2))

{
  "gnn_link": {
    "val": {
      "rel-amazon/user-item-purchase": [
        0.015283927917893046,
        0.00047989226257355865
      ],
      "rel-amazon/user-item-rate": [
        0.014150046798912927,
        0.0005896798920469688
      ],
      "rel-amazon/user-item-review": [
        0.010303252825183407,
        0.00030910400631337895
      ],
      "rel-avito/user-ad-visit": [
        0.0008804251753438116,
        0.00014153948955364384
      ],
      "rel-hm/user-item-purchase": [
        0.009152570570154709,
        0.000412390874606843
      ],
      "rel-stack/user-post-comment": [
        0.004282093811104037,
        0.0007942152203460955
      ],
      "rel-stack/post-post-related": [
        3.4622602100648414e-05,
        5.062257226463135e-05
      ],
      "rel-trial/condition-sponsor-run": [
        0.03124073766851678,
        0.0023553218658563736
      ],
      "rel-trial/site-sponsor-run": [
        0.14086476440215906,
        0.0077401140659883695
      