!mamba install -c conda-forge "vegafusion-python-embed>=1.4.0" "vegafusion>=1.4.0" -y

In [1]:
import glob
import os

import altair as alt
import pandas as pd

from pathlib import Path
from theme import theme

In [2]:
alt.data_transformers.enable("vegafusion")
alt.themes.register("latex", theme)
alt.themes.enable("latex")

ThemeRegistry.enable('latex')

## Load results

In [3]:
directory = Path("outputs")
data = ["baidu", "ltr", "uva"]

In [4]:
def run_complete(file: Path):
    return (file.is_dir()
        and (file / "val.parquet").exists()
        and (file / "test_click.parquet").exists()
        and (file / "test_rel.parquet").exists())

def parse_model_name(path: Path):
    directory = path.name
    options = {}

    for option in directory.split(","):
        k, v = option.split("=")
        options[k] = v

    return options

def parse_result_file(run: Path, file: str):
    options = parse_model_name(run)
    
    test_rel_df = pd.read_parquet(run / file)
    test_rel_df["run"] = run.name
    test_rel_df["model"] = options["model"]
    test_rel_df["data"] = options["data"]
    test_rel_df["random_state"] = options["random_state"]
    
    return test_rel_df

def load_data(data, file: str):
    data_path = directory / data
    runs = [f for f in data_path.iterdir() if run_complete(f)]
    print(f"Loaded {len(runs)} run(s) for {data}")

    return pd.concat([parse_result_file(run, file) for run in runs])

In [11]:
rel_df = pd.concat([load_data(d, "test_rel.parquet") for d in data])
rel_df.head()

Loaded 45 run(s) for baidu
Loaded 45 run(s) for ltr
Loaded 45 run(s) for uva


Unnamed: 0,dcg@01,dcg@03,dcg@05,dcg@10,frequency_bucket,mrr@10,ndcg@10,query_id,run,model,data,random_state
0,0.0,4.13093,5.291488,10.634423,8,0.5,0.361168,1,"data=baidu,es_patience=5,logging=True,max_epoc...",naive-pointwise,baidu,1906
1,0.0,0.0,0.0,0.0,9,0.0,0.0,2,"data=baidu,es_patience=5,logging=True,max_epoc...",naive-pointwise,baidu,1906
2,7.0,7.63093,8.791489,13.294619,3,1.0,0.360048,3,"data=baidu,es_patience=5,logging=True,max_epoc...",naive-pointwise,baidu,1906
3,0.0,0.0,0.430677,2.133662,8,0.25,0.18052,4,"data=baidu,es_patience=5,logging=True,max_epoc...",naive-pointwise,baidu,1906
4,0.0,0.0,2.70797,2.70797,6,0.2,0.153858,5,"data=baidu,es_patience=5,logging=True,max_epoc...",naive-pointwise,baidu,1906


# Plot Ranking Results

In [12]:
model2name = {
    "naive-pointwise": "Point. Naive",
    "pbm-pointwise": "Point. PBM",
    "regression-em": "RegressionEM",
    "ips-pointwise": "Point. IPS",
    "naive-listwise": "List. Naive",
    "ips-listwise": "List. IPS",
    "dla": "Dual Learning Algorithm",
#    "pbm-listwise": "Listwise PBM",
    "pairwise-debias": "Pairwise Debias",
}

data2name = {
    "baidu": "Baidu BERT Embeddings",
    "uva": "Our BERT Embeddings",
    "ltr": "LTR Features"
}

metric = "mrr@10"

In [13]:
metric_df = rel_df.groupby(["data", "model", "random_state"]).aggregate({metric: "mean"}).reset_index()
metric_df = metric_df[metric_df.model.map(lambda x: x in model2name)]
metric_df["model"] = metric_df["model"].map(model2name)
metric_df["data"] = metric_df["data"].map(data2name)

base = alt.Chart(metric_df, width=300)

bars = base.mark_bar().encode(
    x=alt.X("model", title=None, sort=list(model2name.values())).axis(labelAngle=45),
    y=alt.Y(f"mean({metric})").scale(zero=False),
    color=alt.Color("model", title=None, legend=None),
    tooltip=["model", f"mean({metric}):Q"],
)

error = base.mark_errorbar(extent="ci").encode(
    x=alt.X("model", sort=list(model2name.values())),
    y=alt.Y(metric, title=metric.upper()),
    strokeWidth=alt.value(4)
)

(bars + error).facet(
    column=alt.Column("data", title="", sort=list(data2name.values())),   
).configure_legend(
    orient="top",
)

# Plot Click Prediction

In [14]:
click_df = pd.concat([load_data(d, "test_click.parquet") for d in data])
click_df.head()

Loaded 45 run(s) for baidu
Loaded 45 run(s) for ltr
Loaded 45 run(s) for uva


Unnamed: 0,BC_dcg@01,BC_dcg@03,BC_dcg@05,BC_dcg@10,BC_mrr@10,BC_ndcg@10,loss,nll,query_id,run,model,data,random_state
0,1.0,1.335689,1.482644,1.552952,1.0,0.994436,0.173028,0.173028,22618,"data=baidu,es_patience=5,logging=True,max_epoc...",naive-pointwise,baidu,1906
1,0.414214,1.078205,1.173218,1.339991,0.333333,0.8066,0.262523,0.262523,572293,"data=baidu,es_patience=5,logging=True,max_epoc...",naive-pointwise,baidu,1906
2,0.414214,1.175104,1.268113,1.430666,0.5,0.861182,0.090926,0.090926,516399,"data=baidu,es_patience=5,logging=True,max_epoc...",naive-pointwise,baidu,1906
3,0.414214,0.57362,0.683886,0.832818,0.0,0.949916,0.081797,0.081797,551606,"data=baidu,es_patience=5,logging=True,max_epoc...",naive-pointwise,baidu,1906
4,0.090508,0.461607,0.565915,1.047312,0.166667,0.630424,0.047556,0.047556,285250,"data=baidu,es_patience=5,logging=True,max_epoc...",naive-pointwise,baidu,1906


In [15]:
metric_df = click_df.groupby(["data", "model", "random_state"]).aggregate({"nll": "mean"}).reset_index()
metric_df = metric_df[metric_df["nll"].notna() & ~metric_df.model.str.contains("list")]
metric_df["model"] = metric_df["model"].map(model2name)

base = alt.Chart(metric_df, width=150)

bars = base.mark_bar().encode(
    x=alt.X("model", title=None, sort=list(model2name.values())).axis(labelAngle=45),
    y=alt.Y("mean(nll)").scale(zero=False),
    color=alt.Color("model", title=None, legend=None),
    tooltip=["model", "mean(nll):Q"],
)

error = base.mark_errorbar(extent="ci").encode(
    x=alt.X("model", title=None, sort=list(model2name.values())).axis(labelAngle=45),
    y=alt.Y("nll"),
    strokeWidth=alt.value(4)
)

(bars + error).facet(
    column=alt.Column("data", title=""),   
).configure_legend(
    orient="top",
)