In [1]:
import json
from collections import defaultdict
from itertools import product

import pandas as pd
import roach
import torch

from relbench.base import TaskType
from relbench.datasets import get_dataset_names
from relbench.tasks import get_task, get_task_names

In [2]:
all_stores = roach.scan("relbench/2024-07-05")

In [3]:
len(all_stores)

600

In [4]:
all_stores[-1]["__roach__"]

{'project': 'relbench/2024-07-05',
 'timestamp': 1720395884026350988,
 'caller_file': 'idgnn_link.py',
 'done': True}

In [17]:
def wrap(name):
    # return r"\texttt{" + name + r"}"
    return name

In [6]:
txt = {
    "val": "Val",
    "test": "Test",
}

# link prediction

In [7]:
metric = "link_prediction_map"
higher_is_better = True

In [15]:
table_data = []
for dataset in get_dataset_names():
    for task in get_task_names(dataset):
        task_obj = get_task(dataset, task)
        if task_obj.task_type.value != TaskType.LINK_PREDICTION.value:
            continue
        for script in [
            "gnn_link",
            "idgnn_link",
            "lightgbm_link",
            "baseline_link",
        ]:
            for split in [
                "val",
                "test",
            ]:
                if script == "baseline_link":
                    stores = list(
                        filter(
                            lambda store: store["__roach__"]["caller_file"]
                            == f"{script}.py"
                            and store["args"]["dataset"] == dataset
                            and store["args"]["task"] == task,
                            all_stores,
                        ),
                    )
                    for baseline in [
                        "global_popularity",
                        "past_visit",
                    ]:
                        try:
                            store = stores[-1]
                            val = store[baseline][split][metric]
                        except IndexError:
                            val = float("nan")
                        record = {
                            "dataset": dataset,
                            "task": task,
                            "script": baseline,
                            "split": split,
                            "mean": val,
                            "std": 0.0,
                        }
                        table_data.append(record)
                else:
                    vals = []
                    for seed in range(5):
                        stores = list(
                            filter(
                                lambda store: store["__roach__"]["caller_file"]
                                == f"{script}.py"
                                and store["args"]["dataset"] == dataset
                                and store["args"]["task"] == task
                                and store["args"]["seed"] == seed,
                                all_stores,
                            )
                        )
                        try:
                            store = stores[-1]
                            val = store[split][metric]
                            vals.append(val)
                        except IndexError:
                            # val = float("nan")
                            # vals.append(val)
                            pass
                    val = torch.tensor(vals)
                    mean = val.mean().item()
                    std = val.std().item()
                    record = {
                        "dataset": dataset,
                        "task": task,
                        "script": script,
                        "split": split,
                        "mean": mean,
                        "std": std,
                    }
                    table_data.append(record)

In [18]:
tex_tab = defaultdict(dict)
for rec in table_data:
    dataset = rec["dataset"]
    task = rec["task"]
    script = rec["script"]
    split = rec["split"]
    mean = rec["mean"]
    std = rec["std"]

    is_best = True
    filter_data = filter(
        lambda r: r["dataset"] == dataset
        and r["task"] == task
        and r["split"] == split
        and r["script"] != "hybrid_node",
        table_data,
    )
    for comp_rec in filter_data:
        comp_mean = comp_rec["mean"]
        comp_std = comp_rec["std"]
        if higher_is_better:
            if mean + std < comp_mean - comp_std:
                is_best = False
        else:
            if mean - std > comp_mean + comp_std:
                is_best = False
    opt_bm_open = r"\bm{" if is_best else ""
    opt_bm_close = r"}" if is_best else ""
    if script in [
        "gnn_link",
        "idgnn_link",
        "lightgbm_link",
    ]:
        tex_val = (
            r"$"
            + opt_bm_open
            + f"{mean * 100:.2f}"
            + opt_bm_close
            + r"_{\pm "
            + f"{std * 100:.2f}"
            + r"}$"
        )
    else:
        tex_val = r"$" + opt_bm_open + f"{mean * 100:.2f}" + opt_bm_close + r"$"

    tex_tab[script][(wrap(dataset), wrap(task), txt[split])] = tex_val
tex_df = pd.DataFrame(tex_tab)
tex_df

