In [None]:
import altair as alt

from pathlib import Path

from src.data import WandbLoader

alt.data_transformers.disable_max_rows()

In [None]:
experiment_name = "sigir-cmip"
run_name = "visual-example"

loader = WandbLoader(
    "your-entity",
    "your-project",
    experiment_name,
    run_name,
)

output_directory = Path(f"figures/{experiment_name}/{run_name}")
output_directory.mkdir(parents=True, exist_ok=True)

# Plot visual example for CMIP

In [None]:
user_model = "GradedPBM"
train_policy = "NoisyOraclePolicy"
test_policy = "UniformPolicy"
models = ["DCTR", "PBM"]
temperature = 1
random_state = 43670

In [None]:
policy_df = loader.load_policy_scores(user_model, train_policy, test_policy, models, temperature)
metric_df = loader.load_metrics()
metric_df.head()

In [None]:
policy_df = policy_df[policy_df.random_state == random_state]

metric_df = metric_df[
    (metric_df["user_model"] == user_model)
    & (metric_df["train_policy"] == train_policy)
    & (metric_df["test_policy"] == test_policy)
    & (metric_df["model"].isin(models))
    & (metric_df["temperature"] == temperature)
]

## Plot

In [None]:
def theme():
    return {
        "config": {
             "title": {
                "font": "serif",
                "fontWeight": "normal",
                "fontSize": 16,
                "dx": 5
            },
            "axis": {
                "titleFont": "serif",
                "titleFontWeight": "normal",
                "titleFontSize": 16,
                "labelFont": "serif",
                "labelFontWeight": "normal",
                "labelFontSize": 16
            },
            "headerColumn": {
                "titleFont": "serif",
                "titleFontWeight": "normal",
                "titleFontSize": 16,
                "labelFont": "serif",
                "labelFontWeight": "normal",
                "labelFontSize": 16
            },
            "text": {
                "font": "serif",
                "fontSize": 14,
            }
        },
    }

In [None]:
metric_df = metric_df.rename(columns={
    "test/nDCG": "nDCG",
    "test/cmi": "CMIP",
})

metric_df = metric_df.melt(
    ["model", "user_model", "train_policy", "test_policy", "random_state",],
    ["CMIP", "nDCG"],
    var_name="metric",
)

In [None]:
def plot_policy(policy_df, model: str, width: int, height: int, is_first=False):
    return alt.Chart(
        policy_df[policy_df["model"] == model],
        width=width,
        height=height
    ).mark_point(opacity=0.5).encode(
        column=alt.Column(
            "y:O",
            title="True relevance" if is_first else None,
            spacing=5,
            header=alt.Header(labels=is_first, titlePadding=0, labelPadding=5),
        ),
        x=alt.X("y_logging_policy", axis=alt.Axis(values=[0, 2, 4, 6]), title="Logging policy" if not is_first else None),
        y=alt.Y("y_predict", title=model),
    )
    
    return scatter

def plot_metric(metric_df, model, metric, title, width, height, domain, text_spacing=-10):
    source = metric_df[
        (metric_df["model"] == model)
        & (metric_df["metric"] == metric)
    ].copy()
    
    base = alt.Chart(
        source,
        title=title,
        width=width,
        height=height,
    )
    
    bar = alt.Chart(
        source,
        title=title,
        width=width,
        height=height
    ).mark_bar().encode(
        y=alt.Y("mean(value)",
        title="",
        scale=alt.Scale(domain=domain, nice=False, zero=False)),
        color=alt.Color("metric", legend=None),
    )

    text = bar.mark_text(
        align="center",
        baseline="middle",
        dy=text_spacing,
    ).encode(
        text=alt.Text("mean(value):Q", format=",.3f")
    )
    
    return bar + text

def plot(metric_df, policy_df, scatter_width = 100, bar_width = 50, height=100):
    dctr = (
        plot_policy(policy_df, "DCTR", scatter_width, height, is_first=True)
        | plot_metric(metric_df, "DCTR", "CMIP", "CMIP⭣", bar_width, height, (0, 0.22))
        | plot_metric(metric_df, "DCTR", "nDCG", "nDCG⭡", bar_width, height, (0, 1.2))
    )
    
    pbm = (
        plot_policy(policy_df, "PBM", scatter_width, height)
        | plot_metric(metric_df, "PBM", "CMIP", "CMIP⭣", bar_width, height, (0, 0.22), text_spacing=-15)
        | plot_metric(metric_df, "PBM", "nDCG", "nDCG⭡", bar_width, height, (0, 1.2))
    )
    
    
    return alt.vconcat(dctr, pbm, spacing=-5)

alt.themes.register("latex", theme)
alt.themes.enable("latex")
chart = plot(metric_df, policy_df.groupby(["model", "y"]).head(300))
# chart.save(output_directory / "CMIP-DCTR-PBM.pdf")
chart