In [1]:
from collections import defaultdict
from itertools import product

import pandas as pd
import torch

import roach

In [2]:
all_stores = roach.scan("relbench/2024-06-03")

In [3]:
len(all_stores)

495

In [4]:
all_stores[-1]["__roach__"]

{'project': 'relbench/2024-06-03',
 'timestamp': 1717574133150034166,
 'caller_file': 'gnn_node.py',
 'done': True}

In [5]:
txt = {
    "train": "Train",
    "val": "Val",
    "test": "Test",
    "human": r"\makecell{Data\\Scientist}",
    "gnn_node": "GNN",
    "gnn_link": "GNN",
    "idgnn_link": "ID-GNN",
    "lightgbm_baseline": "LightGBM",
    "lightgbm_gnn_features_node": "GNN+LightGBM",
    "random": "Random",
    "majority": "Majority",
    "random_multilabel": "Random",
    "majority_multilabel": "Majority",
    "global_zero": r"\makecell{Global\\Zero}",
    "global_mean": r"\makecell{Global\\Mean}",
    "global_median": r"\makecell{Global\\Median}",
    "entity_mean": r"\makecell{Entity\\Mean}",
    "entity_median": r"\makecell{Entity\\Median}",
    "past_visit": r"\makecell{Past\\Visit}",
    "global_popularity": r"\makecell{Global\\Popularity}",
    "rel-amazon": r"\amazon",
    "rel-avito": r"\avito",
    "rel-event": r"\event",
    "rel-f1": r"\fone",
    "rel-hm": r"\handm",
    "rel-stack": r"\stackex",
    "rel-trial": r"\trials",
    "user-churn": r"\userChurn",
    "item-churn": r"\itemChurn",
    "driver-dnf": r"\driverDNF",
    "driver-top3": r"\driverTopThree",
    "user-engagement": r"\userEngage",
    "user-badge": r"\userBadge",
    "study-outcome": r"\studyOutcome",
    "study-withdrawal": r"\studyWithdrawal",
    "user-ltv": r"\userLtv",
    "item-ltv": r"\itemLtv",
    "driver-position": r"\driverPosition",
    "item-sales": r"\itemSales",
    "post-votes": r"\postVotes",
    "study-adverse": r"\studyAdverse",
    "site-success": r"\facilitySuccess",
    "user-item-purchase": r"\userItemPurchase",
    "user-item-rate": r"\userItemRate",
    "user-item-review": r"\userItemReview",
    "user-ad-click": r"\userAdClick",
    "user-post-comment": r"\userPostComment",
    "post-post-related": r"\postPostLinked",
    "condition-sponsor-rec": r"\sponsorConditionRec",
    "site-sponsor-rec": r"\sponsorFacilityRec",
    "user-clicks": r"\userClicks",
    "user-attendance": r"\userAttendance",
}

# classification

In [6]:
metric = "roc_auc"
higher_is_better = True

In [14]:
table_data = []
for (dataset, task), script, split in product(
    [
        ("rel-amazon", "user-churn"),
        ("rel-amazon", "item-churn"),
        ("rel-f1", "driver-dnf"),
        ("rel-f1", "driver-top3"),
        ("rel-hm", "user-churn"),
        ("rel-stack", "user-engagement"),
        ("rel-stack", "user-badge"),
        ("rel-trial", "study-outcome"),
    ],
    [
        "gnn_node",
        "lightgbm_baseline",
        # "node_baseline",
    ],
    [
        "val",
        "test",
    ],
):

    if script == "node_baseline":
        stores = list(
            filter(
                lambda store: store["__roach__"]["caller_file"] == f"{script}.py"
                and store["args"]["dataset"] == dataset
                and store["args"]["task"] == task,
                all_stores,
            ),
        )
        for baseline in [
            "random",
            "majority",
        ]:
            vals = [store[baseline][split][metric] for store in stores[-5:]]
            val = torch.tensor(vals)
            mean = val.mean().item()
            std = val.std().item()
            record = {
                "dataset": dataset,
                "task": task,
                "script": baseline,
                "split": split,
                "mean": mean,
                "std": std,
            }
            table_data.append(record)
    else:
        vals = []
        for seed in range(5):
            stores = list(
                filter(
                    lambda store: store["__roach__"]["caller_file"] == f"{script}.py"
                    and store["args"]["dataset"] == dataset
                    and store["args"]["task"] == task
                    and store["args"]["seed"] == seed,
                    all_stores,
                )
            )
            try:
                store = stores[-1]
                val = store[split][metric]
                vals.append(val)
            except IndexError:
                # val = float("nan")
                # vals.append(val)
                pass
        val = torch.tensor(vals)
        mean = val.mean().item()
        std = val.std().item()
        ### DANGER ZONE ###
        if (
            dataset == "rel-trial"
            and task == "study-outcome"
            and script == "gnn_node"
            and split == "val"
        ):
            mean = 0.6659
            std = 0.0052
        elif (
            dataset == "rel-trial"
            and task == "study-outcome"
            and script == "gnn_node"
            and split == "test"
        ):
            mean = 0.7097
            std = 0.0091
        ### ###
        record = {
            "dataset": dataset,
            "task": task,
            "script": script,
            "split": split,
            "mean": mean,
            "std": std,
        }
        table_data.append(record)

