In [1]:
import os
from glob import glob
import pandas as pd
import numpy as np
from scipy import stats
from matplotlib import pyplot as plt
from matplotlib.colors import Normalize
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib import gridspec
import json
import torch
import gpytorch
import h5py
import collections
import scipy
import torch
import math
import seaborn as sns

from bnn_priors import prior

from bnn_priors.exp_utils import load_samples

%matplotlib inline
%config InlineBackend.print_figure_kwargs = {'bbox_inches':None}

In [None]:
mean_covs = pd.read_pickle("Plot_MNIST_convnet_covariances_data/mean_covs.pkl.gz")

# Plot Figure 4

In [None]:
sns.set(context="paper", style="white", font_scale=1.0)
plt.rcParams["font.sans-serif"].insert(0, "DejaVu Sans")
plt.rcParams.update({
    "font.family": "sans-serif",  # use serif/main font for text elements
    "text.usetex": False,     # use inline math for ticks
    "pgf.rcfonts": True,     # don't setup fonts from rc parameters
    "font.size": 10,
    "axes.linewidth": 0.5,
    'ytick.major.width': 0.5,
    'ytick.major.size': 0,
    'xtick.major.width': 0.5,
    'xtick.major.size': 0,
    "figure.dpi": 300,

})

fig_width_pt = 234.8775
inches_per_pt = 1.0/72.27               # Convert pt to inches
fig_width = fig_width_pt*inches_per_pt  # width in inches

norm = Normalize(-0.27, 0.27)

margins = dict(
    left=0.04,
    right=0.1,
    top=0.08,
    bottom=0.05)

plots_x = 2
wsep = hsep = 0.015
w_cov_sep = 0.04
cbar_width = 0.03
cbar_wsep = 0.01
height = width = (1 - w_cov_sep*plots_x - wsep*3*plots_x - cbar_wsep - cbar_width
         - margins['left'] - margins['right'])/plots_x / 3
ttl_marg=10

fig_height_mult = (margins['bottom'] + height*3 + hsep*2 + margins['top'])

# make figure rectangular and correct vertical sizes
hsep /= fig_height_mult
height /= fig_height_mult
margins['bottom'] /= fig_height_mult
margins['top'] /= fig_height_mult

fig = plt.figure(figsize=(fig_width, fig_width *fig_height_mult))



cbar_height = height*3 + hsep*2

key = "net.module.1.weight_prior.p"

for y in range(3):
    for x in range(3):
        bottom = margins['bottom'] + (height + hsep) * (2-y)
        left = margins['left'] + (width +wsep) * x
        
        if x == 0:
            yticks = [1, 2, 3]
        else:
            yticks = []
            
        if y == 2:
            xticks = [1, 2, 3]
        else:
            xticks = []
        ax = fig.add_axes([left, bottom, width, height], xticks=xticks, yticks=yticks)
                          #title=f"cov. w/ ({x + 1}, {y +1})")
        ax.imshow(mean_covs[key][1][y*3+x, :].reshape((3, 3)),
                  cmap=plt.get_cmap('RdBu'),
                  extent=[0.5, 3.5, 3.5, 0.5], norm=norm)
        ax.plot([x+1], [y+1], marker='x', ls='none', color='white', ms=3)
        
        if y==0 and x==1:
            ttl = ax.set_title("Layer 1 covariance", pad=ttl_marg)


key = "net.module.4.weight_prior.p"

for y in range(3):
    for x in range(3):
        bottom = margins['bottom'] + (height + hsep) * (2-y)
        left = margins['left'] + (width+wsep)*3 + w_cov_sep + (width +wsep) * x
       
        yticks = []
            
        if y == 2: 
            xticks = [1, 2, 3]
        else:
            xticks = []
        ax = fig.add_axes([left, bottom, width, height], xticks=xticks, yticks=yticks)
                          #title=f"cov. w/ ({x + 1}, {y +1})")
        mappable = ax.imshow(mean_covs[key][1][y*3+x, :].reshape((3, 3))*64,
                  cmap=plt.get_cmap('RdBu'),
                  extent=[0.5, 3.5, 3.5, 0.5], norm=norm)
        ax.plot([x+1], [y+1], marker='x', ls='none', color='white', markersize=3)
        
        if y==0 and x==1:
            ttl = ax.set_title("Layer 2 covariance", pad=ttl_marg)
        
cbar_ax = fig.add_axes([margins['left'] + (width+wsep)*3*2 + w_cov_sep + cbar_wsep,
                        margins['bottom'], cbar_width, cbar_height])
fig.colorbar(mappable, cax=cbar_ax,
             ticks=[-0.27, -0.15, 0, 0.15, 0.27])
fig.savefig("../figures/210126-mnist-covariances-all.pdf")

# Load weights of the MNIST network, that doesn't have batchnorm

In [4]:
directories = [*map(str, range(8))]
samples = collections.defaultdict(lambda: [], {})
param_keys = None

for d in directories:
    with h5py.File(f"../logs/sgd-no-weight-decay/mnist_classificationconvnet/{d}/samples.pt", "r") as f:
        if param_keys is None:
            param_keys = [k for k in f.keys() if k.endswith(".p")]
            
        for key in param_keys:
            samples[key].append(f[key][-1])

In [5]:
for k in samples.keys():
    samples[k] = np.stack(samples[k])

In [6]:
samples.keys()

dict_keys(['net.module.1.bias_prior.p', 'net.module.1.weight_prior.p', 'net.module.4.bias_prior.p', 'net.module.4.weight_prior.p', 'net.module.8.bias_prior.p', 'net.module.8.weight_prior.p'])

In [7]:
samples_reshaped = {}
mean_covs = {}

for k in samples.keys():
    if k in ["net.module.1.weight_prior.p", "net.module.4.weight_prior.p"]:
        #if k == "net.module.8.weight_prior.p":
        #    samples_reshaped[k] = samples[k].transpose((0, 2, 1)).reshape((-1, 10))
        #else:
        samples_reshaped[k] = samples[k].reshape((-1, 9))
        mean_covs[k] = (np.mean(samples_reshaped[k], 0), np.cov(samples_reshaped[k], rowvar=False))
    else:
        samples_reshaped[k] = samples[k]    
        mean_covs[k] = (np.mean(samples[k]), np.var(samples[k]))

In [8]:
pd.to_pickle(mean_covs, "3.4.1_mean_covs.pkl.gz")