# Plotting the tempering curves

In [None]:
import os
from glob import glob
import pandas as pd
import numpy as np
from scipy import stats
from matplotlib import pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import seaborn as sns
import json
import torch

from bnn_priors.exp_utils import load_samples

sns.set(context="paper", style="white", font_scale=1.8)
%matplotlib inline

## Setup

In [None]:
# Put in the name of your experiment here
exp_name = "my_experiment"

# Set this to True if you ran evaluations using eval_bnn.py
# Otherwise, if you just want to use the evaluations that ran with the training, set it to False
use_eval_runs = True

# Choose an experiment type from ["mnist", "fashion_mnist", "cifar10"]
exp_type = "mnist"

In [None]:
# We assume that your experiments are saved in ../results
# If that is not the case, you'll have to change it here
train_files = f"../results/{exp_name}/*/config.json"
eval_files = f"../results/{exp_name}/*/eval/*/config.json"

In [None]:
if use_eval_runs:
    files = eval_files
else:
    files = train_files

In [None]:
if exp_type == "mnist":
    calibration_data = "rotated_mnist"
    ood_data = "fashion_mnist"
elif exp_type == "fashion_mnist":
    calibration_data = "fashion_mnist"
    ood_data = "mnist"
elif exp_type == "cifar10":
    calibration_data = "cifar10c"
    ood_data = "svhn"
else:
    raise ValueError(f"Unknown experiment type {exp_type}")

In [None]:
# these are the priors we used in our paper
monolithic_priors = ["gaussian", "convcorrnormal", "laplace", "student-t"]

In [None]:
def plot_tempering_curve(runs, y, yerr=None, ylabel="performance", ylim=None, x="weight_prior",
              title=None, baseline=None, baseline_err=None, log_x=True, legend=True, legend_loc="best",
                        invert_y=False):
    """This function plots the tempering curve of y for different curves x."""
    scales = sorted(runs.weight_scale.unique())
    temps = sorted(runs.temperature.unique())
    
    if 0. in temps:
        temps.remove(0.)
        
    
    fig, axes = plt.subplots(len(scales), 1, sharex=True, figsize=(3*2+2,3*len(scales)+2))
    
    if len(scales) == 1:
        axes = [axes]
        
    for scale, ax in zip(scales, axes):
        for x_val in runs.sort_values([x], ascending=False)[x].unique():
            df = runs.sort_values([x, "weight_scale", "temperature"]).query(f"weight_scale == {scale} & {x} == '{x_val}'")
            if len(df["temperature"].unique()) != len(df["temperature"]):
                df_stderr = df.groupby(by="temperature").apply(lambda group: group.std() / np.sqrt(len(group)))[[y]]
                df_mean = df.groupby(by="temperature").mean()
                df_mean[f"{y}_stderr"] = df_stderr[y]
                df = df_mean
                yerr = f"{y}_stderr"
                df.reset_index(level=0, inplace=True)
            df.plot(x="temperature", y=y, kind="line", legend=legend, ax=ax, label=x_val, linewidth=3)
            if yerr is not None:
                ax.fill_between(df["temperature"], df[y] - df[yerr], df[y] + df[yerr], alpha=0.3)
            # ax.set_title(f"scale={scale}")
        ax.set_ylabel(ylabel)
        if ylim is not None:
            ax.set_ylim(ylim)
        if baseline is not None:
            ax.axhline(y=baseline, color="gray", linestyle="dashed", label="SGD", linewidth=2)
            if baseline_err is not None:
                ax.fill_between(df["temperature"], baseline-baseline_err, baseline+baseline_err, color="gray", alpha=0.3)
        if legend:
            plt.legend(frameon=False, loc=legend_loc)
        if log_x:
            ax.set(xscale="log")
        if invert_y:
            ax.invert_yaxis()
        ax.set_xlim(df["temperature"].min(), df["temperature"].max())
                
    if title is not None:
        fig.suptitle(title)
        fig.tight_layout(rect=[0, 0, 1, 0.97])
    else:
        fig.tight_layout()

    return fig

## Load results

In [None]:
runs = []
for config_file in glob(files):
    with open(config_file) as infile:
        config = pd.Series(json.load(infile))
    with open(config_file[:-11] + "run.json") as infile:
        result = pd.Series(json.load(infile)["result"], dtype=np.float32)
    run_data = pd.concat([config, result])
    runs.append(run_data)
    if not use_eval_runs and run_data["weight_prior"] == "improper":
        print(run_data["weight_prior"], run_data["temperature"], run_data["weight_scale"], config_file)
runs_all = pd.concat(runs, axis=1).T

In [None]:
if not "acc_mean" in runs_all.columns:
    runs_all["acc_mean"] = runs_all["acc_ensemble"]

if not "lp_mean" in runs_all.columns:
    runs_all["lp_mean"] = runs_all["lp_ensemble"]

In [None]:
runs_all["error_mean"] = 1. - runs_all["acc_mean"]
runs_all["nll_mean"] = - runs_all["lp_mean"]
runs_all["neg_auroc"] = - runs_all["auroc"]

In [None]:
runs_all.head()

In [None]:
for col in runs_all.columns:
    runs_all[col] = pd.to_numeric(runs_all[col], errors="ignore")

In [None]:
# filter out the failed runs
runs_all = runs_all[runs_all["acc_mean"].notnull()]

