In [1]:
from collections import defaultdict
from functools import lru_cache
from pathlib import Path

import ir_datasets
import ir_measures
import numpy as np
import pandas as pd
import pyterrier as pt
from tqdm import tqdm

DASHED_DATASET_MAP = {dataset.replace("/", "-"): dataset for dataset in ir_datasets.registry._registered}

In [2]:
@lru_cache(maxsize=1024)
def load_run(run_file):
    run = pd.read_csv(
        run_file,
        sep=r"\s+",
        header=None,
        names=["query_id", "Q0", "doc_id", "rank", "score", "run_name"],
        dtype={"query_id": str, "doc_id": str},
    )
    run = run.sort_values(["query_id", "score"], ascending=[True, False])
    run["rank"] = run.groupby("query_id")["score"].rank(ascending=False, method="first").astype(int)
    config = {}
    config["run_name"] = run_file.parent.stem
    config["dataset"] = DASHED_DATASET_MAP[run_file.stem]

    if "touche" in config["dataset"]:
        config["dataset"] = "beir/webis-touche2020/v2"
    run = clean_run(run)
    return run, config


@lru_cache
def load_qrels(dataset):
    dataset = ir_datasets.load(dataset)
    qrels_df = pd.DataFrame(dataset.qrels_iter())
    return qrels_df


@lru_cache
def load_topics(dataset):
    dataset = ir_datasets.load(dataset)
    topics_df = pd.DataFrame(dataset.queries_iter())
    return topics_df

def clean_run(run):
    run = run.loc[run["doc_id"] != run["query_id"]]
    run = run.groupby("query_id").head(100)
    return run

In [3]:
def evaluate_runs(per_dataset_runs, baseline, ir_measures_metrics, models):
    _results = []
    filtered_per_dataset_runs = defaultdict(dict)
    for dataset, runs_dict in per_dataset_runs.items():
        for run_name, run in runs_dict.items():
            if run_name in models:
                filtered_per_dataset_runs[dataset][run_name] = run
    pg = tqdm(filtered_per_dataset_runs.items())
    for dataset, runs_dict in pg:
        pg.set_description(dataset)
        qrels = load_qrels(dataset).rename(
            columns={"query_id": "qid", "doc_id": "docno", "relevance": "label", "subtopic_id": "iteration"}
        )
        topics = load_topics(dataset).rename(columns={"query_id": "qid"})
        run_names, runs = zip(*runs_dict.items())
        try:
            base_index = run_names.index(baseline)
            kwargs = {"baseline": base_index, "correction": "holm"}
        except ValueError:
            kwargs = {}
        runs = [run.rename(columns={"query_id": "qid", "doc_id": "docno"}).astype({"docno": str}) for run in runs]
        result = pt.Experiment(
            runs, topics, qrels, ir_measures_metrics, names=run_names, **kwargs
        )
        result["dataset"] = dataset
        _results.append(result)

    results = pd.concat(_results)

    # group cqadupstack runs
    cqadupstack = results[results["dataset"].str.contains("cqadupstack")]
    cqadupstack_average = cqadupstack.groupby("name")[metrics].mean().reset_index()
    cqadupstack_average["dataset"] = "beir/cqadupstack"
    results = results.loc[~results["dataset"].str.contains("cqadupstack")]
    results = pd.concat([results, cqadupstack_average])


    # compute beir averages
    def gmean(x):
        return np.exp(np.log(x.prod()) / x.notna().sum())


    beir_results = results.loc[results["dataset"].str.contains("beir")]
    arithmetic_mean = beir_results.groupby("name")[metrics].mean().reset_index()
    arithmetic_mean["dataset"] = "beir/arithmetic-mean"
    geometric_mean = beir_results.groupby("name")[metrics].agg(gmean).reset_index()
    geometric_mean["dataset"] = "beir/geometric-mean"
    beir_summary = pd.concat([arithmetic_mean, geometric_mean])

    results = pd.concat([results, beir_summary])

    return results

In [11]:
EXPERIMENT_DIR = Path.cwd().parent.resolve() / "experiments"

per_dataset_runs = defaultdict(dict)

run_files = list(EXPERIMENT_DIR.glob("*/*.run"))

models = set()
pg = tqdm(run_files)
for run_file in pg:
    pg.set_description(str(run_file.relative_to(EXPERIMENT_DIR)))
    run, config = load_run(run_file)
    models.add(config["run_name"])
    per_dataset_runs[config["dataset"]][config["run_name"]] = run
models

tite-2-late-pre/beir-cqadupstack-webmasters.run: 100%|██████████| 487/487 [16:14<00:00,  2.00s/it]                      


