In [1]:
import json
import re
import warnings
from functools import lru_cache
from pathlib import Path

import aquarel
import ir_datasets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from ir_datasets import docs_parent_id

import yaml
from scipy.stats import pearsonr, ttest_ind, ttest_rel
from tqdm import tqdm

from lightning_ir.lightning_utils.validation_utils import evaluate_run
from lightning_ir.data.dataset import DASHED_DATASET_MAP

In [2]:
theme = aquarel.Theme.from_file(Path.home() / "aquarel-theme.json")
theme.apply()
markers = ["o", "X", "s", "v", "P", "*", "D"]
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]

In [3]:
BASELINE_DIR = Path.cwd().parent / "data/baseline-runs"
LOG_DIR = Path.cwd().parent / "experiments/wandb"

In [29]:
def parse_config(config_path):
    config = yaml.safe_load(config_path.read_text())

    flat_config = {}

    def flatten_config(config, prefix=""):
        for k, v in config.items():
            key = prefix + "." + k if prefix else k
            if isinstance(v, dict):
                if len(v) == 2 and "init_args" in v:
                    flat_config[key + f".class_path"] = v["class_path"]
                    flatten_config(v["init_args"], key)
                else:
                    flatten_config(v, key)
            if isinstance(v, (list, tuple)):
                for i, _v in enumerate(v):
                    if isinstance(_v, dict):
                        flatten_config(_v, f"{key}.{i}")
                    else:
                        flat_config[f"{key}.{i}"] = _v
            else:
                flat_config[key] = v
        return config

    flatten_config(config)

    config = json.loads(
        (config_path.parent / "huggingface_checkpoint" / "config.json").read_text()
    )
    for k, v in config.items():
        flat_config[f"model.config.{k}"] = v

    for k, v in list(flat_config.items()):
        if isinstance(v, (list, tuple)):
            try:
                flat_config[k] = ".".join(_v for _v in v if _v)
            except:
                del flat_config[k]
        elif isinstance(v, dict):
            del flat_config[k]

    return flat_config


def load_run(run_file):
    run = pd.read_csv(
        run_file,
        sep=r"\s+",
        header=None,
        names=["query_id", "Q0", "doc_id", "rank", "score", "run_name"],
    )
    run = run.sort_values(["query_id", "score"], ascending=[True, False])
    run["rank"] = run.groupby("query_id")["score"].rank(ascending=False, method="first")
    
    if run_file.parent.parent.name == "tirex" or "rank-" in run_file.parent.name:
        run["run_name"] = run_file.parent.name

    config = {}
    config_path = run_file.parent.parent.parent / "pl_config.yaml"
    if config_path.exists() and "wandb" in str(run_file):
        run_id = run_file.parent.parent.parent.parent.name.split("-")[-1]
        config = parse_config(config_path)
        run["run_name"] = run_id
        
    config["run_name"] = run.iloc[0]["run_name"]
    dataset_name = run_file.name.split(".")[0].strip("_")
    dataset_name = re.sub(r"_+", "_", dataset_name)
    try:
        dataset_id = DASHED_DATASET_MAP[dataset_name]
        first_stage = "tirex"
        if run_file.parent.name in ("bm25", "chatnoir"):
            config["run_name"] = "first_stage"
        if run_file.parent.name == "colbert":
            config["run_name"] = "first_stage"
            first_stage = "colbert"
    except KeyError:
        split = dataset_name.split("_")
        if len(split) == 2:
            first_stage, dataset_name = split
        else:
            raise ValueError(f"Unknown dataset name: {dataset_name}")
        if "tirex" in first_stage:
            first_stage = "tirex"
        dataset_id = DASHED_DATASET_MAP[dataset_name]

    config["dataset"] = dataset_id
    config["first_stage"] = first_stage
    config["base"] = docs_parent_id(dataset_id)
    run = run.astype({"query_id": str, "doc_id": str})
    return run, config


