In [1]:
#import modules and set up environment 
import os
import sys
path = "../../src/"

sys.path.append(path)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import jax.numpy as jnp
from jax import jit, grad, vmap, hessian, scipy, random

In [34]:
#import sgmcmc code 
import models.logistic_regression.logistic_regression as lr
import samplers.sgld as sgld
import samplers.sgldps as sgldps
import samplers.sgldcv as sgldcv
import samplers.sgldcvps as sgldcvps
import samplers.sgd as sgd

key = random.PRNGKey(0)

### Set up model and data

In [3]:
#load in the data
file_path = "../../data/synthetic/toy_lr_balance_train_synth.csv"
data = pd.read_csv(file_path)
dat_array = data.values[:]
x = dat_array[:, 1:]
y = dat_array[:,0]

In [4]:
#set up model parameters
N = x.shape[0]
dim = x.shape[1] 

#priors
mu_0 = np.zeros(dim) #prior mean
lambda_0 = 10.0*np.eye(dim)  #prior covariance matrix

In [5]:
#set up model-specific gradient functions
gradf_0 = lr.gradf_0
gradf_i_batch = lr.gradf_i_batch
post_var = lr.post_var

### Find mode using SGD

In [6]:
#hyperparameters for SGD
n_sgd = int(0.01*N)
Niter_sgd = 10**4
step_sgd = 1e-3

In [7]:
#draw theta_0 from the prior
theta_start = jnp.zeros(dim) #random starting value
print("Initalising SGD to find mode at theta_0:\n", theta_start)

Initalising SGD to find mode at theta_0:
 [0. 0. 0. 0. 0.]


In [8]:
theta_hat, samples_sgd, run_time = sgd.adam(key, gradf_0, gradf_i_batch, Niter_sgd, theta_start, x, y, n_sgd, step_sgd, replacement=True)

In [9]:
theta_hat

array([-0.9348976 , -0.41609308, -1.22440743,  0.36571023, -0.8317759 ])

In [10]:
# draw candidate thetas
n_candidate = 10
sigma_hat = post_var(theta_hat, x,y)[0]
theta_candidates = random.multivariate_normal(key=key, mean = theta_hat, cov= sigma_hat, shape=(n_candidate,))
theta_candidates

DeviceArray([[-0.9683532 , -0.49170724, -1.3405797 ,  0.38402447,
              -0.9073683 ],
             [-1.0635264 , -0.4388886 , -1.2785772 ,  0.38747898,
              -0.8468758 ],
             [-0.9769268 , -0.45872605, -1.2457107 ,  0.39358127,
              -0.93533677],
             [-1.3089921 , -0.34126902, -1.2269322 ,  0.47543913,
              -0.89157206],
             [-0.9981764 , -0.38148248, -1.1538057 ,  0.35715896,
              -0.9293198 ],
             [-0.66766846, -0.51894903, -1.1942793 ,  0.21381217,
              -0.9099646 ],
             [-0.83467466, -0.40973818, -1.1823505 ,  0.3357298 ,
              -0.76768553],
             [-0.9066595 , -0.4267365 , -1.1557379 ,  0.4286306 ,
              -0.7331321 ],
             [-1.0963364 , -0.3834407 , -1.2176023 ,  0.45045567,
              -0.8945911 ],
             [-0.989809  , -0.49958813, -1.3074896 ,  0.32732075,
              -0.96823853]], dtype=float32)

### Set up experiment 

In [11]:
batch_sizes = jnp.int32(jnp.linspace(0.01, 0.99, num=50)*N)

In [12]:
pseudo_var = dict()
reps = 1000

### SGLD gradients (with replacement)

In [13]:
pseudo_var['sgld-wr'] = np.zeros((n_candidate, batch_sizes.shape[0])) 
#sgld (wr) gradients
for i in tqdm(range(n_candidate), desc = "Theta candidates tested"):
    theta_candidate = theta_candidates[i]
    for j in range(batch_sizes.shape[0]):
        n = batch_sizes[j]
        sgld_gradients1 = np.zeros((reps, dim))
        for k in range(reps):
            key, subkey = random.split(key)
            grad = sgld.sgld_grad(subkey, theta_candidate, gradf_0, gradf_i_batch, x, y, n, replacement=True)
            sgld_gradients1[k] = grad
        pseudo_var['sgld-wr'][i,j] = np.sum(np.var(sgld_gradients1, axis = 0))