{'bert-mae-bow',
 'bert-mlm',
 'bm25-flat',
 'colbertv2',
 'msmarco-bert-base-dot-v5',
 'msmarco-distilbert-dot-v5',
 'retromae',
 'retromae-repro',
 'splade-pp-ed',
 'tite-2-late-intra',
 'tite-2-late-intra-bow',
 'tite-2-late-intra-higher-dims',
 'tite-2-late-intra-mae',
 'tite-2-late-post',
 'tite-2-late-pre',
 'tite-2-staggered',
 'tite-3-late',
 'tite-3-staggered'}

In [29]:
result_models = [
    "bm25-flat",
    "bert-mlm",
    "msmarco-bert-base-dot-v5",
    "msmarco-distilbert-dot-v5",
    "retromae-repro",
    "retromae",
    "colbertv2",
    "splade-pp-ed",
    "tite-2-late-intra",
    "tite-2-staggered",
    "tite-3-late",
    "tite-3-staggered",
    "tite-2-late-pre",
    "tite-2-late-post",
    "tite-2-late-intra-higher-dims",
]

ablation_models = [
    "bert-mlm",
    "retromae-repro",
    "bert-mae-bow",
    "tite-2-late-intra",
    "tite-2-late-intra-bow",
    "tite-2-late-intra-mae",
]

In [30]:
ir_measures_metrics = [
    ir_measures.nDCG(cutoff=10),
    ir_measures.Recall(cutoff=100),
]
metrics = [str(metric) for metric in ir_measures_metrics]

results = evaluate_runs(per_dataset_runs, "bert-mlm", ir_measures_metrics, result_models)
ablation_results = evaluate_runs(per_dataset_runs, "bert-mlm", ir_measures_metrics, ablation_models)
results


beir/cqadupstack/webmasters: 100%|██████████| 27/27 [00:57<00:00,  2.11s/it]        


Unnamed: 0,name,R@100,nDCG@10,R@100 +,R@100 -,R@100 p-value,R@100 reject,R@100 p-value corrected,nDCG@10 +,nDCG@10 -,nDCG@10 p-value,nDCG@10 reject,nDCG@10 p-value corrected,dataset
0,tite-3-staggered,0.913925,0.660462,188.0,247.0,4.619205e-04,True,2.771523e-03,1099.0,1973.0,1.016752e-63,True,1.118427e-62,beir/fever/test
1,tite-2-late-intra,0.924770,0.698559,215.0,202.0,4.569090e-01,False,9.138180e-01,1265.0,1697.0,1.167278e-13,True,3.501833e-13,beir/fever/test
2,bm25-flat,0.918516,0.651321,321.0,319.0,1.536085e-01,False,4.608256e-01,1416.0,2354.0,9.773184e-51,True,8.795865e-50,beir/fever/test
3,msmarco-distilbert-dot-v5,0.940828,0.773761,247.0,145.0,5.342597e-14,True,6.411117e-13,1684.0,1068.0,1.370241e-38,True,8.221445e-38,beir/fever/test
4,tite-3-late,0.904699,0.642826,163.0,291.0,5.311499e-12,True,5.842649e-11,1018.0,2137.0,2.614387e-96,True,3.660142e-95,beir/fever/test
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10,tite-2-late-post,0.546619,0.399472,,,,,,,,,,,beir/geometric-mean
11,tite-2-late-pre,0.546899,0.400425,,,,,,,,,,,beir/geometric-mean
12,tite-2-staggered,0.543888,0.396615,,,,,,,,,,,beir/geometric-mean
13,tite-3-late,0.548902,0.400309,,,,,,,,,,,beir/geometric-mean


In [21]:
datasets = results["dataset"].drop_duplicates().sort_values()
datasets = datasets.loc[~datasets.str.contains("mean")].tolist()
sig_datasets = datasets[:]
sig_datasets.remove("beir/cqadupstack")
datasets.append("beir/arithmetic-mean")
datasets.append("beir/geometric-mean")
non_sig_datasets = list(set(datasets) - set(sig_datasets))
datasets

['beir/arguana',
 'beir/climate-fever',
 'beir/cqadupstack',
 'beir/dbpedia-entity/test',
 'beir/fever/test',
 'beir/fiqa/test',
 'beir/hotpotqa/test',
 'beir/nfcorpus/test',
 'beir/nq',
 'beir/quora/test',
 'beir/scidocs',
 'beir/scifact/test',
 'beir/trec-covid',
 'beir/webis-touche2020/v2',
 'msmarco-passage/trec-dl-2019/judged',
 'msmarco-passage/trec-dl-2020/judged',
 'beir/arithmetic-mean',
 'beir/geometric-mean']

