In [1]:
import os
from glob import glob
import pandas as pd
import numpy as np
from pathlib import Path
from scipy import stats
from matplotlib import pyplot as plt
from matplotlib.colors import Normalize
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib import gridspec
import json
import torch
import gpytorch
import h5py
import collections
import scipy
import torch
import subprocess
import sys

from bnn_priors import prior

from bnn_priors.exp_utils import load_samples
from bnn_priors.notebook_utils import collect_runs, unique_cols, json_dump, json_load

%matplotlib inline
%config InlineBackend.print_figure_kwargs = {'bbox_inches':None}

# Load data and plot figures 3, 5 and A.10

With the repository we ship the extracted data, necessary to plot the figures. The code for *creating* the CSVs and JSONs is at the end of this notebook.

In [30]:
runs_with_eval = pd.read_csv("Plot_CIFAR10_resnet_data/sgd_runs.csv", index_col=0)
opt_dfs = json_load("Plot_CIFAR10_resnet_data/opt_dfs.json")
opt_lengthscale = json_load("Plot_CIFAR10_resnet_data/opt_lengthscale.json")
(covs, lens, conv_n_channels) = pd.read_pickle("Plot_CIFAR10_resnet_data/covs_lens.pkl.gz")

conv_keys = list(covs.keys())
conv_keys.sort(key=lambda k: (int(k.split('.')[2]), k))

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(9, 2.3), sharex=True)

ax = axes[1]
ax.set_title("Degrees of freedom")
ax.plot([opt_dfs[k][1] for k in conv_keys])
ax.set_xticks([1, 7, 13, 18])
ax.set_xticklabels(["*L2", "*L8", "*L14", "L19"])
ax.set_yscale('log')

ax = axes[0] 
ax.set_title("Lengthscale")
ax.plot([opt_lengthscale[k] for k in conv_keys])

fig.text(0.5, 0, 'Layer index (input = L0)', ha='center')
#fig.suptitle("ResNet-20, CIFAR-10 SGD: fitted parameters for T-distribution and Gaussian")
fig.tight_layout()
fig.savefig("../figures/210122_resnet_fitted.pdf", bbox_inches="tight")

In [None]:
# Figure 3

plt.rcParams.update({
    "axes.linewidth": 0.5,
    'ytick.major.width': 0.5,
    'xtick.major.width': 0.5,
    'ytick.minor.width': 0.5,
    'xtick.minor.width': 0.5,
    "figure.dpi": 300,
})
fig_width_pt = 234.8775
inches_per_pt = 1.0/72.27               # Convert pt to inches
fig_width = fig_width_pt*inches_per_pt  # width in inches

fig, axes = plt.subplots(1, 1, figsize=(fig_width, 1.3), sharex=True, gridspec_kw=dict(
    top=1, bottom=0.34, left=0.17, right=1))

ax = axes
ax.set_ylabel("Deg. of freedom", horizontalalignment="right", position=(0, 1))
ax.plot([opt_dfs[k][1] for k in conv_keys])
ax.set_xticks([1, 7, 13, 18])
ax.set_xticklabels(["*L2", "*L8", "*L14", "L19"])
ax.set_yscale('log')
ax.set_xlabel('Layer index (input = L0)')

fig.savefig("../figures/210126-resnet-dof.pdf")

In [None]:
# Figure 5

plt.rcParams.update({
    "text.usetex": False,
    "font.family": "sans-serif"})

fig_width_pt = 487.8225
inches_per_pt = 1.0/72.27               # Convert pt to inches
fig_width = fig_width_pt*inches_per_pt  # width in inches

mean_covs = {k: covs[k]/lens[k] * conv_n_channels[k] for k in conv_keys}

plots_x = 7
plots_y = 3

margins = dict(
    left=0.015,
    right=0.01,
    top=0.007,
    bottom=0.02)

wsep = hsep = 0.002
w_cov_sep = 0.02
h_cov_sep = 0.03
height = width = (1 - w_cov_sep*(plots_x-1) - wsep*3*plots_x
         - margins['left'] - margins['right'])/plots_x / 3
ttl_marg=5