@lru_cache
def load_qrels(dataset):
    dataset = ir_datasets.load(dataset)
    qrels_df = pd.DataFrame(dataset.qrels_iter())
    return qrels_df


In [23]:
metrics = {
    "nDCG@10": {},
    # "NDCG@10_UNJ" : {"removeUnjudged": True},
}

In [31]:
# trectools throws annoying warnings because of pandas
warnings.simplefilter(action="ignore", category=FutureWarning)
values = []

run_files = (
    list(BASELINE_DIR.glob("tirex/*/*.run"))
    + list(BASELINE_DIR.glob("colbert/*.run"))
    + list(BASELINE_DIR.glob("rank-*/*.run"))
    + list(LOG_DIR.glob("**/*.run"))
)

pg = tqdm(run_files)
for run_file in pg:
    if (
        "train" in run_file.name
        or "msmarco-passage-v2" in run_file.name
        or "beir" in run_file.name
        or "dev" in str(run_file)
        or "tripclick" in run_file.name
        or run_file.name.startswith("__")
        or "msmarco-document" in run_file.name
        # or "tirex" in str(run_file)
        # or "orcas" in run_file.name
        # or ("dev" in run_file.name and "small" not in run_file.name)
    ):
        continue
    if "diversity" in run_file.name:
        continue
    pg.set_description(
        str(
            run_file.relative_to(BASELINE_DIR)
            if run_file.is_relative_to(BASELINE_DIR)
            else run_file.relative_to(LOG_DIR)
        )
    )
    run_df, config = load_run(run_file)
    if run_df is None:
        continue
    qrels_df = load_qrels(config["dataset"])
    qrel_qids = set(qrels_df["query_id"])
    run_qids = set(run_df["query_id"].drop_duplicates())
    missing_qids = set(qrel_qids) - set(run_qids)
    if missing_qids:
        # print(f"found {len(missing_qids)} missing qids in {run_file}")
        qrels_df = qrels_df.loc[qrels_df["query_id"].isin(run_qids)]
    _metrics = evaluate_run(run_df, qrels_df, metrics)
    values.append({**config, **_metrics})

results = pd.DataFrame(values)
# per_query_results = pd.concat(per_query_dfs).infer_objects().reset_index(drop=True)

# per_query_results = per_query_results.loc[
#     ~(
#         per_query_results["first_stage"].str.contains("tirex")
#         & per_query_results["dataset"].str.contains("msmarco")
#     )
# ]

# del per_query_dfs

# per_query_results["finetuned"] = per_query_results[
#     "model.model_name_or_path"
# ].str.contains("experiments")

# per_query_results.to_json("per_query_results.jsonl", orient="records", lines=True)

tirex/monot5-large/clueweb12-touche-2020-task-2.run:   1%|          | 8/1055 [00:08<12:35,  1.39it/s]       

/var/tmp/fschlatt/.ir_datasets/touche/2020/task-2/qrels.qrels


tirex/monot5-large/argsme-2020-04-01-touche-2020-task-1.run:   2%|▏         | 16/1055 [00:14<15:42,  1.10it/s]

/var/tmp/fschlatt/.ir_datasets/touche/2020/task-1/qrels.qrels


run-20240424_120056-008cf7pt/files/huggingface_checkpoint/runs/tirex-rerank_gov2-trec-tb-2005.run: 100%|██████████| 1055/1055 [05:11<00:00,  3.39it/s]                          


In [16]:
def run_type(row):
    if not pd.isna(row["model.model_name_or_path"]):
        return "monoELECTRA"
    if "rank-gpt" in row["run_name"]:
        if "turbo" in row["run_name"]:
            return "RankGPT-4 Turbo"
        return "RankGPT-4"
    if "rank-zephyr" in row["run_name"]:
        return "RankZephyr"
    if "monot5" in row["run_name"]:
        return row["run_name"].replace("t5", "T5").replace("-", " ")
    if "monobert" in row["run_name"]:
        return row["run_name"].replace("bert", "BERT").replace("-", " ")
    if "sparse" in row["run_name"]:
        return "Sparse monoMiniLM"
    if "list-in-t5" in row["run_name"]:
        return "LiT5-Distill"
    return "First Stage"