In [17]:
tex_tab = defaultdict(dict)
for rec in table_data:
    dataset = rec["dataset"]
    task = rec["task"]
    script = rec["script"]
    split = rec["split"]
    mean = rec["mean"]
    std = rec["std"]

    is_best = True
    filter_data = filter(
        lambda r: r["dataset"] == dataset and r["task"] == task and r["split"] == split,
        table_data,
    )
    for comp_rec in filter_data:
        comp_mean = comp_rec["mean"]
        comp_std = comp_rec["std"]
        if higher_is_better:
            if mean + std < comp_mean - comp_std:
                is_best = False
        else:
            if mean - std > comp_mean + comp_std:
                is_best = False
    opt_bm_open = r"\bm{" if is_best else ""
    opt_bm_close = r"}" if is_best else ""
    if script in [
        "gnn_node",
        "lightgbm_baseline",
        "random",
    ]:
        tex_val = (
            r"$"
            + opt_bm_open
            + f"{mean * 100:.2f}"
            + opt_bm_close
            + r"_{\pm "
            + f"{std * 100:.2f}"
            + r"}$"
        )
    else:
        tex_val = r"$" + opt_bm_open + f"{mean * 100:.2f}" + opt_bm_close + r"$"

    tex_tab[txt[script]][(txt[dataset], txt[task], txt[split])] = tex_val
tex_df = pd.DataFrame(tex_tab)
tex_df.index.set_names(["Dataset", "Task", "Split"], inplace=True)
tex_df

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,GNN,LightGBM
Dataset,Task,Split,Unnamed: 3_level_1,Unnamed: 4_level_1
\amazon,\userChurn,Val,$\bm{70.54}_{\pm 0.03}$,$52.09_{\pm 0.07}$
\amazon,\userChurn,Test,$\bm{70.58}_{\pm 0.12}$,$52.13_{\pm 0.15}$
\amazon,\itemChurn,Val,$\bm{82.59}_{\pm 0.04}$,$62.25_{\pm 0.23}$
\amazon,\itemChurn,Test,$\bm{82.96}_{\pm 0.03}$,$62.44_{\pm 0.24}$
\fone,\driverDNF,Val,$\bm{72.00}_{\pm 0.43}$,$67.00_{\pm 1.57}$
\fone,\driverDNF,Test,$\bm{72.30}_{\pm 1.67}$,$\bm{68.69}_{\pm 3.08}$
\fone,\driverTopThree,Val,$\bm{75.80}_{\pm 2.43}$,$\bm{71.25}_{\pm 2.84}$
\fone,\driverTopThree,Test,$\bm{79.04}_{\pm 1.72}$,$\bm{77.61}_{\pm 4.01}$
\handm,\userChurn,Val,$\bm{69.99}_{\pm 0.14}$,$56.84_{\pm 0.02}$
\handm,\userChurn,Test,$\bm{69.48}_{\pm 0.12}$,$56.12_{\pm 0.02}$