fig_height_mult = (margins['bottom'] + (height*3 + hsep*2)*plots_y + h_cov_sep*plots_y + margins['top'])

# make figure rectangular and correct vertical sizes
hsep /= fig_height_mult
height /= fig_height_mult
h_cov_sep /= fig_height_mult
margins['bottom'] /= fig_height_mult
margins['top'] /= fig_height_mult
fig = plt.figure(figsize=(fig_width, fig_width *fig_height_mult))

print("fig height = ", fig_width *fig_height_mult)

cbar_height = height*3 + hsep*2


extreme = max(*(mean_covs[k].abs().max().item() for k in mean_covs.keys()))  #1.68
#assert extreme < 1.7
extreme = 2
norm = Normalize(-extreme, extreme)
    
def plot_at(key, base_bottom, base_left, is_bottom_row=False, is_left_col=False, title="title"):
    max_bottom = base_bottom
    max_left = base_left
    
    for y in range(3):
        for x in range(3):
            bottom = base_bottom + (height+hsep) * (2-y)
            left = base_left + (width+wsep) * x
            max_bottom = max(max_bottom, bottom+height+hsep)
            max_left = max(max_left, left+width+wsep)
            
            if x == 0 and is_left_col:
                yticks = [1, 2, 3]
            else:
                yticks = []

            if (y == 2 and is_bottom_row) or title=="Layer 15":
                xticks = [1, 2, 3]
            else:
                xticks = []
            ax = fig.add_axes([left, bottom, width, height], xticks=xticks, yticks=yticks)
                              #title=f"cov. w/ ({x + 1}, {y +1})")
            extreme = 1
            mappable = ax.imshow(
                mean_covs[key][y*3+x, :].reshape((3, 3)) / mean_covs[key].abs().max().item() ,
                cmap=plt.get_cmap('RdBu'),
                extent=[0.5, 3.5, 3.5, 0.5], norm=Normalize(-extreme, extreme))
            ax.plot([x+1], [y+1], marker='x', ls='none', color=('white' if title == "Layer 19" else 'white'),
                   ms=3, markeredgewidth=0.5)
            ax.tick_params(left=False, bottom=False, labelsize="xx-small", pad=0)  # remove ticks

            if y==0 and x==1:
                ttl = ax.set_title(title, pad=ttl_marg, size="x-small")
    return max_bottom, max_left, mappable

# Iterate over the indices for axes, starting from the bottom-left of the plots
cur_bottom = margins['bottom']
for y_idx in reversed(range(0, plots_y)):
    cur_left = margins['left']
    for x_idx in range(min(len(conv_keys)-y_idx*plots_x, plots_x)):
        key = conv_keys[y_idx*plots_x + x_idx]
        if key in ['net.module.3.main.0.weight_prior.p',
                   'net.module.6.main.0.weight_prior.p',
                   'net.module.9.main.0.weight_prior.p',]:
            marker = "*"
        else:
            marker = ""
        
        
        next_bottom, cur_left, mappable = plot_at(
            key, cur_bottom, cur_left,
            is_bottom_row=(y_idx == plots_y-1),
            is_left_col=(x_idx == 0),
            title=f"{marker}Layer {conv_keys.index(key) + 1}")
        cur_left += w_cov_sep
        
    if cur_bottom == margins["bottom"]:
        cbar_width = width/3
        cbar_ax = fig.add_axes([cur_left, cur_bottom, cbar_width, next_bottom-cur_bottom])
        fig.colorbar(mappable, cax=cbar_ax, ticks=[-extreme, -1, 0, 1, extreme])
        
        # plot absolute variance
        bottom = cur_bottom+0.02
        lmarg = 2.6667*width + 2*wsep
        ax = fig.add_axes([cur_left+lmarg, bottom,
                           1-(cur_left+lmarg + margins["right"] ), next_bottom-bottom])
        ax.set_ylabel("Max. variance", size="x-small") #, horizontalalignment="right", position=(0, 1))
        ax.plot([mean_covs[k].abs().max().item() for k in mean_covs.keys()])
        ax.set_xticks([1, 7, 13, 18])
        ax.set_xticklabels(["*L2", "*L8", "*L14", "L19"])
        ax.set_ylim((0, 6))
        #ax.set_yscale('log')
        #ax.set_xlabel('Layer index')
        ax.tick_params(labelsize="x-small")

        
        cbar_ax.tick_params(labelsize="x-small")
    cur_bottom = next_bottom + h_cov_sep
    
