In [None]:
import json
import re
import sys
import warnings
from collections import defaultdict
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import List

import aquarel
import ir_datasets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import trectools
import yaml
from aquarel import Theme
from scipy.stats import pearsonr, ttest_ind, ttest_rel
# from tqdm.autonotebook import tqdm
from tqdm import tqdm

sys.path.append("../..")
from sparse_cross_encoder.data.ir_dataset_utils import DASHED_DATASET_MAP, get_base

In [None]:
# theme = aquarel.Theme.from_file(str(Path.home() / "aquarel-theme.json"))
# theme.apply()
theme = Theme(name="theme").set_grid(draw=True).set_font(family="serif")
theme.apply()
markers = ["o", "s", "v", "X", "P", "*", "D"]

In [None]:
BASELINE_DIR = Path("../../data/baseline-runs").resolve()
LOG_DIR = Path("../../experiments/sparse-cross-encoder").resolve()
ARCHIVE_LOG_DIR = Path("../../logs/archive").resolve()

In [None]:
LEVELS = {
    0.05: "*",
    #     0.01: "**",
    #     0.005: "***",
}


def mean_run_to_full_runs(run, full_run_df):
    columns = list(filter(lambda x: "@" not in x, run.index))
    run = run.loc[columns]
    runs = full_run_df.merge(run.to_frame().T, how="inner", on=columns)
    return runs


def significance(
    per_query_df,
    baseline_run_name,
    comparator_run_names,
    metric,
    datasets=None,
    corpora=None,
    bound=1,
):
    if datasets is not None:
        per_query_df = per_query_df.loc[per_query_df["dataset"].isin(datasets)]
    elif corpora is not None:
        per_query_df = per_query_df.loc[per_query_df["base"].isin(corpora)]
    else:
        raise ValueError("Either dataset or corpus must be specified")
    if pd.isna(baseline_run_name):
        return [np.nan] * len(comparator_run_names)
    per_query_df = per_query_df.set_index("run_name")[metric]
    if baseline_run_name not in per_query_df.index:
        return [np.nan] * len(comparator_run_names)
    baseline_run = per_query_df.loc[baseline_run_name].fillna(0)
    results = []
    for comparator_run_name in comparator_run_names:
        if comparator_run_name == baseline_run_name:
            results.append(np.nan)
            continue
        if pd.isna(comparator_run_name):
            results.append(np.nan)
            continue
        comparator_run = per_query_df.loc[comparator_run_name].fillna(0)
        warnings.filterwarnings("error")
        if bound:
            try:
                p_greater = ttest_rel(
                    baseline_run.values + bound, comparator_run.values, alternative="greater"
                )[1]
                p_less = ttest_rel(
                    baseline_run.values - bound, comparator_run.values, alternative="less"
                )[1]
                p_value = max(p_greater, p_less)
            except RuntimeWarning:
                p_value = 0
        else:
            p_value = ttest_rel(baseline_run.values, comparator_run.values)[1]
        warnings.resetwarnings()
        results.append(p_value)
        # star = ""
        # for level, stars in LEVELS.items():
        #     if p_value < level:
        #         star = stars
        # results.append((p_value, star))
    return results

In [None]:
def parse_config(config_path):
    config = yaml.safe_load(config_path.read_text())

    flat_config = {}

    def flatten_config(config, prefix=""):
        for k, v in config.items():
            key = prefix + "." + k if prefix else k
            if isinstance(v, dict):
                if len(v) == 2 and "init_args" in v:
                    flatten_config(v["init_args"], key)
                else:
                    flatten_config(v, key)
            else:
                flat_config[key] = v
        return config

    flatten_config(config)
    for k, v in list(flat_config.items()):
        if hasattr(v, "__iter__"):
            del flat_config[k]
    return flat_config