In [22]:
datasets_ordered = [
    "msmarco-passage/trec-dl-2019/judged",
    "msmarco-passage/trec-dl-2020/judged",
    "beir/arguana",
    "beir/climate-fever",
    "beir/cqadupstack",
    "beir/dbpedia-entity/test",
    "beir/fever/test",
    "beir/fiqa/test",
    "beir/hotpotqa/test",
    "beir/nfcorpus/test",
    "beir/nq",
    "beir/quora/test",
    "beir/scidocs",
    "beir/scifact/test",
    "beir/trec-covid",
    "beir/webis-touche2020/v2",
    "beir/arithmetic-mean",
    "beir/geometric-mean",
]
datasets_ablation = [
    "msmarco-passage/trec-dl-2019/judged",
    "msmarco-passage/trec-dl-2020/judged",
    "beir/arithmetic-mean",
    "beir/geometric-mean",
]

In [23]:
def format_columns(series, round_to=3):
    drop_rows = ["cohere-embed-english-v3", "bge-base-en-v1"]
    rounded = series.round(round_to)
    values = sorted(set(rounded.drop(drop_rows, axis=0, errors="ignore").dropna()), reverse=True)
    best = values[0]
    out_values = []
    for val in rounded:
        if pd.isna(val):
            out_val = "--"
        else:
            out_val = f"{val:.{round_to}f}"
        if val == best:
            out_val = "\\textbf{" + out_val + "}"
        out_values.append(out_val)
    out = pd.Series(out_values, index=series.index)
    return out

In [28]:
index = "name"
columns = "dataset"
values = ["nDCG@10"]

pd.set_option("display.max_columns", None)

table = (
    results.pivot_table(index=index, columns=columns, values=values).apply(format_columns, axis=0).droplevel(0, axis=1)
)
try:
    sig = (
        results.pivot_table(index=index, columns=columns, values=[f"{value} reject" for value in values])
        .loc[:, pd.IndexSlice["nDCG@10 reject", sig_datasets]]
        .replace({1: "*", 0: ""})
        .droplevel(0, axis=1)
    )
    # sig.set_axis(table.columns, axis=1)
    sig = sig.assign(**{non_sig_dataset: "" for non_sig_dataset in non_sig_datasets})
except:
    sig = pd.DataFrame("", index=table.index, columns=table.columns)
table_sig = (table + sig).loc[:, datasets]
table_sig = table_sig.reindex(labels=result_models, columns=datasets_ordered)
missing_models = table_sig.loc[table_sig.isna().any(axis=1)].index
if len(missing_models) > 0:
    print(f"Warning: The following models are missing in the table: {', '.join(missing_models)}")
table_sig = table_sig.dropna()
table_sig

dataset,msmarco-passage/trec-dl-2019/judged,msmarco-passage/trec-dl-2020/judged,beir/arguana,beir/climate-fever,beir/cqadupstack,beir/dbpedia-entity/test,beir/fever/test,beir/fiqa/test,beir/hotpotqa/test,beir/nfcorpus/test,beir/nq,beir/quora/test,beir/scidocs,beir/scifact/test,beir/trec-covid,beir/webis-touche2020/v2,beir/arithmetic-mean,beir/geometric-mean
name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
bm25-flat,0.506*,0.480*,0.397*,0.165*,0.302,0.318*,0.651*,0.236*,0.633*,0.322,0.305*,0.789*,0.149,0.679*,0.595*,\textbf{0.442}*,0.427,0.379
bert-mlm,0.700,0.688,0.336,0.224,0.319,0.369,0.727,0.317,0.574,0.303,0.510,0.844,0.146,0.603,0.756,0.256,0.449,0.399
msmarco-bert-base-dot-v5,0.705,0.726,0.384*,0.221,0.337,0.385,0.762*,0.323,0.585*,0.315,0.522*,0.844,0.146,0.606,0.744,0.237,0.458,0.407
msmarco-distilbert-dot-v5,0.705,0.699,0.355*,0.233,0.322,0.375,0.774*,0.286*,0.571,0.298,0.497*,0.833*,0.140,0.596,0.666*,0.224,0.441,0.391
retromae-repro,0.723,0.711,0.375*,\textbf{0.242}*,0.340,0.406*,0.737*,0.340*,0.624*,0.336*,0.539*,0.844,\textbf{0.163}*,0.663*,\textbf{0.780},0.273,0.476,0.428
retromae,0.712,\textbf{0.730},0.367*,0.240*,0.342,0.428*,0.777*,0.343*,0.668*,0.325*,\textbf{0.573}*,\textbf{0.853}*,0.160*,0.638,0.759,0.280,0.482,0.432
colbertv2,\textbf{0.732},0.724,0.453*,0.176*,\textbf{0.359},\textbf{0.441}*,0.774*,0.346*,0.665*,0.330*,0.547*,0.851*,0.150,0.691*,0.732,0.257,0.484,0.427
splade-pp-ed,0.731,0.720,\textbf{0.520}*,0.230,0.334,0.437*,\textbf{0.788}*,\textbf{0.347}*,\textbf{0.687}*,\textbf{0.347}*,0.538*,0.834*,0.159*,\textbf{0.704}*,0.727,0.247,\textbf{0.493},\textbf{0.440}
tite-2-late-intra,0.705,0.670,0.391*,0.204*,0.312,0.376,0.699*,0.302,0.604*,0.334*,0.484*,0.818*,0.156,0.647*,0.691,0.271,0.449,0.403
tite-2-staggered,0.675,0.663,0.387*,0.199*,0.309,0.351,0.675*,0.288*,0.599*,0.324*,0.470*,0.823*,0.154,0.632,0.708,0.283,0.443,0.397


