In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

# Plotting Code for Figure 1

In [None]:
import torch
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
import numpy as np
from scipy.stats import lognorm
import math
from scipy.special import erf
from scipy.stats import norm
sns.set_style("white")
sns.set_palette(sns.color_palette("colorblind"))

from modules.classifiers import make_mlp
from modules.layers import LogUniformAllBMRPruningLayer
from modules.layers import gather_pruning_layers


def loguniform(low=0, high=1, size=None):
    return np.exp(np.random.uniform(low, high, size))

def trunc_logu_pdf(x, a, b):
    return 1 / (x * (b - a))

def trunc_exp_pdf(x, a, b, lam):
    return np.exp(-lam*x) / (np.exp(-lam*a) - np.exp(-lam*b))

def truncated_lognormal_pdf(x, mu=0, var=1, a=-20, b=0):
    assert (np.exp(a) <= x).all() and (np.exp(b) >= x).all()
    
    sig = np.sqrt(var)
    alpha = (a - mu) / sig
    beta = (b - mu) / sig
    Z = norm.cdf(beta) - norm.cdf(alpha)
    
    return 1 / (Z*x*np.sqrt(2*np.pi*var)) * np.exp(-(np.log(x) - mu)**2 / (2*var))


model = make_mlp(
            in_dim=32*32*3,
            out_dim=10,
            n_layers=5,
            hidden_dim=150,
            pruning_class=LogUniformAllBMRPruningLayer,
            enable_pruning=True
        )
model.load_state_dict(torch.load(f'latex/paper_items/cifar10_mlp_1000.pth', map_location=torch.device('cpu')))


In [None]:
fig, ax = plt.subplots(1, 3, figsize=(30,7))
plt.rc('axes', titlesize=40)  # fontsize of the axes title
plt.rc('axes', labelsize=40)  # fontsize of the x and y labels
plt.rc('xtick', labelsize=40)  # fontsize of the tick labels
plt.rc('ytick', labelsize=40)  # fontsize of the tick labels
plt.rc('legend', fontsize=40)  # legend fontsize
plt.rcParams['text.usetex'] = True
#sns.despine()
a = -20
b = 0
x = np.linspace(np.exp(a), np.exp(b), 100)
y = trunc_logu_pdf(x,a,b)
mu = -20
sigma = 10
y2 = truncated_lognormal_pdf(x, mu, sigma**2, a, b)
sns.lineplot(x=x, y=y, color='black', label="Original Prior", linewidth=8, ax=ax[1])
logu_color = ax[1].lines[-1].get_color()
sns.lineplot(x=x, y=y2, color='black', linestyle='--', label="Reduced Prior", linewidth=8, ax=ax[1])
logn_color = ax[1].lines[-1].get_color()

ax[1].set_ylim(0.,5.)
#ax.fill_between(x,y, color=logu_color, alpha=0.3)
ax[1].fill_between(x,y2, color=logn_color, alpha=0.3)
ax[1].set_yticks([], labels=[])
ax[1].set_xlim((0.,1.))

layers = gather_pruning_layers(model)
Enoise = np.concatenate([l.Etheta().detach().cpu().numpy() for l in layers])
Enoise_pruned = np.concatenate([l.Etheta()[l.mask == 1].detach().cpu().numpy() for l in layers])

sns.histplot(Enoise, bins=20,kde=True, color='black', ax=ax[0], line_kws={'linewidth': 8}, label='$E[\\theta]$')
ax[0].set_ylim((0., 30.))
ax[0].legend()
ax[0].set_yticks([], labels=[])
ax[0].set_ylabel("")
ax[0].set_xlim((0.,1.))
ax[1].set_facecolor((0.082, 0.376, 0.510, 0.05))

sns.histplot(Enoise_pruned, bins=20,kde=True, color='black', ax=ax[2], line_kws={'linewidth': 8}, label='$E[\\theta]$ BMRS$_{\\mathcal{N}}$')
ax[2].set_ylim((0., 30.))
ax[2].legend()
ax[2].set_yticks([], labels=[])
ax[2].set_ylabel("")
ax[2].set_xlim((0.,1.))

for axc in ax:
    axc.set_xticks([np.exp(-20), 0.5, 1], labels=["$e^-20$", "0.5", "1"])
    for axis in ['bottom','left', 'top', 'right']:
        axc.spines[axis].set_linewidth(3)
    
plt.tight_layout(w_pad=20.0)
plt.savefig('./latex/figures/figure1.png')