def load_run(run_file):
    run = pd.read_csv(
        run_file,
        sep=r"\s+",
        header=None,
        names=["query_id", "Q0", "doc_id", "rank", "score", "run_name"],
    )
    config = {}
    config_path = run_file.parent.parent / "pl_config.yaml"
    if "bm25" in run.iloc[0]["run_name"].lower():
        run["run_name"] = "bm25"
    if config_path.exists():
        run_id = run_file.parent.parent.parent.name[20:]
        config = parse_config(config_path)
        run["run_name"] = run_id
    # if arguana or quora, remove queries from run
    # https://twitter.com/nandan__thakur/status/1603920955679551488
    # https://github.com/beir-cellar/beir/blob/bc4d2b50b0059c0895282b609ff30b0530ed6648/beir/retrieval/evaluation.py#L49
    dataset = DASHED_DATASET_MAP[run_file.stem]
    if dataset.startswith("beir/arguana") or dataset.startswith("beir/quora"):
        run = run.loc[run["query_id"] != run["doc_id"]]
        run["rank"] = run.groupby("query_id")["score"].rank(ascending=False).astype(int)
    config["dataset"] = dataset
    config["base"] = get_base(dataset)
    if "medline" in config["base"]:
        config["base"] = "medline"
    config["run_name"] = run.iloc[0]["run_name"]
    run = run.astype({"query_id": str, "doc_id": str})
    return run, config


def load_qrels(dataset):
    qrels = trectools.TrecQrel()
    qrels_df = pd.DataFrame(ir_datasets.load(dataset).qrels_iter())
    qrels_df = qrels_df.rename(
        {"query_id": "query", "doc_id": "docid", "relevance": "rel", "iteration": "q0"},
        axis=1,
    )
    qrels.qrels_data = qrels_df
    return qrels


def evaluate_run(run_df, qrels, metrics, metric_kwargs):
    metric_to_func = {
        "NDCG": "get_ndcg",
        "recip_rank": "get_reciprocal_rank",
        "UNJ": "get_unjudged",
    }
    run_df = run_df.rename(
        {"query_id": "query", "Q0": "q0", "doc_id": "docid", "run_name": "system"},
        axis=1,
    )
    run = trectools.TrecRun()
    run.run_data = run_df
    trec_eval = trectools.TrecEval(run, qrels)
    metric_dfs = []
    for full_metric, kwargs in zip(metrics, metric_kwargs):
        metric, depth = full_metric.split("@")
        depth = depth.split("_")[0]
        depth = int(depth)
        values = getattr(trec_eval, metric_to_func[metric])(
            depth, per_query=True, **kwargs
        )
        values = values.rename(lambda x: full_metric, axis=1)
        metric_dfs.append(values.fillna(0))
    metric_df = pd.concat(metric_dfs, axis=1)
    return metric_df


def concat_df_config(df, config):
    length = df.shape[1] + len(config)
    values = np.empty((df.shape[0], length), dtype="object")
    values[:, : df.shape[1]] = df.values
    values[:, df.shape[1] :] = np.array(list(config.values()))
    columns = list(df.columns) + list(config.keys())
    return pd.DataFrame(values, columns=columns)

def get_query_id_doc_id_pairs(df):
    return df.loc[:, ["query_id", "doc_id"]].apply(tuple, axis=1)

def clean_run_df(df, bm25_df):
    df = df.loc[
        get_query_id_doc_id_pairs(df).isin(get_query_id_doc_id_pairs(bm25_df))
    ].copy()
    df["rank"] = df.groupby("query_id")["score"].rank(ascending=False).astype(int)
    return df

def parse_run_type(df):
    run_type = pd.Series(index=df.index, dtype="object")
    run_type.loc[df["model.config.query_doc_attention"].fillna(False)] = "full attention"
    run_type.loc[df["model.qds_transformer"].fillna(False)] = "qds transformer"
    run_type.loc[~df["model.config.query_doc_attention"].fillna(False)] = "sparse cross encoder"
    run_type.loc[df["run_name"].str.lower().str.contains("bm25")] = "bm25"
    run_type.loc[df["run_name"].str.lower().str.contains("colbert")] = "colbert"
    return run_type


In [None]:
depth = 100
metrics = ["NDCG@10", "NDCG@10_UNJ", f"recip_rank@{depth}", "UNJ@10"]
metric_kwargs = [{}, {"removeUnjudged": True}, {}, {}]