def num_params(row):
    if "monot5-base" in row["run_name"]:
        return "220M"
    if "monot5-large" in row["run_name"]:
        return "770M"
    if "monot5-3b" in row["run_name"]:
        return "3B"
    if "rank-gpt-4" in row["run_name"]:
        return "?"
    if "rank-zephyr" in row["run_name"]:
        return "7B"
    if "list-in-t5" in row["run_name"]:
        return "220M"
    if "sparse" in row["run_name"]:
        return "11M"
    if "monobert-base" in row["run_name"]:
        return "110M"
    if "monobert-large" in row["run_name"]:
        return "340M"
    if (
        not pd.isna(row["model.model_name_or_path"])
        and row["model.config.num_hidden_layers"] == 12
    ):
        return "110M"
    if (
        not pd.isna(row["model.model_name_or_path"])
        and row["model.config.num_hidden_layers"] == 24
    ):
        return "340M"
    return "--"


def first_train_dataset(row):
    if pd.isna(row["data.train_dataset"]):
        return None
    if "google/electra" not in row["model.model_name_or_path"]:
        return "CBv2"
    if "rank-gpt-3-turbo" in row["data.train_dataset"]:
        return "RGPT3.5"
    if "rank-gpt-4-turbo" in row["data.train_dataset"]:
        return "RGPT4-T"
    if "msmarco-passage-train" in row["data.train_dataset"]:
        return "CBv2"
    if "twolar" in row["data.train_dataset"]:
        return "TWOLAR"
    raise ValueError(f"Unknown dataset: {row['data.train_dataset']}")


def second_train_dataset(row):
    if pd.isna(row["data.train_dataset"]):
        return None
    if row["first_train_dataset"] != "CBv2":
        return None
    if "rank-gpt-3-turbo" in row["data.train_dataset"]:
        return "RGPT3.5"
    if "rank-gpt-4-turbo" in row["data.train_dataset"]:
        return "RGPT4-T"
    if "msmarco-passage-train" in row["data.train_dataset"]:
        return "CBv2"
    if "twolar" in row["data.train_dataset"]:
        return "TWOLAR"
    raise ValueError(f"Unknown dataset: {row['data.train_dataset']}")


results["run_type"] = results.apply(run_type, axis=1)
results["first_train_dataset"] = results.apply(first_train_dataset, axis=1)
results["second_train_dataset"] = results.apply(second_train_dataset, axis=1)
results["num_params"] = results.apply(num_params, axis=1)

In [17]:
warnings.simplefilter(action="ignore", category=pd.errors.PerformanceWarning)
groupby_columns = list(
    filter(
        lambda x: x not in list(metrics) + ["dataset"],
        results.columns,
    )
)
columns = results.columns.intersection(list(metrics)).tolist()
base_results = (
    results.groupby(groupby_columns, dropna=False)[columns]
    .mean()
    .reset_index()
    .copy()
)
base_results