Theta candidates tested: 100%|██████████| 10/10 [04:50<00:00, 29.07s/it]


### SGLD gradients (without replacement)

In [14]:
pseudo_var['sgld-wor'] = np.zeros((n_candidate, batch_sizes.shape[0])) 
#sgld (wor) gradients
for i in tqdm(range(n_candidate), desc = "Theta candidates tested"):
    theta_candidate = theta_candidates[i]
    for j in range(batch_sizes.shape[0]):
        n = batch_sizes[j]
        sgld_gradients2 = np.zeros((reps, dim))
        for k in range(reps):
            key, subkey = random.split(key)
            grad = sgld.sgld_grad(subkey, theta_candidate, gradf_0, gradf_i_batch, x, y, n, replacement=False)
            sgld_gradients2[k] = grad
        pseudo_var['sgld-wor'][i,j] = np.sum(np.var(sgld_gradients2, axis = 0))

Theta candidates tested: 100%|██████████| 10/10 [13:08<00:00, 78.89s/it]


### SGLD-PS (exact probabilities)

In [15]:
#sgld-ps (exact) gradients
pseudo_var['sgld-ps-exact'] = np.zeros((n_candidate, batch_sizes.shape[0])) 

In [16]:
for i in tqdm(range(n_candidate), desc = "Theta candidates tested"):
    theta_candidate = theta_candidates[i]
    probs1 = sgldps.exact_probs(theta_candidate, gradf_i_batch, x, y)
    for j in range(batch_sizes.shape[0]):
        n = batch_sizes[j]
        sgldps_gradients1 = np.zeros((reps, dim))
        for k in range(reps):
            key, subkey = random.split(key)
            grad = sgldps.sgldps_grad(subkey, theta_candidate, gradf_0, gradf_i_batch, probs1, x, y, n)
            sgldps_gradients1[k] = grad
        pseudo_var['sgld-ps-exact'][i,j] = np.sum(np.var(sgldps_gradients1, axis = 0))

Theta candidates tested: 100%|██████████| 10/10 [13:43<00:00, 82.34s/it]


### SGLD-PS (approximate probabilities)

In [17]:
#sgld-ps (approx) gradients
f_i_grad_list = sgldps.ps_preliminaries(theta_hat, gradf_i_batch, x, y)
probs2 = sgldps.approx_probs(theta_hat, f_i_grad_list)
pseudo_var['sgld-ps-approx'] = np.zeros((n_candidate, batch_sizes.shape[0])) 

In [18]:
for i in tqdm(range(n_candidate), desc = " Theta candidates tested"):
    theta_candidate = theta_candidates[i]
    for j in range(batch_sizes.shape[0]):
        n = batch_sizes[j]
        sgldps_gradients2 = np.zeros((reps, dim))
        for k in range(reps):
            key, subkey = random.split(key)
            grad = sgldps.sgldps_grad(subkey, theta_candidate, gradf_0, gradf_i_batch, probs2, x, y, n)
            sgldps_gradients2[k] = grad
        pseudo_var['sgld-ps-approx'][i, j] = np.sum(np.var(sgldps_gradients2, axis = 0))

 Theta candidates tested: 100%|██████████| 10/10 [12:58<00:00, 77.88s/it]


### SGLD-CV gradients (with replacement)

In [19]:
gradf_0_hat, grad_full_hat, f_i_grad_list = sgldcv.cv_preliminaries(theta_hat, gradf_0, gradf_i_batch, x, y)

In [20]:
pseudo_var['sgld-cv-wr'] = np.zeros((n_candidate, batch_sizes.shape[0])) 