In [None]:
per_query_dfs = []

run_files = (
    list(BASELINE_DIR.glob("bm25/*.run"))
    # + list(BASELINE_DIR.glob("colbert/*.run"))
    + list(LOG_DIR.glob("**/*.run"))
    # + list(ARCHIVE_LOG_DIR.glob("run-*/**/*.run"))
)
bm25_runs = {}

pg = tqdm(run_files)
for run_file in pg:
    pg.set_description(run_file.name)
    if "trec-dl" not in str(run_file):
        continue
    # if (
    #     "beir" in str(run_file)
    #     or "tripclick" in str(run_file)
    #     or "orcas" in str(run_file)
    #     or "msmarco-passage-v2" in str(run_file)
    #     or "train" in str(run_file)
    #     or "dev" in str(run_file)
    # ):
    #     continue
    # try:
    run_df, config = load_run(run_file)
    # except:
    #     continue
    # if "bm25" in str(run_file):
    #     run_df = run_df.groupby("query_id").head(depth)
    #     bm25_runs[config["dataset"]] = run_df
    # else:
    #     run_df = clean_run_df(run_df, bm25_runs[config["dataset"]])
    qrels = load_qrels(config["dataset"])
    metric_df = evaluate_run(run_df, qrels, metrics, metric_kwargs)
    eval_df = concat_df_config(metric_df.reset_index().astype({"query": str}), config)
    per_query_dfs.append(eval_df)

per_query_results = pd.concat(per_query_dfs).infer_objects().reset_index(drop=True)
per_query_results["run_type"] = parse_run_type(per_query_results)

# all_query_results = per_query_results.copy()
# all_query_results["base"] = "all"
# all_query_results["dataset"] = "all"

passage_query_results = per_query_results.copy().loc[
    per_query_results.dataset.str.contains("passage")
]
passage_query_results["base"] = "passage"
passage_query_results["dataset"] = "passage"

document_query_results = per_query_results.copy().loc[
    per_query_results.dataset.str.contains("document")
]
document_query_results["base"] = "document"
document_query_results["dataset"] = "document"

per_query_results = pd.concat(
    [
        per_query_results,
        # all_query_results,
        passage_query_results,
        document_query_results,
    ]
)

del per_query_dfs
# per_query_results.loc[
#     per_query_results["dataset"].str.startswith("beir/cqadupstack"), "dataset"
# ] = "beir/cqadupstack"
per_query_results.reset_index().to_json("per_query_results.json")

In [None]:
per_query_results = pd.read_json("per_query_results.json").reset_index(drop=True)
per_query_results = per_query_results.drop("index", axis=1)

In [None]:
warnings.simplefilter(action="ignore", category=pd.errors.PerformanceWarning)
groupby_columns = list(
    filter(lambda x: x not in metrics + ["query"], per_query_results.columns)
)
results = (
    per_query_results.groupby(groupby_columns, dropna=False)[metrics]
    .mean()
    .reset_index()
    .copy()
)
groupby_columns = list(
    filter(lambda x: x not in metrics + ["dataset", "query"], per_query_results.columns)
)
base_results = (
    per_query_results.groupby(groupby_columns, dropna=False)[metrics].mean().reset_index().copy()
)
groupby_columns = list(
    filter(lambda x: x not in metrics + ["dataset", "query", "base"], per_query_results.columns)
)
base_results.head(5)

In [None]:
per_query_results.pivot(
    index=[
        "model.config.attention_window_size",
        "model.config.query_cls_attention",
        "model.config.cls_query_attention",
        "model.config.doc_query_attention",
        "run_name",
    ],
    columns=["query", "dataset"],
    values=["NDCG@10"],
)

In [None]:
results.pivot(
    index=[
        "model.config.attention_window_size",
        "model.config.query_cls_attention",
        "run_name",
    ],
    columns=["dataset"],
    values=["NDCG@10"],
).multiply(100).round(1)