Unnamed: 0,run_name,first_stage,base,seed_everything,trainer.accelerator,trainer.strategy,trainer.devices,trainer.num_nodes,trainer.precision,trainer.logger.class_path,...,model.config.architectures,model.config.model_type,model.config.save_step,model.config.torch_dtype,model.config.transformers_version,run_type,first_train_dataset,second_train_dataset,num_params,nDCG@10
0,008cf7pt,colbert,msmarco-passage,42.0,auto,auto,auto,1.0,bf16-mixed,lightning_ir.main.CustomWandbLogger,...,FlashMonoElectraModel,mono-electra,590.0,float32,4.39.3,monoELECTRA,CBv2,TWOLAR,110M,0.761399
1,008cf7pt,tirex,antique,42.0,auto,auto,auto,1.0,bf16-mixed,lightning_ir.main.CustomWandbLogger,...,FlashMonoElectraModel,mono-electra,590.0,float32,4.39.3,monoELECTRA,CBv2,TWOLAR,110M,0.576209
2,008cf7pt,tirex,argsme/2020-04-01,42.0,auto,auto,auto,1.0,bf16-mixed,lightning_ir.main.CustomWandbLogger,...,FlashMonoElectraModel,mono-electra,590.0,float32,4.39.3,monoELECTRA,CBv2,TWOLAR,110M,0.305200
3,008cf7pt,tirex,clueweb09/en,42.0,auto,auto,auto,1.0,bf16-mixed,lightning_ir.main.CustomWandbLogger,...,FlashMonoElectraModel,mono-electra,590.0,float32,4.39.3,monoELECTRA,CBv2,TWOLAR,110M,0.185766
4,008cf7pt,tirex,clueweb12,42.0,auto,auto,auto,1.0,bf16-mixed,lightning_ir.main.CustomWandbLogger,...,FlashMonoElectraModel,mono-electra,590.0,float32,4.39.3,monoELECTRA,CBv2,TWOLAR,110M,0.292334
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
481,zhub2hz2,tirex,medline/2017,42.0,auto,auto,auto,1.0,bf16-mixed,lightning_ir.main.CustomWandbLogger,...,FlashMonoElectraModel,mono-electra,189.0,float32,4.39.3,monoELECTRA,CBv2,RGPT4-T,110M,0.309761
482,zhub2hz2,tirex,msmarco-passage,42.0,auto,auto,auto,1.0,bf16-mixed,lightning_ir.main.CustomWandbLogger,...,FlashMonoElectraModel,mono-electra,189.0,float32,4.39.3,monoELECTRA,CBv2,RGPT4-T,110M,0.719828
483,zhub2hz2,tirex,nfcorpus,42.0,auto,auto,auto,1.0,bf16-mixed,lightning_ir.main.CustomWandbLogger,...,FlashMonoElectraModel,mono-electra,189.0,float32,4.39.3,monoELECTRA,CBv2,RGPT4-T,110M,0.298179
484,zhub2hz2,tirex,vaswani,42.0,auto,auto,auto,1.0,bf16-mixed,lightning_ir.main.CustomWandbLogger,...,FlashMonoElectraModel,mono-electra,189.0,float32,4.39.3,monoELECTRA,CBv2,RGPT4-T,110M,0.525003


In [21]:
results.filter(like="loss", axis=1)

0
1
2
3
4
...
997
998
999
1000
1001


In [8]:
index_columns = [
    "data.train_dataset",
    "trainer.max_epochs",
    "run_name",
]
columns = ["first_stage", "dataset"]
values = ["nDCG@10"]

filter_series = (
    results["dataset"].str.contains("msmarco") &
    results["first_stage"].isin(["tirex", "colbert"])
    & results["model.config.loss_functions"]
)

table = (
    results.loc[filter_series]
    .pivot(
        index=index_columns,
        columns=columns,
        values=values,
    )
    .droplevel(axis=1, level=0)
)

table.sort_index(ascending=False).dropna().round(3)

KeyError: 'model.config.loss_functions'

In [None]:
index_columns = [
    "data.train_dataset",
    "trainer.max_epochs",
    "run_name",
]
columns = ["first_stage", "base"]
values = ["nDCG@10"]

filter_series = (
    base_results["first_stage"].eq("tirex")
)

table = (
    base_results.loc[filter_series]
    .pivot(
        index=index_columns,
        columns=columns,
        values=values,
    )
    .droplevel(axis=1, level=[0, 1])
)

def hmean(x):
    return 1 / np.mean(1 / x)


def qmean(x):
    return np.sqrt(np.mean(x**2))


