In [1]:
import os
import pickle

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from experiments.logisticRegression.mnist.load_mnist import mnist_dataset
from experiments.logisticRegression.utils import get_tgt_log_density
from variational.exponential_family import GenericMeanFieldNormalDistribution, NormalDistribution, MeanFieldNormalDistribution

flipped_predictors = mnist_dataset(path_prefix="..")
_, flipped_predictors_test = mnist_dataset(return_test=True, flip=False, path_prefix="..")
predictors_test, labels_test = flipped_predictors_test
N, dim = flipped_predictors.shape

# Gaussian Prior
my_prior_covariance = 25 * jnp.identity(dim)
#my_prior_covariance = my_prior_covariance.at[0, 0].set(400)
my_prior_log_density = NormalDistribution(jnp.zeros(dim), my_prior_covariance).log_density
tgt_log_density = jax.vmap(get_tgt_log_density(flipped_predictors, my_prior_log_density))

PKLs = []
PKL_titles = []
for file in os.listdir("./"):
    if file.endswith(".pkl"):
        PKLs.append(pickle.load(open(file, "rb")))
        PKL_titles.append(str(file))

full_gaussian = GenericMeanFieldNormalDistribution(dimension=dim)

lossesPKLs = []
lossesPKL_titles = []
for file in os.listdir("./losses/"):
    if file.endswith(".pkl"):
        lossesPKLs.append(pickle.load(open("./losses/" + file, "rb")))
        lossesPKL_titles.append(str(file))

In [3]:
for idx, loss in enumerate(lossesPKLs):
    plt.plot(lossesPKLs[idx][10:500], label=lossesPKL_titles[idx])
        
plt.legend()

In [15]:
np.savetxt('loss_500_10000_heuristic_mnist_seq1em3_u10.csv', np.array([np.arange(1, 500+1), lossesPKLs[0][1:]]).T, delimiter=",")
np.savetxt('loss_500_10000_heuristic_mnist_seq1u10.csv', np.array([np.arange(1, 500+1), lossesPKLs[3][1:]]).T, delimiter=",")
np.savetxt('loss_500_10000_nicolas_mnist_seq1em3.csv', np.array([np.arange(1, 500+1), lossesPKLs[2][1:]]).T, delimiter=",")
np.savetxt('loss_500_10000_blackjax_mnist_1em3_u10.csv', np.array([np.arange(1, 500+1), lossesPKLs[1][1:]]).T, delimiter=",")


In [27]:
mean, cov = full_gaussian.get_mean_cov(PKLs[0]['res'][-1][:-1])
OP_key = jax.random.PRNGKey(0)
samples = jax.random.multivariate_normal(key=OP_key, mean=mean, cov=jnp.diag(cov), shape=(1000,))
plt.hist(tgt_log_density(samples), bins=50)

In [21]:
plt.plot(full_gaussian.get_mean_cov(PKLs[1]['res'][-100][:-1])[1])

In [2]:
def logistic_fun(x):
    return 1/(jnp.exp(-x)+1)
mean_cov_tuple = full_gaussian.get_mean_cov(PKLs[0]['res'][-100][:-1])
learnt_gaussian = MeanFieldNormalDistribution(*mean_cov_tuple)
OP_key = jax.random.PRNGKey(0)
def eval_model(OP_key, gaussian):
    sample_keys = jax.random.split(OP_key, 10)
    sampled_betas = gaussian.sampling_method(sample_keys)
    X = predictors_test@sampled_betas.T
    idxs = jnp.where(logistic_fun(X)[:,0]<=0.5)
    predictions = jnp.zeros(shape=(10, 1954))
    rates = jnp.zeros(shape=(10,))
    for i in range(10):
        _ = jnp.zeros(shape=(1954)) - 1
        #print(jnp.where(logistic_fun(X)[:,i]<=0.5))
        _ = _.at[logistic_fun(X)[:,i]>0.5].set(1)
        predictions = predictions.at[i].set(_)
        rate = (jnp.sum(labels_test == _)/1954)
        print(rate)
        rates = rates.at[i].set(rate)
    print(jnp.mean(rates), jnp.std(rates))
    return rates