In [None]:
results.pivot(
    index=[
        "run_type",
        "model.config.attention_window_size",
        # "model.config.max_position_embeddings",
        # "model.config.query_cls_attention",
        # "model.config.doc_query_attention",
        "run_name",
    ],
    columns=["dataset"],
    values=["NDCG@10"],
).multiply(100).round(1).transpose()

In [None]:
base_results.pivot(
    index=[
        "run_type",
        "model.config.max_position_embeddings",
        # "model.config.query_cls_attention",
        # "model.config.doc_query_attention",
        "model.config.attention_window_size",
        "run_name"
    ],
    columns="base",
    values=["NDCG@10"],
).round(3) * 100

In [None]:
base = False
if base:
    table = base_results.copy()
else:
    table = results.copy()
table["model.config.attention_window_size"] = (
    table["model.config.attention_window_size"].fillna(float("inf")).copy()
)
table = table.pivot(
    index=[
        "run_type",
        "model.config.attention_window_size",
    ],
    columns=["base" if base else "dataset"],
    values=["NDCG@10", "run_name"],
)
# table = table.drop("colbert", level="run_type")
table = table.sort_index(ascending=(True, False, True))

table
run_name_table = table.loc[:, "run_name"]
table = table.loc[:, "NDCG@10"]

tost = []
ttest = []
for corpus in table.columns:
    # baseline = run_name_table.loc[("full attention", 512, float("inf")), corpus]
    if "passage" in corpus:
        baseline = "g94mcy7f"
    else:
        baseline = "3u6n318u"
    if isinstance(baseline, pd.Series):
        baseline = baseline.iloc[0]
    run_names = run_name_table.loc[:, corpus].values.tolist()
    corpus_sig_results = []
    if corpus == "out-of-domain":
        corpora = table.index[table.index.values != "msmarco-passage"].values
    else:
        corpora = [corpus]
    key = "corpora" if base else "datasets"
    kwargs = {key: corpora}
    tost.append(
        significance(
            per_query_results,
            baseline,
            run_names,
            "NDCG@10",
            bound=0.02,
            **kwargs,
        )
    )
    ttest.append(
        significance(
            per_query_results,
            baseline,
            run_names,
            "NDCG@10",
            bound=0,
            **kwargs,
        )
    )
tost_df = pd.DataFrame(tost, index=table.columns, columns=table.index)
# ttest_df = pd.DataFrame(ttest, index=table.columns, columns=table.index)


def format_row(series, tost_df):
    rounded = (series * 100).fillna(0).round(1)
    max_val = rounded.max()
    tost_row = tost_df.loc[series.name, series.index]
    out_values = []
    iterator = zip(tost_row, rounded)
    for tost_p_val, val in iterator:
        out_val = f"{val:.1f}"
        if out_val.startswith("0."):
            out_val = "0" + out_val
        if val == max_val:
            out_val = "\\textbf{" + out_val + "}"
        if tost_p_val < 0.05 / (series.shape[0] - 1):
            out_val = out_val + "\\kernSigtost"
        out_values.append(out_val)
    out = pd.Series(out_values, index=series.index)
    return out


table = table.loc[
    ["full attention", "sparse cross encoder", "qds transformer"]
].transpose()

table.fillna(0).multiply(100).round(1)

In [None]:
pretty_table = pd.concat(
    [
        table.filter(like="passage", axis=0).apply(
            lambda x: (format_row(x, tost_df)), axis=1
        ),
        table.filter(like="document", axis=0).apply(
            lambda x: (format_row(x, tost_df)), axis=1
        ),
    ]
).loc[
    [
        "msmarco-passage/trec-dl-2019/judged",
        "msmarco-passage/trec-dl-2020/judged",
        "msmarco-passage-v2/trec-dl-2021/judged",
        "msmarco-passage-v2/trec-dl-2022/judged",
        "passage",
        "msmarco-document/trec-dl-2019/judged",
        "msmarco-document/trec-dl-2020/judged",
        "msmarco-document-v2/trec-dl-2021/judged",
        "msmarco-document-v2/trec-dl-2022/judged",
        "document",
    ]
]
print(pretty_table.to_latex(escape=False))