fig.savefig("../figures/210204_googleresnet_covariances_all_capped.pdf")

In [None]:
raise RuntimeError("Do you want to continue?")

# Collect SGD Runs for various data sets

Here we will only read the relevant CSV file. The cells enclosed in `if False` below were used to create it.

You need to run `jug/0_31_googleresnet_cifar10_sgd.py` to be able to run the following.

Run `eval_bnn.py` and construct the overall dataframe.

In [7]:
df = collect_runs("../logs/0_31_googleresnet_cifar10_sgd")

good_runs = df[(df["n_epochs"] == 600) & (df["status"] == "COMPLETED")]

In [8]:
def eval_bnn(**config):
    args = [sys.executable, "eval_bnn.py", "with",
                               *[f"{k}={v}" for k, v in config.items()]]
    print(" ".join(args))
    complete = subprocess.run(args)
    if complete.returncode != 0:
        raise SystemError(f"Process returned with code {complete.returncode}")

#for i, (_, run) in enumerate(good_runs.iterrows()):
if False:  # This would run eval_bnn.py on the relevant directory. Only needs to be run once.
    print(f"run {i}/{len(good_runs)}")
    config_file = str(run["the_dir"]/"config.json")
    
    calibration_data = {
        "mnist": "rotated_mnist",
        "fashion_mnist": "fashion_mnist",
        "cifar10": "cifar10c-gaussian_blur",
        "cifar10_augmented": "cifar10c-gaussian_blur",
    }[run["data"]]
    
    eval_bnn(is_run_sgd=True, calibration_eval=True, eval_data=calibration_data,
             config_file=config_file, skip_first=2, batch_size=128)
    
    ood_data = {
        "mnist": "fashion_mnist",
        "fashion_mnist": "mnist",
        "cifar10": "svhn",
        "cifar10_augmented": "svhn",
    }[run["data"]]    
    
    eval_bnn(is_run_sgd=True, ood_eval=True, eval_data=ood_data,
             config_file=config_file, skip_first=2, batch_size=128)

In [9]:
runs_with_eval = []

for _, run in good_runs.iterrows():
    corresponding = collect_runs(run["the_dir"]/"eval", metrics_must_exist=False)

    new_run = [run]
    for _, corr in corresponding.iterrows():
        orig_keys = [k for k in corr.index if k.startswith("result.")]
        if corr["calibration_eval"]:
            purpose = "calibration"
            assert not corr["ood_eval"]
        elif corr["ood_eval"]:
            purpose = "ood"
        else:
            raise ValueError("unknown purpose")
        new_keys = [k.replace("result.", purpose+".") for k in orig_keys]
        for k in new_keys:
            assert k not in run.index

        new_corr = corr[orig_keys]
        new_corr.index = new_keys
        new_run.append(new_corr)
    runs_with_eval.append(pd.concat(new_run))
runs_with_eval = pd.DataFrame(runs_with_eval)

# Get the lengthscales and df's from each layer

In [10]:
def collect_weights(df):
    samples = collections.defaultdict( lambda: [], {})
    for _, row in df.iterrows():
        try:
            s = load_samples(row["the_dir"]/"samples.pt", idx=-1, keep_steps=False)
        except pickle.UnpicklingError:
            continue
        assert len(samples.keys()) == 0 or set(s.keys()) == set(samples.keys())
        for k in s.keys():
            samples[k].append(s[k])
    return {k: torch.stack(v, dim=0) for k, v in samples.items()}

In [11]:
samples = collect_weights(good_runs[good_runs["data"] == "cifar10_augmented"])

In [12]:
samples['net.module.0.weight_prior.p'].shape

torch.Size([10, 16, 3, 3, 3])

In [None]:
for k in samples.keys():
    if k.endswith(".p"):
        print(k, tuple(samples[k].shape))

In [14]:
conv_keys = ["net.module.0.weight_prior.p", *filter(
    lambda k: k.endswith(".p") and "main" in k, samples.keys())]