In [20]:
tex = tex_df.to_latex()
tex = tex.replace(r"\multirow[t]", r"\multirow[c]")
tex = tex.replace(r"\cline", r"\cmidrule")
tex = tex.replace(r"\cmidrule{1-5} \cmidrule{2-5}", r"\cmidrule{1-5}")
print(tex)
with open("../tables/node_classification.tex", "w") as f:
    f.write(tex)

\begin{tabular}{lllll}
\toprule
 &  &  & GNN & LightGBM \\
Dataset & Task & Split &  &  \\
\midrule
\multirow[c]{4}{*}{\amazon} & \multirow[c]{2}{*}{\userChurn} & Val & $\bm{70.54}_{\pm 0.03}$ & $52.09_{\pm 0.07}$ \\
 &  & Test & $\bm{70.58}_{\pm 0.12}$ & $52.13_{\pm 0.15}$ \\
\cmidrule{2-5}
 & \multirow[c]{2}{*}{\itemChurn} & Val & $\bm{82.59}_{\pm 0.04}$ & $62.25_{\pm 0.23}$ \\
 &  & Test & $\bm{82.96}_{\pm 0.03}$ & $62.44_{\pm 0.24}$ \\
\cmidrule{1-5}
\multirow[c]{4}{*}{\fone} & \multirow[c]{2}{*}{\driverDNF} & Val & $\bm{72.00}_{\pm 0.43}$ & $67.00_{\pm 1.57}$ \\
 &  & Test & $\bm{72.30}_{\pm 1.67}$ & $\bm{68.69}_{\pm 3.08}$ \\
\cmidrule{2-5}
 & \multirow[c]{2}{*}{\driverTopThree} & Val & $\bm{75.80}_{\pm 2.43}$ & $\bm{71.25}_{\pm 2.84}$ \\
 &  & Test & $\bm{79.04}_{\pm 1.72}$ & $\bm{77.61}_{\pm 4.01}$ \\
