In [1]:
#import modules and set up environment 
import os
import sys
path = "../../src/"

sys.path.append(path)
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import jax.numpy as jnp
from jax import jit, grad, vmap, hessian, scipy, random

#import sgmcmc code 

import models.logistic_regression.logistic_regression as lr
from samplers import sgd as sgd
from samplers import sgld as sgld
from samplers import sgldps as sgldps
from samplers import sgldcv as sgldcv
from samplers import sgldcvps as sgldcvps

key = random.PRNGKey(11)

  import pandas.util.testing as tm


In [2]:
#set up model-specific gradient functions
gradf_0 = lr.gradf_0
gradf_i_batch = lr.gradf_i_batch
post_var = lr.post_var

### Set up model and data

In [3]:
#load in the data
file_path = "../../data/synthetic/lr_balance_train_synth.csv"
data = pd.read_csv(file_path)
dat_array = data.values[:]
N = dat_array.shape[0]
x = dat_array[:, 1:]
y = dat_array[:,0]

#set up model parameters
dim = x.shape[1] 

### Set up sampling framework

In [4]:
# step-size
step_size = 1e-4
step_sgd = 1e-3
# batch sizes
n_batch = np.int64(N*0.001)
# number of chains (first chain discarded due to slower runtime)
N_rep = 11
#iterations
burnin = np.int64(10/0.001)
Niter = 2*burnin

### Calculating V_0 for ASGLD-CV sampler

In [5]:
cv_samples = pd.read_csv("./out/lrb_sgldcv_samples.csv").iloc[burnin:].reset_index(drop=True)
modes_cv = cv_samples.values[0, :]

In [6]:
quantiles_cv = np.zeros(N_rep-1)
for i in range(1, N_rep):
    mode = modes_cv[(i-1)*(dim):i*dim]
    dists = np.sum((cv_samples.values[:, (i-1)*(dim):i*dim] - mode)**2, axis=1)
    quantiles_cv[i-1] = np.quantile(dists, q=0.95)

In [7]:
quantiles_cv

array([0.04352772, 0.04238889, 0.05069752, 0.04276705, 0.04196886,
       0.03771035, 0.03906198, 0.04459311, 0.04425378, 0.04224128])

In [8]:
lipschitz_cv = np.zeros(N)
for i in range(N):
    xxT = np.outer(x[i,:], x[i, :].T) + 1e-10*np.eye(dim)
    lipschitz_cv[i] = np.max(np.linalg.eigh(xxT)[0])

In [9]:
V_0_cv = (1/n_batch)*quantiles_cv*N*np.sum(lipschitz_cv**2)
max_V_0_cv = np.max(V_0_cv)
cons_cv = 1/max_V_0_cv * N*np.sum(lipschitz_cv**2)

In [10]:
cons_cv

197.24832543695533

### Calculating V_0 for ASGLD-CV sampler

In [11]:
cvps_samples = pd.read_csv("./out/lrb_sgldcvps_samples.csv").iloc[burnin:].reset_index(drop=True)
modes_cvps = cvps_samples.values[0, :]

In [12]:
quantiles_cvps = np.zeros(N_rep-1)
for i in range(1, N_rep):
    mode = modes_cvps[(i-1)*(dim):i*dim]
    dists = np.sum((cvps_samples.values[:, (i-1)*(dim):i*dim] - mode)**2, axis=1)
    quantiles_cvps[i-1] = np.quantile(dists, q=0.95)

In [13]:
quantiles_cvps

array([0.04591819, 0.04678468, 0.04261566, 0.0439665 , 0.04541675,
       0.04435076, 0.04942472, 0.04619804, 0.04464653, 0.04376622])

In [14]:
lipschitz_cvps = np.zeros(N)
for i in range(N):
    xxT = np.outer(x[i,:], x[i, :].T) + 1e-10*np.eye(dim)
    lipschitz_cvps[i] = 0.25 * np.max(np.linalg.eigh(xxT)[0])