conv_keys.sort(key=lambda k: (int(k.split('.')[2]), k))

In [15]:
covs = {}
lens = {}
for k in conv_keys:
    M = samples[k].view(-1, 3*3)
    covs[k] = (M.t() @ M)
    lens[k] = len(M)
conv_n_channels = {k: samples[k].size(-3) for k in conv_keys}
    
pd.to_pickle((covs, lens, conv_n_channels), "4.1_covs_lens.pkl.gz")

In [16]:
points = torch.from_numpy(np.mgrid[:3, :3].reshape(2, -1).T).contiguous().to(torch.float64)


In [17]:
import gpytorch
import math
torch.set_default_dtype(torch.float64)
kern = gpytorch.kernels.RBFKernel(batch_shape=torch.Size([1000]))
kern.lengthscale = torch.linspace(0.001**.5, 30**.5, 1000).unsqueeze(-1).pow(2)

In [18]:
S_inverse = kern(points).inv_matmul(torch.eye(9))
S_logdet = kern(points).logdet()

In [19]:
log_liks = {}
opt_lengthscale = {}
for k in covs.keys():
    with torch.no_grad():
        log_liks[k] = S_logdet.mul(lens[k] / -2) - 0.5 * S_inverse.mul(covs[k]).sum((-2, -1))
        opt_lengthscale[k] = kern.lengthscale[torch.argmax(log_liks[k])].item()
json_dump(opt_lengthscale, "4.1_opt_lengthscale.json")

In [None]:
k = next(iter(log_liks.keys()))

plt.plot(kern.lengthscale.squeeze(-1).detach(), log_liks[k])
plt.ylim((-10000, 0))

In [None]:
plt.plot(np.arange(len(conv_keys)), [opt_lengthscale[k] for k in conv_keys])

In [22]:
# check that log-likelihoods aren't buggy
dist = gpytorch.distributions.MultivariateNormal(torch.zeros(9), kern[100](points))
dist.log_prob(samples[k].view(-1, 9)).sum(), log_liks[k][100] - math.log(2*math.pi) * 9 * lens[k]/2

(tensor(-5513.4573, grad_fn=<SumBackward0>), tensor(-5513.4573))

In [23]:
opt_lengthscale