In [None]:
idx = pd.IndexSlice
# (table.loc[idx[:, :, :, ["g94mcy7f", "p4v9923r", "cg4a0ke7", "uatdxcst", "zzcodw0f", "y5pcbt5n"]],]
#     .transpose()
#     .filter(like="document", axis=0))
print(
    table.loc[idx[:, :, :, ["bm25", "0gyv091s", "g94mcy7f", "p4v9923r", "cg4a0ke7", "tmj1empz", "y5pcbt5n"]],]
    .transpose()
    .filter(like="document", axis=0)
    .apply(lambda x: (format_row(x, False)), axis=1)
    .to_latex(escape=False)
)

In [None]:
datasets = ["msmarco-passage", "out-of-domain"]
dataset_name_map = {"msmarco-passage": "MS MARCO", "out-of-domain": "Out-of-Domain"}
fig, ax = plt.subplots(1, 2, figsize=(6, 2.9))
plot_data = table.loc[datasets].drop("Baseline", axis=1)
approaches = plot_data.columns.get_level_values("model_name").unique().tolist()
approaches.remove("CLS Interaction")

colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

for ax_idx, dataset in enumerate(datasets):
    ax[ax_idx].set_axisbelow(True)
    dataset_data = plot_data.loc[dataset]
    dataset_data = dataset_data.droplevel("run_name")
    ax[ax_idx].set_title(dataset_name_map[dataset])
    ax[ax_idx].set_xlabel("Attention Window Size")
    ax[ax_idx].set_ylabel("NDCG@10")
    attention_window_sizes = set(
        dataset_data.index.get_level_values("model.config.attention_window_size").values
    )
    min_window_size = min(attention_window_sizes - {float("inf")})
    max_window_size = max(attention_window_sizes - {float("inf")})
    for approach_idx, approach in enumerate(approaches):
        color = colors[approach_idx]
        approach_data = dataset_data.loc[approach]
        no_window = approach_data.loc[float("inf")]
        baseline_label = approach.replace("/ Longformer", "")
        if baseline_label == "Independent Query":
            baseline_label = "Independent Query (w=$\infty$)"
        line = ax[ax_idx].plot(
            approach_data.index.values,
            approach_data.values,
            label=approach.replace("Full Attention / ", "") if ax_idx == 0 else None,
            marker=markers[approach_idx],
            color=color,
        )
        ax[ax_idx].plot(
            [min_window_size, max_window_size],
            [no_window, no_window],
            color=color,
            label=baseline_label if ax_idx == 0 else None,
            linestyle="--" if approach_idx == 0 else ":",
            # marker=markers[approach_idx],
            # w=2,
            
        )
    for approach_idx, approach in enumerate(approaches):
        approach_data = dataset_data.loc[approach]
        approach_data = approach_data.loc[
            approach_data.index.get_level_values("model.config.attention_window_size")
            != float("inf")
        ]
        color = colors[approach_idx]
        line = ax[ax_idx].plot(
            approach_data.index.values,
            approach_data.values,
            marker=markers[approach_idx],
            color=color,
        )
# fig.legend(loc="center", bbox_to_anchor=(1.15, 0.5))
fig.legend(ncols=2, bbox_to_anchor=(0.5, -0.05), loc="center")
fig.tight_layout()
plt.savefig("domain-effectiveness.pdf", bbox_inches='tight')
plt.show()

In [None]:
base_results.loc[base_results["run_name"].isin(("hfclj9k8", "vtwk7mqt"))].filter(regex=r"run_name|model\.config\..*_attention$", axis=1).drop_duplicates()

In [None]:
base_results.pivot(
    index=[
        "model.config.attention_window_size",
        "model.config.query_cls_attention",
        "model.config.cls_query_attention",
        "model.config.doc_query_attention",
        "model.config.query_doc_embedding_attention",
        "run_name",
    ],
    columns="base",
    values=["NDCG@10_UNJ"],
).filter(items=list(set(base_results["base"]) - set("msmarco-passage")))