#sgldcv gradients
for i in tqdm(range(n_candidate), desc = "Theta candidates tested"):
    theta_candidate = theta_candidates[i]
    for j in range(batch_sizes.shape[0]):
        n = batch_sizes[j]
        sgldcv_gradients1 = np.zeros((reps, dim))
        for k in range(reps):
            key, subkey = random.split(key)
            grad = sgldcv.sgld_cv_grad(subkey, theta_candidate, theta_hat, gradf_0, gradf_i_batch, grad_full_hat, f_i_grad_list, x, y, n, replacement=True)
            sgldcv_gradients1[k] = grad
        pseudo_var['sgld-cv-wr'][i, j] = np.sum(np.var(sgldcv_gradients1, axis = 0))

Theta candidates tested: 100%|██████████| 10/10 [08:34<00:00, 51.40s/it]


### SGLD-CV gradients (without replacement)

In [21]:
pseudo_var['sgld-cv-wor'] = np.zeros((n_candidate, batch_sizes.shape[0])) 

#sgldcv gradients
for i in tqdm(range(n_candidate), desc = "Theta candidates tested"):
    theta_candidate = theta_candidates[i]
    for j in range(batch_sizes.shape[0]):
        n = batch_sizes[j]
        sgldcv_gradients2 = np.zeros((reps, dim))
        for k in range(reps):
            key, subkey = random.split(key)
            grad = sgldcv.sgld_cv_grad(subkey, theta_candidate, theta_hat, gradf_0, gradf_i_batch, grad_full_hat, f_i_grad_list, x, y, n, replacement=False)
            sgldcv_gradients2[k] = grad
        pseudo_var['sgld-cv-wor'][i, j] = np.sum(np.var(sgldcv_gradients2, axis = 0))

Theta candidates tested: 100%|██████████| 10/10 [17:23<00:00, 104.32s/it]


### SGLD-CV-PS (exact probabilities) gradients

In [22]:
cov_mat = sgldcvps.cvps_preliminaries(theta_hat, gradf_0, gradf_i_batch, post_var, x, y)[0]
f_i_hess_list = sgldcvps.cvps_preliminaries(theta_hat, gradf_0, gradf_i_batch, post_var, x, y)[4]

In [23]:
#sgld-cv-ps (exact) gradients
pseudo_var['sgld-cv-exact'] = np.zeros((n_candidate, batch_sizes.shape[0])) 

for i in tqdm(range(n_candidate), desc = " Theta candidates tested"):
    theta_candidate = theta_candidates[i]
    probs3 = sgldcvps.exact_probs(theta_candidate, theta_hat, gradf_i_batch, f_i_grad_list, x, y)
    for j in range(batch_sizes.shape[0]):
        n = batch_sizes[j]
        sgldcvps_gradients1 = np.zeros((reps, dim))
        for k in range(reps):
            key, subkey = random.split(key)
            grad = sgldcvps.sgldcv_ps_grad(subkey, theta_candidate, theta_hat, gradf_0, gradf_i_batch, grad_full_hat, f_i_grad_list, probs3, x, y, n)
            sgldcvps_gradients1[k] = grad
        pseudo_var['sgld-cv-exact'][i,j] = np.sum(np.var(sgldcvps_gradients1, axis = 0))

 Theta candidates tested: 100%|██████████| 10/10 [17:38<00:00, 105.81s/it]


### SGLD-CV-PS (approximate probabilities) gradients

In [39]:
#sgldps (approx) gradients
probs4 = np.zeros(N)
for i in range(N):
    #iterate over datapoints
    hess_mat_i = f_i_hess_list[i]
    p = jnp.linalg.norm(hess_mat_i, ord='fro')
    probs4[i] = np.array(p)
probs4 /= np.sum(probs4)

In [41]:
pseudo_var['sgld-cv-approx'] = np.zeros((n_candidate, batch_sizes.shape[0])) 
for i in tqdm(range(n_candidate), desc = "Theta candidates tested"):
    theta_candidate = theta_candidates[i]
    for j in range(batch_sizes.shape[0]):
        n = batch_sizes[j]
        sgldcvps_gradients2 = np.zeros((reps, dim))
        for k in range(reps):
            key, subkey = random.split(key)
            grad = sgldcvps.sgldcv_ps_grad(subkey, theta_candidate, theta_hat, gradf_0, gradf_i_batch, grad_full_hat, f_i_grad_list, probs4, x, y, n)
            sgldcvps_gradients2[k] = grad
        pseudo_var['sgld-cv-approx'][i,j] = np.sum(np.var(sgldcvps_gradients2, axis = 0))