In [25]:
print(table_sig.to_latex().replace("0.", ".").replace("*", "\\kernSig"))

\begin{tabular}{lllllllllllllllllll}
\toprule
dataset & msmarco-passage/trec-dl-2019/judged & msmarco-passage/trec-dl-2020/judged & beir/arguana & beir/climate-fever & beir/cqadupstack & beir/dbpedia-entity/test & beir/fever/test & beir/fiqa/test & beir/hotpotqa/test & beir/nfcorpus/test & beir/nq & beir/quora/test & beir/scidocs & beir/scifact/test & beir/trec-covid & beir/webis-touche2020/v2 & beir/arithmetic-mean & beir/geometric-mean \\
name &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  \\
\midrule
bm25-flat & .506\kernSig & .480\kernSig & .397\kernSig & .165\kernSig & .302 & .318\kernSig & .651\kernSig & .236\kernSig & .633\kernSig & .322 & .305\kernSig & .789\kernSig & .149 & .679\kernSig & .595\kernSig & \textbf{.442}\kernSig & .427 & .379 \\
bert-mlm & .700 & .688 & .336 & .224 & .319 & .369 & .727 & .317 & .574 & .303 & .510 & .844 & .146 & .603 & .756 & .256 & .449 & .399 \\
msmarco-bert-base-dot-v5 & .705 & .726 & .384\kernSig & .221 & .337 & .385 & .762\kernSig & .3

In [31]:
index = "name"
columns = "dataset"
values = ["nDCG@10"]

pd.set_option("display.max_columns", None)

table = ablation_results.pivot_table(index=index, columns=columns, values=values).map(lambda x: f"{x:.3f}").droplevel(0, axis=1)
try:
    sig = (
        ablation_results.pivot_table(index=index, columns=columns, values=[f"{value} reject" for value in values])
        .loc[:, pd.IndexSlice["nDCG@10 reject", sig_datasets]]
        .replace({1: "*", 0: ""})
        .droplevel(0, axis=1)
    )
    # sig.set_axis(table.columns, axis=1)
    sig = sig.assign(**{non_sig_dataset: "" for non_sig_dataset in non_sig_datasets})
except:
    sig = pd.DataFrame("", index=table.index, columns=table.columns)
table_sig = (table + sig).loc[:, datasets]
table_sig = table_sig.reindex(labels=ablation_models, columns=datasets_ablation)
missing_models = table_sig.loc[table_sig.isna().any(axis=1)].index
if len(missing_models) > 0:
    print(f"Warning: The following models are missing in the table: {', '.join(missing_models)}")
table_sig

  .replace({1: "*", 0: ""})


dataset,msmarco-passage/trec-dl-2019/judged,msmarco-passage/trec-dl-2020/judged,beir/arithmetic-mean,beir/geometric-mean
name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
bert-mlm,0.7,0.688,0.449,0.399
retromae-repro,0.723,0.711,0.476,0.428
bert-mae-bow,0.704,0.674,0.444,0.4
tite-2-late-intra,0.705,0.67,0.449,0.403
tite-2-late-intra-bow,0.657,0.657,0.4,0.353
tite-2-late-intra-mae,0.66,0.676,0.426,0.38


In [14]:
print(table_sig.to_latex().replace("0.", ".").replace("*", "\\kernSig"))

\begin{tabular}{lllll}
\toprule
dataset & msmarco-passage/trec-dl-2019/judged & msmarco-passage/trec-dl-2020/judged & beir/arithmetic-mean & beir/geometric-mean \\
name &  &  &  &  \\
\midrule
bert-mlm & .700 & .688 & .449 & .399 \\
retromae-repro & .723 & .711 & .476 & .428 \\
bert-mae-bow & .704 & .674 & .444 & .400 \\
tite-2-late & .705 & .670 & .449 & .403 \\
tite-2-late-bow & .657 & .657 & .400 & .353 \\
tite-2-late-mae & .660 & .676 & .426 & .380 \\
\bottomrule
\end{tabular}