In [None]:
# this is just to use the nicer label "correlated" in the plots instead of "convcorrnormal"
runs_all.replace("convcorrnormal", "correlated", inplace=True)
monolithic_priors = ["gaussian", "correlated", "laplace", "student-t"]

## Load SGD baselines

In [None]:
# if you ran SGD baselines with train_sgd.py you can load the results in here
# otherwise don't run these cells
sgd_runs = pd.read_pickle("../results/4.1_sgd_runs.pkl.gz", compression="gzip")

In [None]:
sgd_runs["result.error_ensemble"] = 1. - sgd_runs["result.acc_ensemble"]
sgd_runs["result.nll_ensemble"] = - sgd_runs["result.lp_ensemble"]
sgd_runs["ood.neg_auroc"] = - sgd_runs["ood.auroc"]

In [None]:
def get_sgd_results(model_type, data, measure):
    results = sgd_runs.query(f"model == '{model_type}' and data == '{data}'")[measure]
    mean = results.mean()
    stderr = results.std() / np.sqrt(len(results))
    return mean, stderr

## Evaluate predictive performance

### Tempering curves

In [None]:
if use_eval_runs:
    runs_selected = runs_all.query("eval_data != eval_data")  # basically checks for None
else:
    runs_selected = runs_all

In [None]:
runs_subselected = runs_selected.query(f"weight_prior in {monolithic_priors}")

In [None]:
# again, just run this if you have an SGD baseline
sgd_mean, sgd_stderr = get_sgd_results(model_type=runs_all["model"].iloc[0], data=runs_all["data"].iloc[0], measure="result.error_ensemble")

In [None]:
# if you don't have an SGD baseline, remove the last two arguments
fig = plot_tempering_curve(runs_subselected, y="error_mean", ylabel="error", legend=True, baseline=sgd_mean, baseline_err=sgd_stderr)

In [None]:
fig.savefig(f"../figures/{exp_name}_acc_tempering_curve.pdf")

In [None]:
fig.axes[0].set_title("")
fig.axes[0].legend(frameon=False, labelspacing=0.2)
fig.set_size_inches(5,3)
fig.tight_layout()
fig.savefig(f"../figures/{exp_name}_acc_tempering_curve_small.pdf", bbox_inches = 'tight', pad_inches = 0.1)
fig

In [None]:
# again, just run this if you have an SGD baseline
sgd_mean, sgd_stderr = get_sgd_results(model_type=runs_all["model"].iloc[0], data=runs_all["data"].iloc[0], measure="result.nll_ensemble")

In [None]:
# if you don't have an SGD baseline, remove the last two arguments
fig = plot_tempering_curve(runs_subselected, y="nll_mean", ylabel="NLL", legend=True, baseline=sgd_mean, baseline_err=sgd_stderr)

In [None]:
fig.savefig(f"../figures/{exp_name}_nll_tempering_curve.pdf")

In [None]:
fig.axes[0].set_title("")
fig.axes[0].legend(frameon=False, labelspacing=0.2)
fig.set_size_inches(5,3)
fig.tight_layout()
fig.savefig(f"../figures/{exp_name}_nll_tempering_curve_small.pdf", bbox_inches = 'tight', pad_inches = 0.1)
fig

In [None]:
runs_selected = runs_all[runs_all['eval_data'].str.contains(calibration_data, na=False)]

In [None]:
runs_subselected = runs_selected.query(f"weight_prior in {monolithic_priors}")

In [None]:
# again, just run this if you have an SGD baseline
sgd_mean, sgd_stderr = get_sgd_results(model_type=runs_all["model"].iloc[0], data=runs_all["data"].iloc[0], measure="calibration.ece")

In [None]:
# if you don't have an SGD baseline, remove the last two arguments
fig = plot_tempering_curve(runs_subselected, y="ece", ylabel="ECE", legend=True, baseline=sgd_mean, baseline_err=sgd_stderr)

In [None]:
fig.savefig(f"../figures/{exp_name}_ece_tempering_curve.pdf")

In [None]:
fig.axes[0].set_title("")
fig.axes[0].legend(frameon=False, labelspacing=0.2)
fig.set_size_inches(5,3)
fig.tight_layout()
fig.savefig(f"../figures/{exp_name}_ece_tempering_curve_small.pdf", bbox_inches = 'tight', pad_inches = 0.1)
fig

In [None]:
runs_selected = runs_all.query(f"'{ood_data}' in eval_data")

In [None]:
runs_subselected = runs_selected.query(f"weight_prior in {monolithic_priors}")

In [None]:
# again, just run this if you have an SGD baseline
sgd_mean, sgd_stderr = get_sgd_results(model_type=runs_all["model"].iloc[0], data=runs_all["data"].iloc[0], measure="ood.auroc")

In [None]:
# if you don't have an SGD baseline, remove the last two arguments
fig = plot_tempering_curve(runs_subselected, y="auroc", ylabel="OOD AUROC", legend=True, invert_y=True, baseline=sgd_mean, baseline_err=sgd_stderr)

In [None]:
fig.savefig(f"../figures/{exp_name}_ood_auroc_tempering_curve.pdf")

In [None]:
fig.axes[0].set_title("")
fig.axes[0].legend(frameon=False, labelspacing=0.2)
fig.set_size_inches(5,3)
fig.tight_layout()
fig.savefig(f"../figures/{exp_name}_ood_auroc_tempering_curve_small.pdf", bbox_inches = 'tight', pad_inches = 0.1)
fig