In [1]:
import sys
sys.path.append("../") 

In [2]:
import argparse, time, logging, random, math

import numpy as np
import mxnet as mx

from mxnet import gluon, nd
from mxnet import autograd as ag
from mxnet.gluon import nn
from mxnet.gluon.data.vision import transforms

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(0.,1.)
])

In [4]:
num_gpus = 1
model_ctx = mx.gpu()

num_workers = 4
batch_size = 256 
train_data = gluon.data.DataLoader(
    gluon.data.vision.MNIST(train=True).transform_first(transform),
    batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers)

val_data = gluon.data.DataLoader(
    gluon.data.vision.MNIST(train=False).transform_first(transform),
    batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [5]:
for X,y in val_data:
    print(type(X))
    print(X.shape)
    print(y.shape)
    break

<class 'mxnet.ndarray.ndarray.NDArray'>
(256, 1, 28, 28)
(256,)


### Bayesian inference for MNIST

* [Stochastic Gradient Descent](#chapter1)
* [Stochastic Gradient Langevin Dynamics](#chapter2)
* [Bayes By Backprop](#chapter3)
* [Diagnostics](#chapter4)


# Stochastic Gradient Descent <a class="anchor" id="chapter1"></a>

In [6]:
import mxnet as mx
from mxnet import nd, autograd, gluon
hyper={'alpha':10.}
in_units=(28,28)
out_units=10

In [7]:
import mxnet as mx
from hamiltonian.inference.sgd import sgd
from hamiltonian.models.softmax import softmax

model=softmax(hyper,in_units,out_units,ctx=model_ctx)
inference=sgd(model,model.par,step_size=0.001,ctx=model_ctx)

In [8]:
import hamiltonian
import importlib

try:
    importlib.reload(hamiltonian.models.softmax)
    importlib.reload(hamiltonian.inference.sgd)
    print('modules re-loaded')
except:
    print('no modules loaded yet')


modules re-loaded


In [9]:
model.net.collect_params()

{'1.weight': Parameter (shape=(10, 784), dtype=float32),
 '1.bias': Parameter (shape=(10,), dtype=float32)}

In [10]:
import matplotlib.pyplot as plt 

train_sgd=False
num_epochs=100
if train_sgd:
    par,loss=inference.fit(epochs=num_epochs,batch_size=batch_size,data_loader=train_data,verbose=True)

    fig=plt.figure(figsize=[5,5])
    plt.plot(loss,color='blue',lw=3)
    plt.xlabel('Epoch', size=18)
    plt.ylabel('Loss', size=18)
    plt.title('SGD Softmax MNIST', size=18)
    plt.xticks(size=14)
    plt.yticks(size=14)
    plt.savefig('sgd_softmax.pdf', bbox_inches='tight')
    model.net.save_parameters('../scripts/results/softmax/softmax_sgd_'+str(num_epochs)+'_epochs.params')
else:
    model.net.load_parameters('../scripts/results/softmax/softmax_sgd_'+str(num_epochs)+'_epochs.params',ctx=model_ctx)
    par=dict()
    for name,gluon_par in model.net.collect_params().items():
        par.update({name:gluon_par.data()})
               


In [11]:
from sklearn.metrics import classification_report

total_samples,total_labels,log_like=inference.predict(par,batch_size=batch_size,num_samples=100,data_loader=val_data)
y_hat=np.quantile(total_samples,.5,axis=0)
print(classification_report(np.int32(total_labels),np.int32(y_hat)))

              precision    recall  f1-score   support

           0       0.95      0.97      0.96       979
           1       0.98      0.97      0.97      1133
           2       0.92      0.87      0.90      1030
           3       0.88      0.91      0.89      1008
           4       0.91      0.93      0.92       980
           5       0.84      0.87      0.85       890
           6       0.92      0.95      0.93       956
           7       0.90      0.92      0.91      1027
           8       0.87      0.86      0.86       973
           9       0.92      0.85      0.88      1008

    accuracy                           0.91      9984
   macro avg       0.91      0.91      0.91      9984
weighted avg       0.91      0.91      0.91      9984



# Stochastic Gradient Langevin Dynamics <a class="anchor" id="chapter2"></a>

In [257]:
from hamiltonian.inference.sgld import sgld

model=softmax(hyper,in_units,out_units,ctx=model_ctx)
inference=sgld(model,model.par,step_size=0.01,ctx=model_ctx)

In [258]:
import hamiltonian
import importlib

try:
    importlib.reload(hamiltonian.models.softmax)
    importlib.reload(hamiltonian.inference.sgld)
    print('modules re-loaded')
except:
    print('no modules loaded yet')

modules re-loaded


In [259]:
import matplotlib.pyplot as plt
import seaborn as sns
import glob

train_sgld=False
num_epochs=100

if train_sgld:
    loss,posterior_samples=inference.sample(epochs=num_epochs,batch_size=batch_size,
                                data_loader=train_data,
                                verbose=True,chain_name='chain_nonhierarchical')

    plt.rcParams['figure.dpi'] = 360
    sns.set_style("whitegrid")
    fig=plt.figure(figsize=[5,5])
    plt.plot(loss[0],color='blue',lw=3)
    plt.plot(loss[1],color='red',lw=3)
    plt.xlabel('Epoch', size=18)
    plt.ylabel('Loss', size=18)
    plt.title('SGLD Softmax MNIST', size=18)
    plt.xticks(size=14)
    plt.yticks(size=14)
    plt.savefig('sgld_softmax.pdf', bbox_inches='tight')
else:
    chain1=glob.glob("../scripts/results/softmax/chain_nonhierarchical_0_1_sgld*")
    chain2=glob.glob("../scripts/results/softmax/chain_nonhierarchical_0_sgld*")
    chain1.sort()
    chain2.sort()
    posterior_samples=[chain1,chain2]

In [260]:
posterior_samples_flat=[item for sublist in posterior_samples for item in sublist]

In [261]:
total_samples,total_labels,log_like=inference.predict(posterior_samples_flat,5,data_loader=val_data)

In [262]:
from sklearn.metrics import f1_score

score=[]
for q in np.arange(.35,.75,.1):
    y_hat=np.quantile(total_samples,q,axis=0)
    score.append(f1_score(np.int32(total_labels),np.int32(y_hat), average='macro'))
print('mean f-1 : {0}, std f-1 : {1}'.format(np.mean(score),2*np.std(score)))

mean f-1 : 0.9163494521048552, std f-1 : 0.009028858596384338


In [263]:
import tensorflow as tf
import tensorflow_probability as tfp

posterior_samples_multiple_chains=inference.posterior_diagnostics(posterior_samples)
samples={var:np.concatenate([posterior_samples_multiple_chains[i][var] 
                        for i in range(len(posterior_samples_multiple_chains))]) 
                        for var in model.par}
samples={var:np.swapaxes(samples[var],0,1) for var in model.par}
r_hat_estimate = lambda samples : tfp.mcmc.diagnostic.potential_scale_reduction(samples, independent_chain_ndims=1,split_chains=True).numpy()
rhat = {var:np.median(r_hat_estimate(samples[var])) for var in model.par}

In [264]:
rhat

{'1.weight': 3.9110374, '1.bias': 1.7014743}

In [265]:
ess_estimate = lambda samples : tfp.mcmc.diagnostic.effective_sample_size(samples, filter_beyond_positive_pairs=True,cross_chain_dims=1).numpy()
ess = {var:np.median(ess_estimate(samples[var])) for var in model.par}

In [266]:
ess

{'1.weight': 3.085601, '1.bias': 8.5339775}

In [269]:
from hamiltonian.utils.diagnostics import *

posterior_samples_multiple_chains=inference.posterior_diagnostics(posterior_samples,serialize=True)

In [270]:
import dask.array as da 

df=h5py.File('posterior_samples.h5','r')
samples={var:da.from_array(df[var]) for var in df.keys()}
r_hat_estimate={var:potential_scale_reduction(samples[var]) for var in samples.keys()}
ess_estimate={var:effective_sample_size(samples[var]) for var in samples.keys()}

In [271]:
ess = {var:np.median(ess_estimate[var]) for var in model.par}
rhat = {var:np.median(r_hat_estimate[var])  for var in model.par}

In [272]:
ess

{'1.weight': 1.6918622967224088, '1.bias': 6.755600823881087}

In [29]:
rhat

{'1.weight': 1.6774608, '1.bias': 1.106829}

In [30]:
import arviz as az

posterior_samples_multiple_chains=inference.posterior_diagnostics(posterior_samples)
datasets=[az.convert_to_inference_data(sample) for sample in posterior_samples_multiple_chains]
dataset = az.concat(datasets, dim="chain")

In [31]:
mean_r_hat_values={var:float(az.rhat(dataset)[var].median().data) for var in model.par}
mean_ess_values={var:float(az.ess(dataset)[var].median().data) for var in model.par}
mean_mcse_values={var:float(az.mcse(dataset)[var].median().data) for var in model.par}

In [32]:
print(mean_r_hat_values)

{'1.weight': 1.6655441421861479, '1.bias': 1.2943451036823195}


In [33]:
mean_ess_values

{'1.weight': 3.5239685309484114, '1.bias': 5.9626626923120085}

In [None]:
mean_mcse_values

In [None]:
from hamiltonian.utils.psis import *

loo,loos,ks=psisloo(log_like)

In [None]:
max_ks=5

In [None]:
ks[np.isinf(ks)]=max_ks

In [None]:
plt.hist(ks)

In [None]:
score=[]
for q in np.arange(.1,.9,.1):
    y_hat=np.quantile(total_samples,q,axis=0)
    score.append(f1_score(np.int32(total_labels),np.int32(y_hat), sample_weight=1-np.clip(ks,0,1),average='weighted'))
print('mean f-1 : {0}, std f-1 : {1}'.format(np.mean(score),2*np.std(score)))

In [None]:
plt.rcParams['figure.dpi'] = 360
sns.set_style("whitegrid")
fig=plt.figure(figsize=[5,5])
plt.scatter(list(range(len(ks))),ks)
plt.plot(list(range(len(ks))), 0.7*(np.ones(len(ks))), linestyle='-',color='red')  # solid
plt.xlabel('Data point', size=18)
plt.ylabel('Pareto shape k', size=18)
plt.title('Non-hierarchical model', size=18)
plt.savefig('psis_sgld_softmax.pdf', bbox_inches='tight')

In [None]:
np.sum(ks>1)

In [None]:
ks[np.logical_and(ks>0.7,ks<1)].sum()

In [None]:
ks[np.logical_and(ks>0.5,ks<0.7)].sum()

In [None]:
np.sum(ks<0.5)

# Hierarchical Softmax <a class="anchor" id="chapter3"></a>

In [34]:
from hamiltonian.models.softmax import hierarchical_softmax
from hamiltonian.inference.sgld import sgld

model=hierarchical_softmax(hyper,in_units,out_units,ctx=model_ctx)
inference=sgld(model,model.par,step_size=0.01,ctx=model_ctx)

In [52]:
train_sgld=False
num_epochs=100

if train_sgld:
    loss,posterior_samples=inference.sample(epochs=num_epochs,batch_size=batch_size,
                                data_loader=train_data,
                                verbose=True,chain_name='chain_hierarchical')

    plt.rcParams['figure.dpi'] = 360
    sns.set_style("whitegrid")
    fig=plt.figure(figsize=[5,5])
    plt.plot(loss[0],color='blue',lw=3)
    plt.plot(loss[1],color='red',lw=3)
    plt.xlabel('Epoch', size=18)
    plt.ylabel('Loss', size=18)
    plt.title('SGLD Hierarchical Softmax MNIST', size=18)
    plt.xticks(size=14)
    plt.yticks(size=14)
    plt.savefig('sgld_hierarchical_softmax.pdf', bbox_inches='tight')
else:
    chain1=glob.glob("../scripts/results/softmax/chain_hierarchical_0_1_sgld*")
    chain2=glob.glob("../scripts/results/softmax/chain_hierarchical_0_sgld*")
    chain1.sort()
    chain2.sort()
    posterior_samples=[chain1,chain2]

In [36]:
posterior_samples_flat=[item for sublist in posterior_samples for item in sublist]

In [37]:
total_samples,total_labels,log_like=inference.predict(posterior_samples_flat,5,data_loader=val_data)

In [38]:
from sklearn.metrics import f1_score

score=[]
for q in np.arange(.35,.75,.1):
    y_hat=np.quantile(total_samples,q,axis=0)
    score.append(f1_score(np.int32(total_labels),np.int32(y_hat), average='micro'))
print('mean f-1 : {0}, std f-1 : {1}'.format(np.mean(score),2*np.std(score)))

mean f-1 : 0.91585, std f-1 : 0.006073713855624048


In [39]:
import tensorflow as tf
import tensorflow_probability as tfp

posterior_samples_multiple_chains=inference.posterior_diagnostics(posterior_samples)
samples={var:np.concatenate([posterior_samples_multiple_chains[i][var] 
                        for i in range(len(posterior_samples_multiple_chains))]) 
                        for var in model.par}
samples={var:np.swapaxes(samples[var],0,1) for var in model.par}
r_hat_estimate = lambda samples : tfp.mcmc.diagnostic.potential_scale_reduction(samples, independent_chain_ndims=1,split_chains=True).numpy()
rhat = {var:np.median(r_hat_estimate(samples[var])) for var in model.par}

In [40]:
rhat

{'1.weight': 3.7939584, '1.bias': 1.1021233}

In [41]:
ess_estimate = lambda samples : tfp.mcmc.diagnostic.effective_sample_size(samples, filter_beyond_positive_pairs=True,cross_chain_dims=1).numpy()
ess = {var:np.median(ess_estimate(samples[var])) for var in model.par}

In [45]:
ess

{'1.weight': 3.1676974, '1.bias': 52.368546}

In [268]:
!rm posterior_samples.h5

In [250]:
importlib.reload(hamiltonian.utils.diagnostics)

<module 'hamiltonian.utils.diagnostics' from '/home/sergio/code/mxprob/benchmarks/../hamiltonian/utils/diagnostics.py'>

In [251]:
from hamiltonian.utils.diagnostics import *

posterior_samples_multiple_chains=inference.posterior_diagnostics(posterior_samples,serialize=True)

In [252]:
import dask.array as da 

df=h5py.File('posterior_samples.h5','r')
samples={var:da.from_array(df[var]) for var in df.keys()}
r_hat_estimate={var:potential_scale_reduction(samples[var]) for var in samples.keys()}
ess_estimate={var:effective_sample_size(samples[var]) for var in samples.keys()}

In [253]:
ess = {var:np.median(ess_estimate[var]) for var in model.par}
rhat = {var:np.median(r_hat_estimate[var])  for var in model.par}

In [254]:
ess

{'1.weight': 1.7391584476858615, '1.bias': 34.996306436053935}

In [255]:
rhat

{'1.weight': 1.6387938, '1.bias': 1.0146148}

In [256]:
import arviz as az

posterior_samples_multiple_chains=inference.posterior_diagnostics(posterior_samples)
datasets=[az.convert_to_inference_data(sample) for sample in posterior_samples_multiple_chains]
dataset = az.concat(datasets, dim="chain")
mean_r_hat_values={var:float(az.rhat(dataset)[var].mean().data) for var in model.par}
mean_ess_values={var:float(az.ess(dataset)[var].mean().data) for var in model.par}
mean_mcse_values={var:float(az.mcse(dataset)[var].mean().data) for var in model.par}

In [248]:
az.summary(dataset)

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
"1.weight[0,0]",-0.435,0.559,-1.392,0.534,0.233,0.174,6.0,54.0,1.28
"1.weight[0,1]",-0.538,0.454,-1.575,0.331,0.169,0.125,7.0,12.0,1.27
"1.weight[0,2]",-0.024,1.951,-2.739,2.450,1.267,1.037,3.0,12.0,1.89
"1.weight[0,3]",-0.790,0.887,-2.373,0.713,0.447,0.343,4.0,12.0,1.51
"1.weight[0,4]",-0.142,0.515,-0.970,0.821,0.230,0.173,5.0,21.0,1.35
...,...,...,...,...,...,...,...,...,...
1.bias[5],0.733,0.204,0.352,1.079,0.037,0.026,29.0,106.0,1.05
1.bias[6],-0.097,0.168,-0.413,0.209,0.025,0.021,47.0,104.0,1.04
1.bias[7],0.222,0.191,-0.121,0.554,0.020,0.014,87.0,75.0,1.06
1.bias[8],-0.577,0.197,-0.903,-0.223,0.031,0.022,39.0,131.0,1.08


In [67]:
print(mean_r_hat_values)

{'1.weight': 1.6711408591815198, '1.bias': 1.0529862404615036}


In [68]:
print(mean_ess_values)

{'1.weight': 4.9645475053530435, '1.bias': 57.65651747515627}


In [None]:
print(mean_mcse_values)

In [None]:
loo,loos,ks=psisloo(log_like)

In [None]:
score=[]
for q in np.arange(.1,.9,.1):
    y_hat=np.quantile(total_samples,q,axis=0)
    score.append(f1_score(np.int32(total_labels[ks>0.7]),np.int32(y_hat[ks>0.7]), average='macro'))
print('mean f-1 : {0}, std f-1 : {1}'.format(np.mean(score),2*np.std(score)))

In [None]:
score=[]
for q in np.arange(.1,.9,.1):
    y_hat=np.quantile(total_samples,q,axis=0)
    score.append(f1_score(np.int32(total_labels[ks<0.7]),np.int32(y_hat[ks<0.7]), average='macro'))
print('mean f-1 : {0}, std f-1 : {1}'.format(np.mean(score),2*np.std(score)))

In [None]:
score=[]
for q in np.arange(.1,.9,.1):
    y_hat=np.quantile(total_samples,q,axis=0)
    score.append(f1_score(np.int32(total_labels),np.int32(y_hat), sample_weight=1-np.clip(ks,0,1),average='weighted'))
print('mean f-1 : {0}, std f-1 : {1}'.format(np.mean(score),2*np.std(score)))

In [None]:
plt.rcParams['figure.dpi'] = 360
sns.set_style("whitegrid")
fig=plt.figure(figsize=[5,5])
plt.scatter(list(range(len(ks))),ks)
plt.plot(list(range(len(ks))), 0.7*(np.ones(len(ks))), linestyle='-',color='red')  # solid
plt.xlabel('Data point', size=18)
plt.ylabel('Pareto shape k', size=18)
plt.title('Hierarchical model', size=18)
plt.savefig('psis_sgld_hierarchical_softmax.pdf', bbox_inches='tight')

In [None]:
np.sum(ks>1)

In [None]:
ks[np.logical_and(ks>0.7,ks<1)].sum()

In [None]:
ks[np.logical_and(ks>0.5,ks<0.7)].sum()

In [None]:
np.sum(ks<0.5)