Theta candidates tested: 100%|██████████| 10/10 [17:34<00:00, 105.49s/it]


### Save results 

In [42]:
plot_data = []
for key in pseudo_var:
    smooth_path = pseudo_var[key].mean(axis = 0)
    plot_data.append(smooth_path)
    path_deviation = 2*pseudo_var[key].std(axis = 0)
    under_line = (smooth_path-path_deviation)
    plot_data.append(under_line)
    over_line = (smooth_path+path_deviation)
    plot_data.append(over_line) 
    
df = pd.DataFrame(plot_data).T
df['proportion'] = np.linspace(0.01, 0.99, num = 50)

In [43]:
df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,15,16,17,18,19,20,21,22,23,proportion
0,170490.580331,160802.218749,180178.941913,166706.518603,152004.181418,181408.855787,104629.999363,95821.112848,113438.885879,104843.718534,...,793.231907,-36.594398,1623.058212,316.515017,4.355582,628.674451,541.70243,-61.140797,1144.545658,0.01
1,55183.369661,52245.001658,58121.737663,53942.583206,51601.848547,56283.317864,33886.538382,31695.712576,36077.364188,34869.870903,...,261.316534,-7.772608,530.405677,108.022473,-0.152038,216.196985,179.839032,-21.721807,381.399872,0.03
2,33934.767586,32224.296546,35645.238625,31854.337719,30390.057741,33318.617696,20633.043827,19604.591111,21661.496542,21002.464101,...,153.153671,-8.121381,314.428723,64.284284,0.501132,128.067437,107.629971,-12.134174,227.394117,0.05
3,24125.367373,22531.546303,25719.188443,22486.214203,21539.409234,23433.019173,15075.246901,13812.151144,16338.342657,14966.020616,...,107.142677,-4.052652,218.338005,45.468897,0.137479,90.800315,76.80131,-9.323937,162.926557,0.07
4,18825.956775,17753.767556,19898.145993,16821.976506,16160.904616,17483.048395,11479.515283,10805.00461,12154.025956,11665.116064,...,79.495818,-3.099064,162.090699,35.032735,1.068599,68.99687,59.756524,-7.818499,127.331546,0.09
5,15406.481258,14680.711826,16132.250691,13642.439427,13056.501489,14228.377364,9374.412948,8457.329047,10291.496848,9474.046802,...,65.476949,-3.323894,134.277791,29.280451,-0.602628,59.16353,49.495787,-5.314161,104.305735,0.11
6,12923.629407,12335.988838,13511.269976,11377.963182,10697.365622,12058.560742,8015.675236,7459.494042,8571.856431,8063.851225,...,53.765469,-1.415229,108.946168,24.523121,0.215919,48.830323,41.856987,-4.779717,88.493691,0.13
7,11156.481943,10517.301005,11795.662882,9495.138905,9097.380969,9892.896842,6855.42751,6255.849775,7455.005244,7002.810176,...,45.27442,-2.169149,92.717988,21.229922,0.519855,41.93999,35.851748,-3.417355,75.120851,0.15
8,9919.882286,9565.016112,10274.748461,8149.212034,7661.062703,8637.361366,6080.84136,5678.782285,6482.900434,6046.706292,...,38.569367,-1.058916,78.197651,18.81496,0.479404,37.150517,32.015084,-4.509483,68.539652,0.17
9,8669.92941,8233.136377,9106.722443,7100.658098,6761.866091,7439.450104,5439.032839,5042.985411,5835.080267,5517.912792,...,34.294633,-1.12275,69.712017,17.097091,-0.262347,34.45653,28.682077,-3.692017,61.05617,0.19


In [44]:
path_out = "./out/toy_lrb_gradient_comp.csv"
df.to_csv(path_out, index = False) #save csv