# Covertype

In [1]:
%load_ext autoreload
import sys
import os
sys.path.append("../../../learning_particle_gradients/")
from jax import config
config.update("jax_debug_nans", False)
from tqdm import tqdm
from jax import config


import jax.numpy as jnp
import jax.numpy as np
from jax import grad, jit, vmap, random, lax, jacfwd, value_and_grad
from jax.ops import index_update, index
import matplotlib.pyplot as plt
import matplotlib
import numpy as onp
import jax
import pandas as pd
import scipy
import haiku as hk
    
import utils
import plot
import distributions
import stein
import models
import flows
from itertools import cycle, islice

key = random.PRNGKey(0)

from sklearn.model_selection import train_test_split
from sklearn.calibration import calibration_curve

from functools import partial
import kernels
import metrics

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn import datasets
sns.set(style='white')

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
tfb = tfp.bijectors
tfpk = tfp.math.psd_kernels

import optax



In [2]:
# set up exporting
import matplotlib
matplotlib.use("pgf")
matplotlib.rcParams.update({
    "pgf.texsystem": "pdflatex",
#     'font.family': 'serif',
#     'text.usetex': True,
    'pgf.rcfonts': False,
    'axes.unicode_minus': False, # avoid unicode error on saving plots with negative numbers (??)
})

figure_path = "/home/lauro/documents/msc-thesis/thesis/figures/"
# save figures by using plt.savefig('path/to/fig')
# remember that latex textwidth is 5.4in
# so use figsize=[5.4, 4], for example
printsize = [5.4, 4]

In [3]:
%matplotlib inline

# Data

In [26]:
data = scipy.io.loadmat('/home/lauro/code/msc-thesis/wang_svgd/data/covertype.mat')
features = data['covtype'][:, 1:]
features = onp.hstack([features, onp.ones([features.shape[0], 1])]) # add intercept term

labels = data['covtype'][:, 0]
labels[labels == 2] = 0

xx, x_test, yy, y_test = train_test_split(features, labels, test_size=0.2, random_state=42)
x_train, x_val, y_train, y_val = train_test_split(xx, yy, test_size=0.1, random_state=0)

num_features = features.shape[-1]

In [27]:
# batch_size = 128
# num_datapoints = len(x_train)
# num_batches = num_datapoints // batch_size


# def get_batches(x, y, n_steps=num_batches*2, batch_size=batch_size):
#     """Split x and y into batches"""
#     assert len(x) == len(y)
#     assert x.ndim > y.ndim
#     n = len(x)
#     idxs = onp.random.choice(n, size=(n_steps, batch_size))
#     for idx in idxs:
#         yield x[idx], y[idx]
# #     batch_cycle = cycle(zip(*[onp.array_split(data, len(data)//batch_size) for data in (x, y)]))
# #     return islice(batch_cycle, n_steps)

# Model

In [6]:
%autoreload

In [7]:
a0, b0 = 1, 0.01 # hyper-parameters

In [8]:
from jax.scipy import stats, special

In [83]:
# alternative model
def sample_from_prior(key, num=100):
    keya, keyb = random.split(key)
    alpha = random.gamma(keya, a0, shape=(num,)) / b0
    w = random.normal(keyb, shape=(num, num_features))
    return w, np.log(alpha)


def prior_logp(w, log_alpha):
    """
    Returns logp(w, log_alpha) = sum_i(logp(wi, alphai))

    w has shape (num_features,), or (n, num_features)
    similarly, log_alpha may have shape () or (n,)"""
    if log_alpha.ndim == 0:
        assert w.ndim == 1
    elif log_alpha.ndim == 1:
        assert log_alpha.shape[0] == w.shape[0]

    alpha = np.exp(log_alpha)
    logp_alpha = np.sum(stats.gamma.logpdf(alpha, a0, scale=1/b0))
    if w.ndim == 2:
        logp_w = np.sum(vmap(lambda wi, alphai: stats.norm.logpdf(wi, scale=1/np.sqrt(alphai)))(w, alpha))
    elif w.ndim == 1:
        logp_w = np.sum(stats.norm.logpdf(w, scale=1/np.sqrt(alpha)))
    else:
        raise
    return logp_alpha + logp_w