{'net.module.0.weight_prior.p': 0.6403542406466662,
 'net.module.3.main.0.weight_prior.p': 1.0704316857604412,
 'net.module.3.main.3.weight_prior.p': 1.0931095467878191,
 'net.module.4.main.0.weight_prior.p': 1.1508441889397887,
 'net.module.4.main.3.weight_prior.p': 1.1625694012987868,
 'net.module.5.main.0.weight_prior.p': 1.036860604040884,
 'net.module.5.main.3.weight_prior.p': 1.32040510780009,
 'net.module.6.main.0.weight_prior.p': 1.3079073317504737,
 'net.module.6.main.3.weight_prior.p': 1.3582550038061472,
 'net.module.7.main.0.weight_prior.p': 1.409553523481043,
 'net.module.7.main.3.weight_prior.p': 1.58284050379471,
 'net.module.8.main.0.weight_prior.p': 1.5150031056885025,
 'net.module.8.main.3.weight_prior.p': 1.7086917828865713,
 'net.module.9.main.0.weight_prior.p': 1.6521636013059522,
 'net.module.9.main.3.weight_prior.p': 1.4225267233402705,
 'net.module.10.main.0.weight_prior.p': 1.5016139099958647,
 'net.module.10.main.3.weight_prior.p': 1.680308836143859,
 'net.mod

# Get max df of multivariate-T

In [24]:
class MVTFitter(torch.nn.Module):
    def __init__(self, p, df, permute=None, event_dim=2):
        flat_p = p.view(-1, 9)
        cov = (flat_p.t() @ flat_p) / len(flat_p)
        
        super().__init__()
        self.dist = prior.MultivariateT(
            p.size(), torch.zeros(9), cov.cholesky().detach().to(torch.get_default_dtype()),
            df=torch.nn.Parameter(torch.tensor(df, requires_grad=True)),
            event_dim=event_dim, permute=permute)
        
        self.dist.p.requires_grad_(False)
        self.dist.p[...] = p
        
    def closure(self):
        self.zero_grad()
        lp = -self.dist.log_prob()
        lp.backward()
        return lp

In [25]:
opt_dfs = {}

try_df_inits = torch.linspace(math.log(2.1), math.log(1000), 300).exp()

for key in conv_keys:
    max_lik = -np.inf
    
    for permute, event_dim in [(None, 2), (None, 3), (None, 4), ((0, 2, 1, 3, 4), 3)]:
        mvt = MVTFitter(samples[key], 3., permute=permute, event_dim=event_dim).cuda()
        for df_init in try_df_inits:
            with torch.no_grad():
                mvt.dist.df[...] = df_init

            lik = mvt.dist.log_prob().item()
            df = mvt.dist.df.item()
            if np.isnan(lik) or np.isnan(df):
                print("key", key, "saw a nan with lik", lik)

            if lik > max_lik:
                opt_dfs[key] = (lik, df, (permute, event_dim))
                max_lik = lik
json_dump(opt_dfs, "4.1_opt_dfs.json")

# Explore degrees of freedom of MNIST weights

You need to run `jug/0_12_mnist_no_weight_decay.py` for this.

In [31]:
#df of MVT in MNIST

mnist_weights = collections.defaultdict( lambda: [], {})
for i in range(8):
    samples_file = f"../logs/sgd-no-weight-decay/mnist_classificationconvnet/{i}/samples.pt"
    s = load_samples(samples_file)
    for k in s.keys():
        if k.endswith(".p"):
            mnist_weights[k].append(s[k][-1])
mnist_weights = {k: torch.stack(v, 0) for (k, v) in mnist_weights.items()}



In [27]:
mnist_conv_keys = ['net.module.1.weight_prior.p', 'net.module.4.weight_prior.p', 'net.module.8.weight_prior.p']

In [28]:
opt_mnist_dfs = {}

try_df_inits = torch.linspace(math.log(2.1), math.log(1000), 300).exp()

for key in mnist_conv_keys:
    max_lik = -np.inf
    
    for permute, event_dim in [(None, 2), (None, 3), (None, 4), ((0, 2, 1, 3, 4), 3)]:
        try:
            mvt = MVTFitter(mnist_weights[key], 3., permute=permute, event_dim=event_dim).cuda()
            for df_init in try_df_inits:
                with torch.no_grad():
                    mvt.dist.df[...] = df_init

                lik = mvt.dist.log_prob().item()
                df = mvt.dist.df.item()
                if np.isnan(lik) or np.isnan(df):
                    print("key", key, "saw a nan with lik", lik)

                if lik > max_lik:
                    opt_mnist_dfs[key] = (lik, df, (permute, event_dim))
                    max_lik = lik
        except RuntimeError as e:
            dist = scipy.stats.t.fit(mnist_weights[key].numpy())
            opt_mnist_dfs[key] = (None, dist[0], None)
            
opt_mnist_dfs

{'net.module.1.weight_prior.p': (-1663.31383774297,
  107.83875426199431,
  (None, 2)),
 'net.module.4.weight_prior.p': (461357.6454892001,
  3.5897961780383376,
  (None, 2)),
 'net.module.8.weight_prior.p': (None, 1.4334627049114599, None)}

In [29]:
#df of MVT in MNIST

fcnn_weights = collections.defaultdict( lambda: [], {})
for i in range(10):
    if i ==5 :
        continue
    samples_file = f"../logs/sgd-no-weight-decay/mnist_classificationdensenet/{i}/samples.pt"
    s = load_samples(samples_file)
    for k in s.keys():
        if k.endswith("weight_prior.p"):
            fcnn_weights[k].append(s[k][-1])
fcnn_weights = {k: torch.stack(v, 0) for (k, v) in fcnn_weights.items()}

{k: scipy.stats.t.fit(v)[0] for k, v in fcnn_weights.items()}

{'net.module.0.weight_prior.p': 2.1691800464281084,
 'net.module.2.weight_prior.p': 6.292621228487064,
 'net.module.4.weight_prior.p': 11.825714788660914}