In [1]:
#import modules and set up environment 
import os
import sys
path = "../../src/"

sys.path.append(path)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import jax.numpy as jnp
from jax import jit, grad, vmap, hessian, scipy, random

#import sgmcmc code 
import models.bivariate_gaussian.bivariate_gaussian as bvg
import samplers.sgld as sgld
import samplers.sgldps as sgldps
import samplers.sgldcv as sgldcv
import samplers.sgldcvps as sgldcvps

key = random.PRNGKey(0)



### Set up model and data

In [2]:
#load in the data
file_path = "../../data/synthetic/toy_bvg_synth.csv"
data = pd.read_csv(file_path)
dat_array = data.values[:]
x = np.array(dat_array)
N = dat_array.shape[0]
y = np.array([None]*N) 
y.shape

(1000,)

In [3]:
#set up model parameters
dim = 2 
theta_true = jnp.array([0., 1.]) #data mean
xbar = jnp.mean(x, axis = 0)
sigma_x = jnp.array([[ 2*10**5 , -4*10**4], [-4*10**4,  10**4]]) #data covariance matrix
sigma_x_inv = jnp.linalg.inv(sigma_x) #data precision 
mu_0 = jnp.array([0., 0.]) #prior mean
lambda_0 = jnp.array([[ 1000 , 0.], [0.,  1000]])  #prior scale
lambda_0_inv = jnp.linalg.inv(lambda_0)

#posterior 
lambda_1_inv = lambda_0_inv + N*sigma_x_inv #posterior precision
lambda_1 = jnp.linalg.inv(lambda_1_inv) #posterior covariance
mu_1 = jnp.matmul(lambda_1, N*jnp.matmul(sigma_x_inv, xbar) + jnp.matmul(lambda_0_inv, mu_0)) #posterior mean

In [4]:
#define the mode
theta_hat = mu_1 

# draw candidate thetas
n_candidate = 10
theta_candidates = random.multivariate_normal(key=key, mean=mu_1, cov=lambda_1, shape=(n_candidate,))
theta_candidates

DeviceArray([[ -1.12365294,  10.89599973],
             [ 13.91596074,  11.07020949],
             [  1.14693132,  14.07728093],
             [-16.36011229,  17.51262531],
             [  5.18900029,  10.08220814],
             [ -5.58482982,  12.4932415 ],
             [ -2.16542207,  15.96377277],
             [ -6.88311986,  17.80447463],
             [ -5.7467338 ,  13.21220418],
             [-11.68235157,  15.75879915]], dtype=float64)

### Set up experiment 

In [5]:
batch_sizes = jnp.int32(jnp.linspace(0.01, 0.99, num=50)*N)

In [6]:
#set up model-specific gradient functions
gradf_0 = bvg.gradf_0
gradf_i_batch = bvg.gradf_i_batch
post_var = bvg.post_var

In [7]:
pseudo_var = dict()
reps = 1000

### SGLD gradients (with replacement)

In [8]:
pseudo_var['sgld-wr'] = np.zeros((n_candidate, batch_sizes.shape[0])) 
#sgld (wr) gradients
for i in tqdm(range(n_candidate), desc = "Theta candidates tested"):
    theta_candidate = theta_candidates[i]
    for j in range(batch_sizes.shape[0]):
        n = batch_sizes[j]
        sgld_gradients1 = np.zeros((reps, dim))
        for k in range(reps):
            key, subkey = random.split(key)
            grad = sgld.sgld_grad(subkey, theta_candidate, gradf_0, gradf_i_batch, x, y, n, replacement = True)
            sgld_gradients1[k] = grad
        pseudo_var['sgld-wr'][i,j] = np.sum(np.var(sgld_gradients1, axis = 0))

Theta candidates tested: 100%|██████████| 10/10 [06:09<00:00, 36.97s/it]


### SGLD gradients (without replacement)

In [9]:
pseudo_var['sgld-wor'] = np.zeros((n_candidate, batch_sizes.shape[0])) 
#sgld (wor) gradients
for i in tqdm(range(n_candidate), desc = "Theta candidates tested"):
    theta_candidate = theta_candidates[i]
    for j in range(batch_sizes.shape[0]):
        n = batch_sizes[j]
        sgld_gradients2 = np.zeros((reps, dim))
        for k in range(reps):
            key, subkey = random.split(key)
            grad = sgld.sgld_grad(subkey, theta_candidate, gradf_0, gradf_i_batch, x, y, n,replacement = False)
            sgld_gradients2[k] = grad
        pseudo_var['sgld-wor'][i,j] = np.sum(np.var(sgld_gradients2, axis = 0))

Theta candidates tested: 100%|██████████| 10/10 [15:09<00:00, 90.94s/it]


### SGLD-PS (exact probabilities)

In [10]:
#sgld-ps (exact) gradients
pseudo_var['sgld-ps-exact'] = np.zeros((n_candidate, batch_sizes.shape[0])) 

