In [41]:
import pandas as pd
from src.experiments.common import get_run_from_path, EXPERIMENT_PATH
from src.experiments.wrangle import get_multirun_statistics
from src.visualization.common import setup_altair
import altair as alt
from altair import expr, datum

setup_altair()

In [42]:
sghmc_var_est = get_multirun_statistics(
    multirun=get_run_from_path(
        EXPERIMENT_PATH / "mnist/2021-10-31/13-41-02"
    ),
    config_values=["sampler.lr", "sampler.alpha", "sampler.estimation_margin", "model.activation_func._target_", "variance_estimator._target_"],
)

In [43]:
base = (
    sghmc_var_est.reset_index()
    .rename(columns=lambda x: x.replace(".", "/"))
    .pipe(alt.Chart)
    .encode(
        x="step",
        y="err/val",
        tooltip=['step', 'err/val', 'sampler/lr', 'sampler/estimation_margin', "model/activation_func/_target_"]
    )
    .mark_line()
)


In [44]:
layers = []
for margin, sheme in zip([1.3, 10, 100], ["yelloworangebrown", "goldgreen", "redpurple"]):
    layers.append(
        base
        .transform_filter(datum["sampler/estimation_margin"] == margin)
        .encode(
            color=alt.Color(
                "sampler/lr:O", 
                scale=alt.Scale(scheme=sheme),
                title=f"[margin={margin}] lr:"
            )
        )
    )

In [45]:
# fmt: off
(
    alt.layer(*layers)
    .resolve_scale(color="independent")
    .facet(
        column="model/activation_func/_target_", 
        row="sampler/alpha"
    )
    .interactive()
)


In [60]:
dropout_data = get_multirun_statistics(
    multirun=get_run_from_path(
        EXPERIMENT_PATH / "mnist/2021-10-28/16-48-09_sgd_dropout"
    ),
    config_values=["model.dropout", "inference.lr"],
)
map_data = get_multirun_statistics(
    multirun=get_run_from_path(EXPERIMENT_PATH / "mnist/2021-10-28/16-49-20-sgd_map"),
    config_values=["inference.lr"],
)
vi_data = get_multirun_statistics(
    multirun=get_run_from_path(EXPERIMENT_PATH / "mnist/2021-10-28/16-49-59-vi"),
    config_values=["inference.lr", "inference.kl_weighting_scheme._target_"],
)
sghmc_data = get_multirun_statistics(
    multirun=get_run_from_path(EXPERIMENT_PATH / "mnist/2021-11-04/15-49-57"),
    config_values=["sampler.lr", "model.activation_func._target_", "sampler.alpha"],
)

# Dropout

In [61]:
# fmt: off
(
    dropout_data
    .reset_index()
    .rename(columns=lambda x: x.replace(".", "/"))
    .pipe(alt.Chart)
    .mark_line(clip=True)
    .encode(
        x=alt.X("step"),
        y=alt.Y("err/val", scale=alt.Scale(domain=(0, 0.09))),
        color=alt.Color(
            "inference/lr:O", scale=alt.Scale(scheme="turbo"), title="Learning Rate"
        ),
        column="model/dropout",
    
    )
)
# fmt: on


# SDG - MAP

In [62]:
# fmt: off
(
    map_data
    .reset_index()
    .rename(columns=lambda x: x.replace(".", "/"))
    .pipe(alt.Chart)
    .mark_line(clip=True)
    .encode(
        x=alt.X("step"),
        y=alt.Y("err/val", scale=alt.Scale(domain=(0, 0.09))),
        color="inference/lr:O" 
    )
)
# fmt: on


# VI

In [63]:
# fmt: off
(
    vi_data
    .reset_index()
    .rename(columns=lambda x: x.replace(".", "/"))
    .pipe(alt.Chart)
    .mark_line(clip=True)
    .encode(
        x=alt.X("step"),
        y=alt.Y("err/val", scale=alt.Scale(domain=(0, 1.))),
        color="inference/lr:O", 
        column="inference/kl_weighting_scheme/_target_"
    )
).interactive()


# SGHMC (beta=0)

In [64]:
# fmt: off
(
    sghmc_data
    .reset_index()
    .rename(columns=lambda x: x.replace(".", "/"))
    .pipe(alt.Chart)
    .mark_line(clip=True)
    .encode(
        x=alt.X("step"),
        y=alt.Y("err/val", scale=alt.Scale(domain=(0, 0.1))),
        color="sampler/lr:O",
        row="model/activation_func/_target_:N",
        column="sampler/alpha:N"
    )
).interactive()

# Compare best models

In [65]:
combined = pd.concat(
    {
        "dropout": dropout_data,
        "map": map_data,
        "vi": vi_data,
        "sghmc": sghmc_data,
        "sghmc_var_est": sghmc_var_est,
    },
    names=["inference"],
)

best_runs = (
    combined.groupby(["inference", "id"])
    .apply(lambda x: x["err/val"][-10:].mean())
    .groupby(level="inference")
    .idxmin()
    .apply(lambda x: x[1])
)


16-48-09_sgd_dropout/1          dropout
16-49-20-sgd_map/1                  map
15-49-57/0                        sghmc
13-41-02/4                sghmc_var_est
16-49-59-vi/0                        vi
Name: inference, dtype: object

In [74]:
(
    combined
    .droplevel("inference")
    .loc[best_runs]
    
    )

Unnamed: 0_level_0,Unnamed: 1_level_0,err/val,model.dropout,inference.lr,inference.kl_weighting_scheme._target_,sampler.lr,model.activation_func._target_,sampler.alpha,sampler.estimation_margin,variance_estimator._target_,inference
id,step,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
16-48-09_sgd_dropout/1,390,0.0409,0.5,0.001,,,,,,,dropout
16-48-09_sgd_dropout/1,781,0.0301,0.5,0.001,,,,,,,dropout
16-48-09_sgd_dropout/1,1172,0.0275,0.5,0.001,,,,,,,dropout
16-48-09_sgd_dropout/1,1563,0.0210,0.5,0.001,,,,,,,dropout
16-48-09_sgd_dropout/1,1954,0.0227,0.5,0.001,,,,,,,dropout
...,...,...,...,...,...,...,...,...,...,...,...
16-49-59-vi/0,311235,0.0308,,0.001,src.inference.vi.ExponentialKLWeight,,,,,,vi
16-49-59-vi/0,311626,0.0299,,0.001,src.inference.vi.ExponentialKLWeight,,,,,,vi
16-49-59-vi/0,312017,0.0291,,0.001,src.inference.vi.ExponentialKLWeight,,,,,,vi
16-49-59-vi/0,312408,0.0308,,0.001,src.inference.vi.ExponentialKLWeight,,,,,,vi


In [77]:
(
    combined
    .droplevel("inference")
    .loc[best_runs]
    .join(pd.Series(best_runs.index, index=pd.Index(best_runs, name="id")), how="left")
    .reset_index()
    .pipe(alt.Chart)
    .mark_line(clip=True)
    .encode(
        x="step", 
        y=alt.Y("err/val", scale=alt.Scale(domain=(0,0.09))),
        color="inference",
        tooltip=["id", "step", "err/val"],
    )
    .interactive()
)

In [None]:
(combined
    .droplevel("inference")
    .loc[best_runs]
    .reset_index()
    )