In [15]:
V_0_cvps = (1/n_batch)*quantiles_cvps*N*np.sum(lipschitz_cvps**2)
max_V_0_cvps = np.max(V_0_cvps)
cons_cvps = 1/max_V_0_cvps * N*np.sum(lipschitz_cvps**2)

In [16]:
cons_cvps

202.3279062402471

### Running the adaptive SGLD-CV sampler

In [17]:
runtime_df=[]
samples_df = []
grads_df = []
n_df = []

In [18]:
for i in tqdm(range(N_rep), desc = "Number of chains run"):
    key, subkey = random.split(key)
    theta_0 = jnp.zeros(dim)
    theta_hat, samples_sgd, runtime_sgd = sgd.adam(key, gradf_0, gradf_i_batch, burnin, theta_0, x, y, n_batch, step_sgd, replacement=True)
    samples, grads, runtime, n_vec = sgldcv.asgld_cv_sampler(subkey, gradf_0, gradf_i_batch, burnin, step_size, theta_hat, 
                                                             theta_hat, x, y, cons_cv, replacement=True)    
    runtime_df.append(np.concatenate((runtime_sgd, runtime_sgd[burnin-1]+runtime))) #join output
    samples_df.append(np.concatenate((samples_sgd, samples[1:]), axis=0))
    grads_df.append(grads)
    n_df.append(n_vec)

Number of chains run: 100%|██████████| 11/11 [09:58<00:00, 54.44s/it]


In [19]:
runtime_df = pd.DataFrame(np.column_stack(runtime_df)) #save output
runtime_df.to_csv("./out/lrb_asgldcv_runtime.csv", index=False)
samples_df = pd.DataFrame(np.column_stack(samples_df))
samples_df.to_csv("./out/lrb_asgldcv_samples.csv", index=False)
grads_df = pd.DataFrame(np.column_stack(grads_df))
grads_df.to_csv("./out/lrb_asgldcv_grads.csv", index=False)
n_df = pd.DataFrame(np.column_stack(n_df))
n_df.to_csv("./out/lrb_asgldcv_n.csv", index=False)

### Running the ASGLD-CV-PS sampler

In [20]:
runtime_df=[]
samples_df = []
grads_df = []
n_df = []

In [21]:
for i in tqdm(range(N_rep), desc = "Number of chains run"):
    key, subkey = random.split(key)
    theta_0 = jnp.zeros(dim)
    theta_hat, samples_sgd, runtime_sgd = sgd.adam(key, gradf_0, gradf_i_batch, burnin, theta_0, x, y, 
                                                   n_batch, step_sgd, replacement=True)
    samples, grads, runtime, n_vec, probs = sgldcvps.asgld_cv_ps_sampler(subkey, gradf_0, gradf_i_batch, post_var, burnin, 
                                                                         step_size, theta_hat, theta_hat, x, y, cons_cvps)
    runtime_df.append(np.concatenate((runtime_sgd, runtime_sgd[burnin-1]+runtime))) #join output
    samples_df.append(np.concatenate((samples_sgd, samples[1:]), axis=0))
    grads_df.append(grads) 
    n_df.append(n_vec)

Number of chains run: 100%|██████████| 11/11 [14:47<00:00, 80.70s/it]


In [22]:
runtime_df = pd.DataFrame(np.column_stack(runtime_df))
runtime_df.to_csv("./out/lrb_asgldcvps_runtime.csv", index=False)
samples_df = pd.DataFrame(np.column_stack(samples_df))
samples_df.to_csv("./out/lrb_asgldcvps_samples.csv", index=False)
grads_df = pd.DataFrame(np.column_stack(grads_df))
grads_df.to_csv("./out/lrb_asgldcvps_grads.csv", index=False)
n_df = pd.DataFrame(np.column_stack(n_df))
n_df.to_csv("./out/lrb_asgldcvps_n.csv", index=False)