eval_model(OP_key, learnt_gaussian)

0.9912999
0.9861822
0.986694
0.9907881
0.9882293
0.9943705
0.9907881
0.9907881
0.9928352
0.98311156
0.9895086 0.0032306449


Array([0.9912999 , 0.9861822 , 0.986694  , 0.9907881 , 0.9882293 ,
       0.9943705 , 0.9907881 , 0.9907881 , 0.9928352 , 0.98311156],      dtype=float32)

In [None]:
mean_cov_tuple = PKLs[1]['states'][0][-1], jnp.exp(PKLs[1]['states'][1][-1])
learnt_gaussian = MeanFieldNormalDistribution(*mean_cov_tuple)
eval_model(OP_key, learnt_gaussian)

In [4]:
PKL_titles

['gaussianMeanField_Nicolas_500_10000_Seq4_[0 0].pkl.pkl',
 'heuristic_gaussian_Nicolas_5_10000_Seq1_u10_[0 0].pkl',
 'heuristic_gaussian_Nicolas_500_10000_Seq3_u10_[0 0].pkl',
 'res_mfg_advi_blackjax_500_10000_0.001.pkl',
 'heuristic_gaussian_Nicolas_500_10000_Seq1_u10_[0 0].pkl']

In [3]:
learnt_gaussian.sampling_method(keys)

NameError: name 'keys' is not defined

In [28]:
s = lambda k: jax.random.multivariate_normal(k,  jnp.zeros(10), jnp.diag(jnp.ones(10)))

In [4]:
g = MeanFieldNormalDistribution(*mean_cov_tuple)

In [7]:
g.sampling_method(OP_key)

Array([ 0.313475  ,  0.441408  ,  0.55896103, -0.23439744, -0.02954961,
        0.05229205,  0.26854497,  0.7707008 ,  0.07143676, -0.53043294,
        0.2203336 , -0.76579785,  0.3901741 ,  0.01817752, -1.2995163 ,
        0.00795902,  0.04268111,  0.7264216 , -0.05677755, -0.7872324 ,
        0.81912494, -0.21688269,  0.55443156, -0.5571157 , -0.3117542 ,
        0.7961377 , -0.00888165, -0.39085254,  0.7081708 , -1.0156626 ,
       -0.19626148,  0.6325965 ,  0.03863332, -0.384978  ,  0.09495692,
        0.44962713, -0.7146758 , -0.20064822,  0.32527444, -0.17603156,
        0.50992996,  0.08772411, -0.22887367, -0.2101209 , -0.24070154,
       -0.0281838 , -0.6426368 , -1.1692498 ,  0.48823455, -0.07868916,
        0.2440154 ,  1.0271785 , -1.7143891 , -0.35170212, -0.01248175,
       -0.1805342 , -0.24478248,  1.0311053 ,  0.31660888,  0.7952634 ,
        0.38795912,  0.13418059, -0.52439713,  0.6559715 ,  0.2184682 ,
       -0.38576838,  0.71528476,  0.21952145,  0.4822137 , -0.60

In [10]:
jax.vmap(g.sampling_method)(jax.random.split(OP_key, 10))

Array([[ 0.15115173, -0.2574636 ,  0.5681663 , ...,  0.71017325,
         0.04409748, -0.09633517],
       [-0.10271245,  0.07634994, -0.22728637, ...,  0.31605375,
        -0.6273328 , -0.38868678],
       [-0.0704969 ,  0.07580637, -0.5471061 , ...,  1.0101757 ,
        -0.07262857, -0.01127673],
       ...,
       [-0.02094457, -0.07891387,  0.05349177, ...,  0.35268953,
         0.47787023, -0.76062846],
       [ 0.02996048,  0.59578544,  0.45737356, ..., -1.232112  ,
         0.2867137 ,  0.13808155],
       [-0.32137984, -0.38593614, -0.03017567, ..., -0.39324567,
         0.49051586, -0.19054556]], dtype=float32)