In [1]:
#import modules and set up environment 
import os
import sys
path = "../../src/"

sys.path.append(path)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import jax.numpy as jnp
from jax import jit, grad, vmap, hessian, scipy, random

#import sgmcmc code 
import models.bivariate_gaussian.bivariate_gaussian as bvg
import samplers.sgd as sgd
import samplers.sgld as sgld
import samplers.sgldps as sgldps

key = random.PRNGKey(100)



In [2]:
#set up model-specific gradient functions
gradf_0 = bvg.gradf_0
gradf_i_batch = bvg.gradf_i_batch
post_var = bvg.post_var

### Set up model and data

In [3]:
#load in the data
file_path = "../../data/synthetic/bvg_synth.csv"
data = pd.read_csv(file_path)
dat_array = data.values[:]
x = np.array(dat_array)
N = dat_array.shape[0]
y = np.array([None]*N) 

In [4]:
#set up model parameters
dim = 2 
theta_true = jnp.array([0., 1.]) #data mean
xbar = jnp.mean(x, axis = 0)
sigma_x = jnp.array([[ 1*10**5 , 6*10**4], [6*10**4,  2*10**5]]) #data covariance matrix
sigma_x_inv = jnp.linalg.inv(sigma_x) #data precision 
mu_0 = jnp.array([0., 0.]) #prior mean
lambda_0 = jnp.array([[ 1000 , 0.], [0.,  1000]])  #prior scale
lambda_0_inv = jnp.linalg.inv(lambda_0)

#posterior 
lambda_1_inv = lambda_0_inv + N*sigma_x_inv #posterior precision
lambda_1 = jnp.linalg.inv(lambda_1_inv) #posterior covariance
mu_1 = jnp.matmul(lambda_1, N*jnp.matmul(sigma_x_inv, xbar) + jnp.matmul(lambda_0_inv, mu_0)) #posterior mean

### Set up sampling framework

In [5]:
# step-size
step_size = 10**(-4)
# batch sizes
n_batch = np.int64(N*np.array([0.01, 0.05, 0.1]))
# number of chains
N_rep = 11

### SGLD (0.01N)

In [6]:
n = n_batch[0]
burnin = np.int64((N*500)/n)
Niter = 2*burnin

runtime_df=[]
samples_df = []
grads_df = []

In [7]:
for i in tqdm(range(N_rep), desc = "Number of chains run"):
    key, subkey = random.split(key)
    theta_0 = random.multivariate_normal(key=key, mean=mu_0, cov=lambda_0)
    samples, grads, runtime = sgld.sgld_sampler(subkey, gradf_0, gradf_i_batch, Niter, step_size, theta_0, x, y, n, replacement=True)
    runtime_df.append(runtime)
    samples_df.append(samples)
    grads_df.append(grads)

Number of chains run: 100%|██████████| 11/11 [34:24<00:00, 187.67s/it]


In [8]:
runtime_df = pd.DataFrame(np.column_stack(runtime_df))
runtime_df.to_csv("./out/bvg_sgld_1_runtime.csv", index=False)
samples_df = pd.DataFrame(np.column_stack(samples_df))
samples_df.to_csv("./out/bvg_sgld_1_samples.csv", index=False)
grads_df = pd.DataFrame(np.column_stack(grads_df))
grads_df.to_csv("./out/bvg_sgld_1_grads.csv", index=False)

### SGLD (0.05N)

In [9]:
n = n_batch[1]
burnin = np.int64((N*500)/n)
Niter = 2*burnin

runtime_df=[]
samples_df = []
grads_df = []

In [10]:
for i in tqdm(range(N_rep), desc = "Number of chains run"):
    key, subkey = random.split(key)
    theta_0 = random.multivariate_normal(key=key, mean=mu_0, cov=lambda_0)
    samples, grads, runtime = sgld.sgld_sampler(subkey, gradf_0, gradf_i_batch, Niter, step_size, theta_0, x, y, n, replacement=True)
    runtime_df.append(runtime)
    samples_df.append(samples)
    grads_df.append(grads)

Number of chains run: 100%|██████████| 11/11 [07:17<00:00, 39.78s/it]


In [11]:
runtime_df = pd.DataFrame(np.column_stack(runtime_df))
runtime_df.to_csv("./out/bvg_sgld_5_runtime.csv", index=False)
samples_df = pd.DataFrame(np.column_stack(samples_df))
samples_df.to_csv("./out/bvg_sgld_5_samples.csv", index=False)
grads_df = pd.DataFrame(np.column_stack(grads_df))
grads_df.to_csv("./out/bvg_sgld_5_grads.csv", index=False)

### SGLD (0.1N)

In [12]:
n = n_batch[2]
burnin = np.int64((N*500)/n)
Niter = 2*burnin

runtime_df=[]
samples_df = []
grads_df = []

In [13]:
for i in tqdm(range(N_rep), desc = "Number of chains run"):
    key, subkey = random.split(key)
    theta_0 = random.multivariate_normal(key=key, mean=mu_0, cov=lambda_0)
    samples, grads, runtime = sgld.sgld_sampler(subkey, gradf_0, gradf_i_batch, Niter, step_size, theta_0, x, y, n, replacement=True)
    runtime_df.append(runtime)
    samples_df.append(samples)
    grads_df.append(grads)

Number of chains run: 100%|██████████| 11/11 [03:41<00:00, 20.13s/it]