def gmean(x):
    return np.exp(np.log(x.prod()) / x.notna().sum(1))


arithmetic_mean = table.mean(axis=1).rename("Arithmetic Mean").to_frame()
# harmonic_mean = table.apply(hmean, axis=1).rename("Harmonic Mean").to_frame()
# quadratic_mean = table.apply(qmean, axis=1).rename("Quadratic Mean").to_frame()
geoemtric_mean = table.apply(qmean, axis=1).rename("Geometric Mean").to_frame()

table = pd.concat(
    [
        table,
        arithmetic_mean,
        # harmonic_mean,
        # quadratic_mean,
        geoemtric_mean,
    ],
    axis=1,
)


table.sort_index(ascending=False).dropna().round(3).sort_values("Geometric Mean", ascending=False)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,antique,argsme/2020-04-01,clueweb09/en,clueweb12,cord19/fulltext,cranfield,disks45/nocr,gov,gov2,medline/2004,medline/2017,msmarco-passage,nfcorpus,vaswani,wapo/v2,Arithmetic Mean,Geometric Mean
data.train_dataset,trainer.max_epochs,run_name,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1
,,rank-zephyr,0.534,0.363,0.213,0.303,0.767,0.009,0.556,0.294,0.56,0.413,0.5,0.72,0.314,0.512,0.508,0.438,0.477
data/baseline-runs/rank-zephyr/__colbert-10000-sampled-100__msmarco-passage-train-judged.run,3.0,q7d4u8sj,0.581,0.335,0.191,0.296,0.703,0.007,0.517,0.233,0.53,0.384,0.407,0.706,0.306,0.526,0.476,0.413,0.453
data/baseline-runs/rank-zephyr/__colbert-10000-sampled-100__msmarco-passage-train-judged.run,3.0,lb5en990,0.584,0.333,0.19,0.299,0.707,0.008,0.515,0.236,0.535,0.385,0.391,0.707,0.305,0.53,0.473,0.413,0.453
data/baseline-runs/rank-zephyr/__colbert-10000-sampled-100__msmarco-passage-train-judged.run,3.0,8ntakfv1,0.583,0.328,0.193,0.298,0.706,0.008,0.512,0.234,0.53,0.382,0.391,0.705,0.307,0.528,0.471,0.412,0.452
data/baseline-runs/rank-gpt-3-turbo/msmarco-passage-train-judged.run,1.0,0seh1a6k,0.47,0.42,0.201,0.311,0.665,0.008,0.541,0.256,0.531,0.371,0.448,0.681,0.296,0.5,0.482,0.412,0.447
data/baseline-runs/rank-zephyr/__colbert-10000-sampled-100__msmarco-passage-train-judged.run,1.0,s7n25zed,0.588,0.345,0.199,0.295,0.682,0.009,0.512,0.249,0.521,0.389,0.297,0.711,0.3,0.516,0.449,0.404,0.444
,,monot5-3b,0.543,0.391,0.199,0.279,0.603,0.011,0.569,0.289,0.513,0.395,0.301,0.736,0.324,0.458,0.476,0.406,0.442
,,castorini-list-in-t5-300,0.576,0.394,0.214,0.275,0.686,0.011,0.509,0.266,0.534,0.389,0.278,0.687,0.293,0.429,0.47,0.401,0.439
data/baseline-runs/rank-zephyr/__colbert-10000-sampled-100__msmarco-passage-train-judged.run,1.0,kw5kywed,0.573,0.344,0.148,0.264,0.679,0.007,0.5,0.206,0.511,0.362,0.412,0.693,0.3,0.501,0.449,0.397,0.438
data/baseline-runs/colbert/msmarco-passage-train-judged.run,,monoelectralarge,0.492,0.329,0.202,0.312,0.7,0.008,0.515,0.29,0.522,0.407,0.323,0.685,0.291,0.473,0.46,0.401,0.437
