In [None]:
import altair as alt

from pathlib import Path

from src.data import WandbLoader

In [None]:
experiment_name = "sigir-cmip"
run_name = "pbm"

loader = WandbLoader(
    "your-entity",
    "your-project",
    experiment_name,
    run_name,
)

output_directory = Path(f"figures/{experiment_name}/{run_name}")
output_directory.mkdir(parents=True, exist_ok=True)

In [None]:
df = loader.load_metrics()
df.head()

# Plot metrics bar chart

In [None]:
user_model = "GradedPBM"
train_policy = "NoisyOraclePolicy"
test_policy = "UniformPolicy"

## Fetch in-distribution ppl and out-of-distribution ppl

In [None]:
columns = ["model", "user_model", "train_policy", "random_state", "test/ppl"]

ind_df = df[
    (df["user_model"] == user_model)
    & (df["train_policy"] == train_policy)
    & (df["test_policy"] == train_policy)
][columns]

ind_df = ind_df.rename(columns={
    "test/ppl": "in-distribution PPL"
})

In [None]:
columns = ["model", "user_model", "train_policy", "test_policy", "random_state", "test/nDCG", "test/cmi", "test/ppl"]

ood_df = df[
    (df["user_model"] == user_model)
    & (df["train_policy"] == train_policy)
    & (df["test_policy"] == test_policy)
][columns]

ood_df = ood_df.rename(columns={
    "test/nDCG": "nDCG",
    "test/cmi": "CMIP",
    "test/ppl": "out-of-distribution PPL"
})

In [None]:
chart_df = ind_df.merge(ood_df, on=["model", "user_model", "train_policy", "random_state"])
chart_df = chart_df.melt(
    ["model", "user_model", "train_policy", "test_policy", "random_state",],
    var_name="metric",
)
chart_df.head()

## Plot

In [None]:
def theme():
    return {
        "config": {
             "title": {
                "font": "serif",
                "fontWeight": "normal",
                "fontSize": 20,
            },
            "axis": {
                "titleFont": "serif",
                "titleFontWeight": "normal",
                "titleFontSize": 20,
                "labelFont": "serif",
                "labelFontWeight": "normal",
                "labelFontSize": 16
            }
        },
    }

In [None]:
def plot_metric(df, metric, title, width, height, y_domain, x_title=""):
    chart = alt.Chart(
        df[df["metric"] == metric],
        title=title,
        width=width,
        height=height,

    )

    bars = chart.mark_bar().encode(
        x=alt.X("model", title=x_title),
        y=alt.Y("mean(value)", title="", scale=alt.Scale(domain=y_domain)),
        color=alt.Color("model", legend=None),
    )

    ci = chart.mark_errorbar(extent="ci").encode(
        x=alt.X("model"),
        y=alt.Y("value", title=""),
        strokeWidth=alt.value(3)
    )

    return bars + ci

def plot(df, width, height):
    top = alt.hconcat(
        plot_metric(chart_df, metric="in-distribution PPL", title="in-distribution PPL⭣", width=width, height=height, y_domain=(1.195, 1.21)),
        plot_metric(chart_df, metric="out-of-distribution PPL", title="out-of-distribution PPL⭣", width=width, height=height, y_domain=(1.18, 1.24))
    )

    bottom = alt.hconcat(
        plot_metric(chart_df, metric="nDCG", title="nDCG⭡", x_title="models", width=width, height=height, y_domain=(0.4, 1)),
        plot_metric(chart_df, metric="CMIP", title="CMIP⭣", x_title="models", width=width, height=height, y_domain=(-0.03, 0.2))
    )

    return alt.vconcat(top, bottom)

alt.themes.register("latex", theme)
alt.themes.enable("latex")
chart = plot(df, 175, 150)
chart.save(output_directory / f"{user_model}-{train_policy}-{test_policy}.pdf")
chart