In [14]:
runtime_df = pd.DataFrame(np.column_stack(runtime_df))
runtime_df.to_csv("./out/bvg_sgld_10_runtime.csv", index=False)
samples_df = pd.DataFrame(np.column_stack(samples_df))
samples_df.to_csv("./out/bvg_sgld_10_samples.csv", index=False)
grads_df = pd.DataFrame(np.column_stack(grads_df))
grads_df.to_csv("./out/bvg_sgld_10_grads.csv", index=False)

### SGLD-PS (0.01N)

In [15]:
n = n_batch[0]
burnin = np.int64((N*500)/n)
step_sgd = 1e-03
Niter = 2*burnin
runtime_df=[]
samples_df = []
grads_df = []

In [16]:
for i in tqdm(range(N_rep), desc = "Number of chains run"):
    key, subkey = random.split(key)
    theta_0 = random.multivariate_normal(key=key, mean=mu_0, cov=lambda_0)
    theta_hat, samples_sgd, runtime_sgd = sgd.adam_x(key, gradf_0, gradf_i_batch, burnin, theta_0, x, n, step_sgd, replacement=True)
    samples, grads, runtime = sgldps.sgldps_sampler(subkey, gradf_0, gradf_i_batch, burnin, step_size, theta_hat, theta_hat, x, y, n,  prob_type='approx')
        
    runtime_df.append(np.concatenate((runtime_sgd, runtime_sgd[burnin-1]+runtime)))
    samples_df.append(np.concatenate((samples_sgd, samples[1:]), axis=0))
    grads_df.append(grads)

Number of chains run: 100%|██████████| 11/11 [40:11<00:00, 219.26s/it]


In [17]:
runtime_df = pd.DataFrame(np.column_stack(runtime_df))
runtime_df.to_csv("./out/bvg_sgldps_1_runtime.csv", index=False)
samples_df = pd.DataFrame(np.column_stack(samples_df))
samples_df.to_csv("./out/bvg_sgldps_1_samples.csv", index=False)
grads_df = pd.DataFrame(np.column_stack(grads_df))
grads_df.to_csv("./out/bvg_sgldps_1_grads.csv", index=False)

### SGLD-PS (0.05N)

In [18]:
n = n_batch[1]
burnin = np.int64((N*500)/n)
Niter = 2*burnin
runtime_df=[]
samples_df = []
grads_df = []

In [19]:
for i in tqdm(range(N_rep), desc = "Number of chains run"):
    key, subkey = random.split(key)
    theta_0 = random.multivariate_normal(key=key, mean=mu_0, cov=lambda_0)
    theta_hat, samples_sgd, runtime_sgd = sgd.adam_x(key, gradf_0, gradf_i_batch, burnin, theta_0, x, n, step_sgd, replacement=True)
    samples, grads, runtime = sgldps.sgldps_sampler(subkey, gradf_0, gradf_i_batch, burnin, step_size, theta_hat, theta_hat, x, y, n,  prob_type='approx')
        
    runtime_df.append(np.concatenate((runtime_sgd, runtime_sgd[burnin-1]+runtime)))
    samples_df.append(np.concatenate((samples_sgd, samples[1:]), axis=0))
    grads_df.append(grads)

Number of chains run: 100%|██████████| 11/11 [09:44<00:00, 53.12s/it]


In [20]:
runtime_df = pd.DataFrame(np.column_stack(runtime_df))
runtime_df.to_csv("./out/bvg_sgldps_5_runtime.csv", index=False)
samples_df = pd.DataFrame(np.column_stack(samples_df))
samples_df.to_csv("./out/bvg_sgldps_5_samples.csv", index=False)
grads_df = pd.DataFrame(np.column_stack(grads_df))
grads_df.to_csv("./out/bvg_sgldps_5_grads.csv", index=False)

### SGLD-PS (0.1N)

In [21]:
n = n_batch[2]
burnin = np.int64((N*500)/n)
Niter = 2*burnin
runtime_df=[]
samples_df = []
grads_df = []

In [22]:
for i in tqdm(range(N_rep), desc = "Number of chains run"):
    key, subkey = random.split(key)
    theta_0 = random.multivariate_normal(key=key, mean=mu_0, cov=lambda_0)
    theta_hat, samples_sgd, runtime_sgd = sgd.adam_x(key, gradf_0, gradf_i_batch, burnin, theta_0, x, n, step_sgd, replacement=True)
    samples, grads, runtime = sgldps.sgldps_sampler(subkey, gradf_0, gradf_i_batch, burnin, step_size, theta_hat, theta_hat, x, y, n,  prob_type='approx')
        
    runtime_df.append(np.concatenate((runtime_sgd, runtime_sgd[burnin-1]+runtime)))
    samples_df.append(np.concatenate((samples_sgd, samples[1:]), axis=0))
    grads_df.append(grads)

Number of chains run: 100%|██████████| 11/11 [05:42<00:00, 31.12s/it]


In [23]:
runtime_df = pd.DataFrame(np.column_stack(runtime_df))
runtime_df.to_csv("./out/bvg_sgldps_10_runtime.csv", index=False)
samples_df = pd.DataFrame(np.column_stack(samples_df))
samples_df.to_csv("./out/bvg_sgldps_10_samples.csv", index=False)
grads_df = pd.DataFrame(np.column_stack(grads_df))
grads_df.to_csv("./out/bvg_sgldps_10_grads.csv", index=False)