In [12]:
for i in tqdm(range(n_candidate), desc = "Theta candidates tested"):
    theta_candidate = theta_candidates[i]
    probs1 = sgldps.exact_probs(theta_candidate, gradf_i_batch, x, y)
    for j in range(batch_sizes.shape[0]):
        n = batch_sizes[j]
        sgldps_gradients1 = np.zeros((reps, dim))
        for k in range(reps):
            key, subkey = random.split(key)
            grad = sgldps.sgldps_grad(subkey, theta_candidate, gradf_0, gradf_i_batch, probs1, x, y, n)
            sgldps_gradients1[k] = grad
        pseudo_var['sgld-ps-exact'][i,j] = np.sum(np.var(sgldps_gradients1, axis = 0))

Theta candidates tested: 100%|██████████| 10/10 [17:15<00:00, 103.52s/it]


### SGLD-PS (approx)

In [13]:
#sgld-ps (approx) gradients
f_i_grad_list = sgldps.ps_preliminaries(theta_hat, gradf_i_batch, x, y)
probs2 = sgldps.approx_probs(theta_hat, f_i_grad_list)
pseudo_var['sgld-ps-approx'] = np.zeros((n_candidate, batch_sizes.shape[0])) 

In [14]:
for i in tqdm(range(n_candidate), desc = "Theta candidates tested"):
    theta_candidate = theta_candidates[i]
    for j in range(batch_sizes.shape[0]):
        n = batch_sizes[j]
        sgldps_gradients2 = np.zeros((reps, dim))
        for k in range(reps):
            key, subkey = random.split(key)
            grad = sgldps.sgldps_grad(subkey, theta_candidate, gradf_0, gradf_i_batch, probs2, x, y, n)
            sgldps_gradients2[k] = grad
        pseudo_var['sgld-ps-approx'][i, j] = np.sum(np.var(sgldps_gradients2, axis = 0))

Theta candidates tested: 100%|██████████| 10/10 [15:57<00:00, 95.76s/it]


### Save results 

In [15]:
plot_data = []
for key in pseudo_var:
    smooth_path = pseudo_var[key].mean(axis = 0)
    plot_data.append(smooth_path)
    path_deviation = 2*pseudo_var[key].std(axis = 0)
    under_line = (smooth_path-path_deviation)
    plot_data.append(under_line)
    over_line = (smooth_path+path_deviation)
    plot_data.append(over_line) 
    
df = pd.DataFrame(plot_data).T
df['proportion'] = np.linspace(0.01, 0.99, num = 50)

In [16]:
df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,proportion
0,42866420000.0,39101500000.0,46631350000.0,43336660000.0,39592710000.0,47080620000.0,29493970000.0,27824440000.0,31163500000.0,29410730000.0,27055410000.0,31766060000.0,0.01
1,14264130000.0,13283030000.0,15245240000.0,14515060000.0,13799580000.0,15230540000.0,9588912000.0,8971133000.0,10206690000.0,9804989000.0,8963483000.0,10646500000.0,0.03
2,8874715000.0,8178782000.0,9570647000.0,8177979000.0,7705573000.0,8650384000.0,5914396000.0,5533784000.0,6295007000.0,5967556000.0,5621921000.0,6313192000.0,0.05
3,6415751000.0,6060470000.0,6771032000.0,5896994000.0,5515039000.0,6278949000.0,4266415000.0,3877600000.0,4655230000.0,4284063000.0,4070769000.0,4497357000.0,0.07
4,4781324000.0,4437021000.0,5125628000.0,4424186000.0,4063527000.0,4784845000.0,3232648000.0,3090673000.0,3374624000.0,3387598000.0,3177728000.0,3597469000.0,0.09
5,3943642000.0,3572115000.0,4315170000.0,3500167000.0,3246913000.0,3753421000.0,2743706000.0,2581850000.0,2905561000.0,2651712000.0,2475610000.0,2827814000.0,0.11
6,3367711000.0,3107039000.0,3628383000.0,2896790000.0,2711896000.0,3081684000.0,2223470000.0,2105141000.0,2341800000.0,2247786000.0,2094718000.0,2400853000.0,0.13
7,2866736000.0,2596438000.0,3137033000.0,2535390000.0,2330624000.0,2740156000.0,1978673000.0,1777719000.0,2179628000.0,1936755000.0,1839984000.0,2033526000.0,0.15
8,2563694000.0,2415863000.0,2711525000.0,2157974000.0,2075493000.0,2240456000.0,1767786000.0,1588929000.0,1946644000.0,1737485000.0,1652560000.0,1822410000.0,0.17
9,2287941000.0,2151275000.0,2424607000.0,1838647000.0,1699153000.0,1978141000.0,1540613000.0,1459801000.0,1621425000.0,1555159000.0,1449663000.0,1660654000.0,0.19


In [17]:
path_out = "./out/toy_bvg_gradient_comp.csv"
df.to_csv(path_out, index = False) #save csv