Unnamed: 0,Unnamed: 1,Unnamed: 2,gnn_link,idgnn_link,lightgbm_link,global_popularity,past_visit
rel-amazon,user-item-purchase,Val,$\bm{2.49}_{\pm 0.14}$,$0.19_{\pm 0.00}$,$0.47_{\pm 0.12}$,$0.62$,$0.12$
rel-amazon,user-item-purchase,Test,$\bm{1.49}_{\pm 0.13}$,$0.16_{\pm 0.00}$,$0.35_{\pm 0.08}$,$0.49$,$0.10$
rel-amazon,user-item-rate,Val,$\bm{2.40}_{\pm 0.14}$,$0.22_{\pm 0.00}$,$0.59_{\pm 0.13}$,$0.37$,$0.14$
rel-amazon,user-item-rate,Test,$\bm{1.54}_{\pm 0.18}$,$0.18_{\pm 0.00}$,$0.45_{\pm 0.07}$,$0.48$,$0.12$
rel-amazon,user-item-review,Val,$\bm{1.93}_{\pm 0.09}$,$0.15_{\pm 0.00}$,$0.40_{\pm 0.14}$,$0.37$,$0.07$
rel-amazon,user-item-review,Test,$\bm{0.97}_{\pm 0.07}$,$0.13_{\pm 0.00}$,$0.25_{\pm 0.06}$,$0.25$,$0.05$
rel-avito,user-ad-visit,Val,$0.08_{\pm 0.02}$,$\bm{5.40}_{\pm 0.02}$,$0.17_{\pm 0.01}$,$0.01$,$3.66$
rel-avito,user-ad-visit,Test,$0.02_{\pm 0.00}$,$\bm{3.66}_{\pm 0.02}$,$0.06_{\pm 0.01}$,$0.00$,$1.95$
rel-hm,user-item-purchase,Val,$0.92_{\pm 0.04}$,$\bm{2.64}_{\pm 0.00}$,$0.44_{\pm 0.03}$,$0.36$,$1.07$
rel-hm,user-item-purchase,Test,$0.80_{\pm 0.03}$,$\bm{2.81}_{\pm 0.01}$,$0.38_{\pm 0.02}$,$0.30$,$0.89$


In [13]:
tex = tex_df.to_latex()
tex = tex.replace(r"\multirow[t]", r"\multirow[c]")
tex = tex.replace(r"\cline", r"\cmidrule")
tex = tex.replace(r"\cmidrule{1-7} \cmidrule{2-7}", r"\cmidrule{1-7}")
print(tex)

\begin{tabular}{llllllll}
\toprule
 &  &  & gnn_link & idgnn_link & lightgbm_link & global_popularity & past_visit \\
\midrule
\multirow[c]{6}{*}{\texttt{rel-amazon}} & \multirow[c]{2}{*}{\texttt{user-item-purchase}} & Val & $\bm{2.49}_{\pm 0.14}$ & $0.19_{\pm 0.00}$ & $0.47_{\pm 0.12}$ & $0.62$ & $0.12$ \\
 &  & Test & $\bm{1.49}_{\pm 0.13}$ & $0.16_{\pm 0.00}$ & $0.35_{\pm 0.08}$ & $0.49$ & $0.10$ \\
\cmidrule{2-8}
 & \multirow[c]{2}{*}{\texttt{user-item-rate}} & Val & $\bm{2.40}_{\pm 0.14}$ & $0.22_{\pm 0.00}$ & $0.59_{\pm 0.13}$ & $0.37$ & $0.14$ \\
 &  & Test & $\bm{1.54}_{\pm 0.18}$ & $0.18_{\pm 0.00}$ & $0.45_{\pm 0.07}$ & $0.48$ & $0.12$ \\
\cmidrule{2-8}
 & \multirow[c]{2}{*}{\texttt{user-item-review}} & Val & $\bm{1.93}_{\pm 0.09}$ & $0.15_{\pm 0.00}$ & $0.40_{\pm 0.14}$ & $0.37$ & $0.07$ \\
 &  & Test & $\bm{0.97}_{\pm 0.07}$ & $0.13_{\pm 0.00}$ & $0.25_{\pm 0.06}$ & $0.25$ & $0.05$ \\
\cmidrule{1-8} \cmidrule{2-8}
\multirow[c]{2}{*}{\texttt{rel-avito}} & \multirow[c]{2}{*}{

In [14]:
lb_sub = defaultdict(dict)
for rec in table_data:
    dataset = rec["dataset"]
    task = rec["task"]
    script = rec["script"]
    split = rec["split"]
    mean = rec["mean"]
    std = rec["std"]

    if split != "test":
        continue

    lb_sub[script][f"{dataset}/{task}"] = [mean, std]
print(json.dumps(lb_sub, indent=2))

{
  "gnn_link": {
    "rel-amazon/user-item-purchase": [
      0.0148619023444159,
      0.0013048609271184338
    ],
    "rel-amazon/user-item-rate": [
      0.015441215843968164,
      0.001814221070691635
    ],
    "rel-amazon/user-item-review": [
      0.009737415034373373,
      0.0006692742440980297
    ],
    "rel-avito/user-ad-visit": [
      NaN,
      NaN
    ],
    "rel-hm/user-item-purchase": [
      0.008030757277855417,
      0.00029190920355524246
    ],
    "rel-stack/user-post-comment": [
      0.001104691159162458,
      0.0004739540101185633
    ],
    "rel-stack/post-post-related": [
      0.0007475914805332113,
      0.0007792807950739197
    ],
    "rel-trial/condition-sponsor-run": [
      0.028945028883913527,
      0.003860065235997204
    ],
    "rel-trial/site-sponsor-run": [
      0.10698655189731872,
      0.010972726531651095
    ]
  },
  "idgnn_link": {
    "rel-amazon/user-item-purchase": [
      0.0016284186580703213,
      9.544074360059546e-06
    ],