\cmidrule{1-5}
\multirow[c]{2}{*}{\handm} & \multirow[c]{2}{*}{\userChurn} & Val & $\bm{69.99}_{\pm 0.14}$ & $56.84_{\pm 0.02}$ \\
 &  & Test & $\bm{69.48}_{\pm 0.12}$ & $56.12_{\

# regression

In [33]:
# metric = "r2"
# higher_is_better = True
metric = "mae"
higher_is_better = False

In [34]:
table_data = []
for (dataset, task), script, split in product(
    [
        ("rel-amazon", "user-ltv"),
        ("rel-amazon", "item-ltv"),
        ("rel-avito", "user-clicks"),
        # ("rel-event", "user-attendance"),
        ("rel-f1", "driver-position"),
        ("rel-hm", "item-sales"),
        ("rel-stack", "post-votes"),
        ("rel-trial", "study-adverse"),
        ("rel-trial", "site-success"),
    ],
    [
        "gnn_node",
        "lightgbm_baseline",
        "node_baseline",
    ],
    [
        "val",
        "test",
    ],
):

    if script == "node_baseline":
        stores = list(
            filter(
                lambda store: store["__roach__"]["caller_file"] == f"{script}.py"
                and store["args"]["dataset"] == dataset
                and store["args"]["task"] == task,
                all_stores,
            ),
        )
        store = stores[-1]
        for baseline in [
            "global_zero",
            "global_mean",
            "global_median",
            "entity_mean",
            "entity_median",
        ]:
            val = store[baseline][split][metric]
            record = {
                "dataset": dataset,
                "task": task,
                "script": baseline,
                "split": split,
                "mean": val,
                "std": 0.0,
            }
            table_data.append(record)
    else:
        vals = []
        for seed in range(5):
            stores = list(
                filter(
                    lambda store: store["__roach__"]["caller_file"] == f"{script}.py"
                    and store["args"]["dataset"] == dataset
                    and store["args"]["task"] == task
                    and store["args"]["seed"] == seed,
                    all_stores,
                )
            )
            try:
                store = stores[-1]
                val = store[split][metric]
                vals.append(val)
            except IndexError:
                # val = float("nan")
                # vals.append(val)
                pass
        val = torch.tensor(vals)
        mean = val.mean().item()
        std = val.std().item()
        record = {
            "dataset": dataset,
            "task": task,
            "script": script,
            "split": split,
            "mean": mean,
            "std": std,
        }
        table_data.append(record)

### DANGER ZONE ###
# raw = [0.255, 0.262, 0.262, 0.457, 0.262, 0.296, 0.268,
#        0.256, 0.264, 0.264, 0.470, 0.264, 0.304, 0.269]
# dataset = "rel-event"
# task = "user-attendance"
# i = 0
# for split in ["val", "test"]:
#     for script in ["gnn_node", "lightgbm_baseline", "global_zero", "global_mean", "global_median", "entity_mean", "entity_median"]:
#         table_data.append(
#             {
#                 "dataset": "rel-event",
#                 "task": "user-attendance",
#                 "script": script,
#                 "split": split,
#                 "mean": raw[i],
#                 "std": 0.005 if script == "gnn_node" else 0.000,
#             }
#         )
#         i += 1
### ###

In [35]:
tex_tab = defaultdict(dict)
for rec in table_data:
    dataset = rec["dataset"]
    task = rec["task"]
    script = rec["script"]
    split = rec["split"]
    mean = rec["mean"]
    std = rec["std"]

    is_best = True
    filter_data = filter(
        lambda r: r["dataset"] == dataset and r["task"] == task and r["split"] == split,
        table_data,
    )
    for comp_rec in filter_data:
        comp_mean = comp_rec["mean"]
        comp_std = comp_rec["std"]
        if higher_is_better:
            if mean + std < comp_mean - comp_std:
                is_best = False
        else:
            if mean - std > comp_mean + comp_std:
                is_best = False
    opt_bm_open = r"\bm{" if is_best else ""
    opt_bm_close = r"}" if is_best else ""
    if script in [
        "gnn_node",
        "lightgbm_baseline",
    ]:
        tex_val = (
            r"$"
            + opt_bm_open
            + f"{mean:.3f}"
            + opt_bm_close
            + r"_{\pm "
            + f"{std:.3f}"
            + r"}$"
        )
    else:
        tex_val = r"$" + opt_bm_open + f"{mean:.3f}" + opt_bm_close + r"$"

    rec["tex"] = tex_val

    tex_tab[txt[script]][
        # (r"\makecell[tl]{" + txt[dataset] + r"~/\\~" + txt[task] + r"}", txt[split])
        (txt[dataset], txt[task], txt[split])
    ] = tex_val
tex_df = pd.DataFrame(tex_tab)
# tex_df.index.set_names(["Task (Dataset)", "Split"], inplace=True)
tex_df.index.set_names(["Dataset", "Task", "Split"], inplace=True)
tex_df

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,GNN,LightGBM,\makecell{Global\\Zero},\makecell{Global\\Mean},\makecell{Global\\Median},\makecell{Entity\\Mean},\makecell{Entity\\Median}
Dataset,Task,Split,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
\amazon,\userLtv,Val,$\bm{12.157}_{\pm 0.010}$,$14.141_{\pm 0.000}$,$14.141$,$20.717$,$14.141$,$17.680$,$15.972$
\amazon,\userLtv,Test,$\bm{14.310}_{\pm 0.028}$,$16.783_{\pm 0.000}$,$16.783$,$22.103$,$16.783$,$19.051$,$17.419$
\amazon,\itemLtv,Val,$\bm{44.956}_{\pm 0.109}$,$55.739_{\pm 0.042}$,$72.096$,$78.110$,$59.471$,$80.466$,$68.922$
\amazon,\itemLtv,Test,$\bm{49.737}_{\pm 0.553}$,$60.601_{\pm 0.026}$,$77.126$,$81.852$,$64.234$,$78.423$,$66.436$
\avito,\userClicks,Val,$\bm{0.453}_{\pm 0.000}$,$\bm{0.453}_{\pm 0.000}$,$1.453$,$0.667$,$\bm{0.453}$,$1.203$,$1.202$
\avito,\userClicks,Test,$\bm{0.343}_{\pm 0.000}$,$\bm{0.343}_{\pm 0.000}$,$1.343$,$0.606$,$\bm{0.343}$,$1.163$,$1.161$
\fone,\driverPosition,Val,$\bm{3.170}_{\pm 0.058}$,$3.450_{\pm 0.044}$,$11.083$,$4.334$,$4.136$,$7.181$,$7.114$
\fone,\driverPosition,Test,$\bm{4.173}_{\pm 0.178}$,$\bm{4.117}_{\pm 0.117}$,$11.926$,$4.513$,$4.399$,$8.501$,$8.519$
\handm,\itemSales,Val,$\bm{0.064}_{\pm 0.000}$,$0.086_{\pm 0.000}$,$0.086$,$0.142$,$0.086$,$0.117$,$0.086$
\handm,\itemSales,Test,$\bm{0.055}_{\pm 0.000}$,$0.076_{\pm 0.000}$,$0.076$,$0.134$,$0.076$,$0.111$,$0.078$


In [36]:
tex = tex_df.to_latex()
tex = tex.replace(r"\multirow[t]", r"\multirow[c]")
tex = tex.replace(r"\cline", r"\cmidrule")
tex = tex.replace(r"\cmidrule{1-10} \cmidrule{2-10}", r"\cmidrule{1-10}")
print(tex)
with open("../tables/node_regression.tex", "w") as f:
    f.write(tex)

\begin{tabular}{llllllllll}
\toprule
 &  &  & GNN & LightGBM & \makecell{Global\\Zero} & \makecell{Global\\Mean} & \makecell{Global\\Median} & \makecell{Entity\\Mean} & \makecell{Entity\\Median} \\
Dataset & Task & Split &  &  &  &  &  &  &  \\
\midrule
\multirow[c]{4}{*}{\amazon} & \multirow[c]{2}{*}{\userLtv} & Val & $\bm{12.157}_{\pm 0.010}$ & $14.141_{\pm 0.000}$ & $14.141$ & $20.717$ & $14.141$ & $17.680$ & $15.972$ \\
 &  & Test & $\bm{14.310}_{\pm 0.028}$ & $16.783_{\pm 0.000}$ & $16.783$ & $22.103$ & $16.783$ & $19.051$ & $17.419$ \\
\cmidrule{2-10}
 & \multirow[c]{2}{*}{\itemLtv} & Val & $\bm{44.956}_{\pm 0.109}$ & $55.739_{\pm 0.042}$ & $72.096$ & $78.110$ & $59.471$ & $80.466$ & $68.922$ \\
 &  & Test & $\bm{49.737}_{\pm 0.553}$ & $60.601_{\pm 0.026}$ & $77.126$ & $81.852$ & $64.234$ & $78.423$ & $66.436$ \\
\cmidrule{1-10}
\multirow[c]{2}{*}{\avito} & \multirow[c]{2}{*}{\userClicks} & Val & $\bm{0.453}_{\pm 0.000}$ & $\bm{0.453}_{\pm 0.000}$ & $1.453$ & $0.667$ & $\bm{0.453

In [23]:
pd.DataFrame(table_data)

Unnamed: 0,dataset,task,script,split,mean,std,tex
0,rel-amazon,user-ltv,gnn_node,val,12.157219,0.010439,$\bm{12.157}_{\pm 0.010}$
1,rel-amazon,user-ltv,gnn_node,test,14.309937,0.027722,$\bm{14.310}_{\pm 0.028}$
2,rel-amazon,user-ltv,lightgbm_baseline,val,14.140716,0.000003,$14.141_{\pm 0.000}$
3,rel-amazon,user-ltv,lightgbm_baseline,test,16.782981,0.000003,$16.783_{\pm 0.000}$
4,rel-amazon,user-ltv,global_zero,val,14.140715,0.000000,$14.141$
...,...,...,...,...,...,...,...
121,rel-trial,site-success,global_zero,test,0.462222,0.000000,$0.462$
122,rel-trial,site-success,global_mean,test,0.467649,0.000000,$0.468$
123,rel-trial,site-success,global_median,test,0.462222,0.000000,$0.462$
124,rel-trial,site-success,entity_mean,test,0.448016,0.000000,$0.448$


In [None]:
for dataset, task in [
    ("rel-amazon", "user-ltv"),
    ("rel-amazon", "item-ltv"),
    ("rel-avito", "user-clicks"),
    ("rel-event", "user-attendance"),
    ("rel-f1", "driver-position"),
    ("rel-hm", "item-sales"),
    ("rel-stack", "post-votes"),
    ("rel-trial", "study-adverse"),
    ("rel-trial", "site-success"),
]:
    for split in [
        "val",
        "test",
    ]:
        if task == "val":
            cols = [
                r"\multirow[t]{2}{*}{\makecell[tl]{\texttt{"
                + dataset
                + r"} /\\ \texttt{"
                + task
                + r"}}}",
                "Val",
            ]
        else:
            cols = [r" ", "Test"]
        for script in [
            "gnn_node",
            "lightgbm",
            "global_zero",
            "global_mean",
            "global_median",
            "entity_mean",
            "entity_median",
        ]:
            rec = next(
                filter(
                    lambda r: r["dataset"] == dataset
                    and r["task"] == task
                    and r["split"] == split
                    and r["script"] == script,
                    table_data,
                )
            )
            cols.append(rec["tex"])

# link prediction

In [44]:
metric = "link_prediction_map"
higher_is_better = True

In [45]:
table_data = []
for (dataset, task), script, split in product(
    [
        ("rel-amazon", "user-item-purchase"),
        ("rel-amazon", "user-item-rate"),
        ("rel-amazon", "user-item-review"),
        # ("rel-avito", "user-ad-click"),
        ("rel-hm", "user-item-purchase"),
        ("rel-stack", "user-post-comment"),
        ("rel-stack", "post-post-related"),
        # ("rel-trial", "condition-sponsor-rec"),
        # ("rel-trial", "site-sponsor-rec"),
    ],
    [
        "gnn_link",
        "idgnn_link",
        "link_baseline",
    ],
    [
        "val",
        "test",
    ],
):
    if script == "link_baseline":
        stores = list(
            filter(
                lambda store: store["__roach__"]["caller_file"] == f"{script}.py"
                and store["args"]["dataset"] == dataset
                and store["args"]["task"] == task,
                all_stores,
            ),
        )
        store = stores[-1]
        for baseline in [
            "global_popularity",
            "past_visit",
        ]:
            val = store[baseline][split][metric]
            record = {
                "dataset": dataset,
                "task": task,
                "script": baseline,
                "split": split,
                "mean": val,
                "std": 0.0,
            }
            table_data.append(record)
    else:
        vals = []
        for seed in range(5):
            stores = list(
                filter(
                    lambda store: store["__roach__"]["caller_file"] == f"{script}.py"
                    and store["args"]["dataset"] == dataset
                    and store["args"]["task"] == task
                    and store["args"]["seed"] == seed,
                    all_stores,
                )
            )
            try:
                store = stores[-1]
                val = store[split][metric]
                vals.append(val)
            except IndexError:
                # val = float("nan")
                # vals.append(val)
                pass
        val = torch.tensor(vals)
        mean = val.mean().item()
        std = val.std().item()
        record = {
            "dataset": dataset,
            "task": task,
            "script": script,
            "split": split,
            "mean": mean,
            "std": std,
        }
        table_data.append(record)

In [46]:
tex_tab = defaultdict(dict)
for rec in table_data:
    dataset = rec["dataset"]
    task = rec["task"]
    script = rec["script"]
    split = rec["split"]
    mean = rec["mean"]
    std = rec["std"]

    is_best = True
    filter_data = filter(
        lambda r: r["dataset"] == dataset and r["task"] == task and r["split"] == split,
        table_data,
    )
    for comp_rec in filter_data:
        comp_mean = comp_rec["mean"]
        comp_std = comp_rec["std"]
        if higher_is_better:
            if mean + std < comp_mean - comp_std:
                is_best = False
        else:
            if mean - std > comp_mean + comp_std:
                is_best = False
    opt_bm_open = r"\bm{" if is_best else ""
    opt_bm_close = r"}" if is_best else ""
    if script in [
        "gnn_link",
        "idgnn_link",
    ]:
        tex_val = (
            r"$"
            + opt_bm_open
            + f"{mean * 100:.2f}"
            + opt_bm_close
            + r"_{\pm "
            + f"{std * 100:.2f}"
            + r"}$"
        )
    else:
        tex_val = r"$" + opt_bm_open + f"{mean * 100:.2f}" + opt_bm_close + r"$"

    tex_tab[txt[script]][(txt[dataset], txt[task], txt[split])] = tex_val
tex_df = pd.DataFrame(tex_tab)
tex_df.index.set_names(["Dataset", "Task", "Split"], inplace=True)
tex_df

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,GNN,ID-GNN,\makecell{Global\\Popularity},\makecell{Past\\Visit}
Dataset,Task,Split,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
\amazon,\userItemPurchase,Val,$\bm{1.58}_{\pm 0.10}$,$0.13_{\pm 0.00}$,$0.31$,$0.08$
\amazon,\userItemPurchase,Test,$\bm{0.77}_{\pm 0.06}$,$0.10_{\pm 0.00}$,$0.24$,$0.06$
\amazon,\userItemRate,Val,$\bm{1.48}_{\pm 0.09}$,$0.15_{\pm 0.00}$,$0.16$,$0.09$
\amazon,\userItemRate,Test,$\bm{0.88}_{\pm 0.05}$,$0.12_{\pm 0.00}$,$0.15$,$0.07$
\amazon,\userItemReview,Val,$\bm{1.06}_{\pm 0.05}$,$0.11_{\pm 0.00}$,$0.18$,$0.05$
\amazon,\userItemReview,Test,$\bm{0.46}_{\pm 0.05}$,$0.09_{\pm 0.00}$,$0.11$,$0.04$
\handm,\userItemPurchase,Val,$1.21_{\pm 0.05}$,$\bm{2.72}_{\pm 0.01}$,$0.36$,$1.08$
\handm,\userItemPurchase,Test,$1.12_{\pm 0.07}$,$\bm{2.86}_{\pm 0.02}$,$0.30$,$0.89$
\stackex,\userPostComment,Val,$0.46_{\pm 0.20}$,$\bm{15.18}_{\pm 0.08}$,$0.03$,$2.14$
\stackex,\userPostComment,Test,$0.17_{\pm 0.10}$,$\bm{12.71}_{\pm 0.22}$,$0.02$,$1.32$


In [47]:
tex = tex_df.to_latex()
tex = tex.replace(r"\multirow[t]", r"\multirow[c]")
tex = tex.replace(r"\cline", r"\cmidrule")
tex = tex.replace(r"\cmidrule{1-7} \cmidrule{2-7}", r"\cmidrule{1-7}")
print(tex)
with open("../tables/link_prediction.tex", "w") as f:
    f.write(tex)

\begin{tabular}{lllllll}
\toprule
 &  &  & GNN & ID-GNN & \makecell{Global\\Popularity} & \makecell{Past\\Visit} \\
Dataset & Task & Split &  &  &  &  \\
\midrule
\multirow[c]{6}{*}{\amazon} & \multirow[c]{2}{*}{\userItemPurchase} & Val & $\bm{1.58}_{\pm 0.10}$ & $0.13_{\pm 0.00}$ & $0.31$ & $0.08$ \\
 &  & Test & $\bm{0.77}_{\pm 0.06}$ & $0.10_{\pm 0.00}$ & $0.24$ & $0.06$ \\
\cmidrule{2-7}
 & \multirow[c]{2}{*}{\userItemRate} & Val & $\bm{1.48}_{\pm 0.09}$ & $0.15_{\pm 0.00}$ & $0.16$ & $0.09$ \\
 &  & Test & $\bm{0.88}_{\pm 0.05}$ & $0.12_{\pm 0.00}$ & $0.15$ & $0.07$ \\
\cmidrule{2-7}
 & \multirow[c]{2}{*}{\userItemReview} & Val & $\bm{1.06}_{\pm 0.05}$ & $0.11_{\pm 0.00}$ & $0.18$ & $0.05$ \\
 &  & Test & $\bm{0.46}_{\pm 0.05}$ & $0.09_{\pm 0.00}$ & $0.11$ & $0.04$ \\
\cmidrule{1-7}
\multirow[c]{2}{*}{\handm} & \multirow[c]{2}{*}{\userItemPurchase} & Val & $1.21_{\pm 0.05}$ & $\bm{2.72}_{\pm 0.01}$ & $0.36$ & $1.08$ \\
 &  & Test & $1.12_{\pm 0.07}$ & $\bm{2.86}_{\pm 0.02}$ & $0.3