In [None]:
import pandas as pd
from src.experiments.common import get_run_from_path, EXPERIMENT_PATH
from src.experiments.wrangle import get_multirun_statistics
from src.visualization.common import setup_altair
import altair as alt

setup_altair()

In [None]:
dropout_data = get_multirun_statistics(
    multirun=get_run_from_path(
        EXPERIMENT_PATH / "mnist/2021-10-28/16-48-09_sgd_dropout"
    ),
    config_values=["model.dropout", "inference.lr"],
)
map_data = get_multirun_statistics(
    multirun=get_run_from_path(EXPERIMENT_PATH / "mnist/2021-10-28/16-49-20-sgd_map"),
    config_values=["inference.lr"],
)
vi_data = get_multirun_statistics(
    multirun=get_run_from_path(EXPERIMENT_PATH / "mnist/2021-10-28/16-49-59-vi"),
    config_values=["inference.lr", "inference.kl_weighting_scheme._target_"],
)
sghmc_data = get_multirun_statistics(
    multirun=get_run_from_path(EXPERIMENT_PATH / "mnist/2021-10-28/16-50-39-sghmc"),
    config_values=["inference.sampler.lr", "model.activation_func._target_"],
)
sghmc_adam_data = get_multirun_statistics(
    multirun=get_run_from_path(
        EXPERIMENT_PATH / "mnist/2021-10-28/16-51-15-sghmc_adam"
    ),
    config_values=["inference.sampler.lr", "model.activation_func._target_"],
)
sghmc_interbatch_data = get_multirun_statistics(
    multirun=get_run_from_path(
        EXPERIMENT_PATH / "mnist/2021-10-28/16-52-37-sghmc_interbatch"
    ),
    config_values=["inference.sampler.lr", "model.activation_func._target_"],
)


# Dropout

In [None]:
# fmt: off
(
    dropout_data
    .reset_index()
    .rename(columns=lambda x: x.replace(".", "/"))
    .pipe(alt.Chart)
    .mark_line(clip=True)
    .encode(
        x=alt.X("step"),
        y=alt.Y("err/val", scale=alt.Scale(domain=(0, 0.09))),
        color=alt.Color(
            "inference/lr:O", scale=alt.Scale(scheme="plasma"), title="Learning Rate"
        ),
        column="model/dropout",
    
    )
)
# fmt: on


# SDG - MAP

In [None]:
# fmt: off
(
    map_data
    .reset_index()
    .rename(columns=lambda x: x.replace(".", "/"))
    .pipe(alt.Chart)
    .mark_line(clip=True)
    .encode(
        x=alt.X("step"),
        y=alt.Y("err/val", scale=alt.Scale(domain=(0, 0.09))),
        color="inference/lr:O" 
    )
)
# fmt: on


# VI

In [None]:
# fmt: off
(
    vi_data
    .reset_index()
    .rename(columns=lambda x: x.replace(".", "/"))
    .pipe(alt.Chart)
    .mark_line(clip=True)
    .encode(
        x=alt.X("step"),
        y=alt.Y("err/val", scale=alt.Scale(domain=(0, 1.))),
        color="inference/lr:O", 
        column="inference/kl_weighting_scheme/_target_"
    )
)


# SGHMC (beta=0)

In [None]:
# fmt: off
(
    sghmc_data
    .reset_index()
    .rename(columns=lambda x: x.replace(".", "/"))
    .pipe(alt.Chart)
    .mark_line(clip=True)
    .encode(
        x=alt.X("step"),
        y=alt.Y("err/val", scale=alt.Scale(domain=(0, 0.1))),
        color="inference/sampler/lr:N",
        column="model/activation_func/_target_:N"
    )
).interactive()

# SGHMC (beta=adam)

In [None]:
# fmt: off
(
    sghmc_adam_data
    .reset_index()
    .rename(columns=lambda x: x.replace(".", "/"))
    .pipe(alt.Chart)
    .mark_line(clip=True)
    .encode(
        x=alt.X("step"),
        y=alt.Y("err/val", scale=alt.Scale(domain=(0, 0.1))),
        color="inference/sampler/lr:N",
        column="model/activation_func/_target_:N"
    )
).interactive()

# SGHMC Interbatch

In [None]:
# fmt: off
(
    sghmc_interbatch_data
    .reset_index()
    .rename(columns=lambda x: x.replace(".", "/"))
    .pipe(alt.Chart)
    .mark_line(clip=True)
    .encode(
        x=alt.X("step"),
        y=alt.Y("err/val", scale=alt.Scale(domain=(0, 0.1))),
        color="inference/sampler/lr:N",
        column="model/activation_func/_target_:N"
    )
).interactive()

# Compare best models

In [None]:
combined = pd.concat(
    {
        "dropout": dropout_data,
        "map": map_data,
        "vi": vi_data,
        "sghmc": sghmc_data,
        "sghmc_adam": sghmc_adam_data,
        "sghmc_interbatch": sghmc_interbatch_data,
    },
    names=["inference"],
)

best_runs = (
    combined.groupby(["inference", "id"])
    .apply(lambda x: x["err/val"][-10:].mean())
    .groupby(level="inference")
    .idxmin()
    .apply(lambda x: x[1])
)


In [None]:
(
    combined
    .droplevel("inference")
    .loc[best_runs]
    .reset_index()
    .pipe(alt.Chart)
    .mark_line(clip=True)
    .encode(
        x="step", 
        y=alt.Y("err/val", scale=alt.Scale(domain=(0,0.09))),
        color="id",
        tooltip=["id", "step", "err/val"],
    )
    .interactive()
)

In [None]:
(combined
    .droplevel("inference")
    .loc[best_runs]
    .reset_index()
    )