def preds(x, w):
    """returns predicted p(y = 1| x, w)

    x can have shape (n, num_features) or (num_features,).
    w is a single param of shape (num_features,)"""
    return special.expit(x @ w)


def loglikelihood(y, x, w):
    """
    compute log p(y | x, w) for a single parameter w of
    shape (num_features,) and a batch of data (y, x) of
    shape (m,) and (m, num_features)

    log p(y | x, w) = sum_i(logp(yi| xi, w))
    """
    y = ((y - 1/2)*2).astype(np.int32)
    logits = x @ w
    prob_y = special.expit(logits*y)
    return np.sum(np.log(prob_y))


def log_posterior_unnormalized(y, x, w, log_alpha):
    """All is batched"""
    log_prior = prior_logp(w, log_alpha)
    log_likelihood = np.sum(vmap(lambda wi: loglikelihood(y, x, wi))(w))
    return log_prior + log_likelihood


def log_posterior_unnormalized_single_param(y, x, w, log_alpha):
    """y, x are batched, w, log_alpha not. In case I need
    an unbatched eval of the target logp."""
    log_prior = prior_logp(w, log_alpha)
    log_likelihood = loglikelihood(y, x, w)
    return log_prior + log_likelihood


def compute_probs(y, x, w):
    """
    returns P(y_generated==y | x, w)

    y and x are data batches. w is a single parameter
    array of shape (num_features,)"""
    y = ((y - 1/2)*2).astype(np.int32)
    logits = x @ w
    prob_y = special.expit(logits*y)
    return prob_y


@jit
def compute_test_accuracy(w):
    "w is a batch"
    probs = vmap(lambda wi: compute_probs(y_test, x_test, wi))(w)
    probs_y = np.mean(probs, axis=0)
    return np.mean(probs_y > 0.5)


@jit
def compute_train_accuracy(w):
    probs = vmap(lambda wi: compute_probs(y_train, x_train, wi))(w)
    probs_y = np.mean(probs, axis=0)
    return np.mean(probs_y > 0.5)


def ravel(w, log_alpha):
    return np.hstack([w, np.expand_dims(log_alpha, -1)])


def unravel(params):
    if params.ndim == 1:
        return params[:-1], params[-1]
    elif params.ndim == 2:
        return params[:, :-1], np.squeeze(params[:, -1])

In [197]:
n = 10**5

In [223]:
def target_log_prob(raveled_params):
    return log_posterior_unnormalized_single_param(yy, xx, *unravel(raveled_params))

# Model

In [224]:
key, subkey = random.split(key)

In [225]:
num_chains = 100
key, subkey = random.split(random.PRNGKey(0))
init_state = ravel(*sample_from_prior(subkey, num_chains))

@jit
def run_chain(key, state):
    kernel = tfp.mcmc.NoUTurnSampler(target_log_prob, 1e-8)
#     kernel = tfp.mcmc.UncalibratedLangevin(target_log_prob, 1e-6)
    return tfp.mcmc.sample_chain(1000,
      current_state=state,
      kernel=kernel,
      trace_fn=None,
      seed=key,
      return_final_kernel_results=True)

In [226]:
key, subkey = random.split(key)
sample_output = vmap(run_chain)(random.split(subkey, num_chains), init_state)



In [227]:
# logacc = sample_output.final_kernel_results.log_accept_ratio

In [None]:
# print(np.mean(np.exp(logacc)))

In [228]:
samples = sample_output.all_states

In [229]:
# s = unravel(np.reshape(samples, newshape=(num_chains*50, 56)))
s = unravel(samples[:, -1, :])

In [230]:
wsa, asa = s

In [None]:
compute_train_accuracy(wsa)