In [None]:
%matplotlib inline
from google.colab import drive
import sys
drive.mount('/content/drive')
sys.path.append('/content/drive/My Drive/Colab Notebooks/src')
from utils import Bootstrap
from collections import defaultdict
import matplotlib.pyplot as plt
import jax.numpy as jnp
from jax import jit, grad, lax, random, value_and_grad, vmap
from jax.experimental import stax, optimizers
from sklearn import datasets
from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler
import numpy as np
import pandas as pd
import seaborn as sns
import logging 
logging.basicConfig(level = logging.INFO)
logger = logging.getLogger('Autoencoder')
plt.rc('axes', labelsize=15)
plt.rc('xtick', labelsize=15)
plt.rc('ytick', labelsize=15)
iris = datasets.load_iris()
X = iris.data
y = iris.target
target_names = iris.target_names
ndim = 2
lr = 0.01
batch_size = 60
X = MinMaxScaler().fit_transform(X)
X.shape

In [None]:
def plot_dim_reduction(ax, predicted, target, title='', pca=False):
    dim_reduced = pd.DataFrame(predicted, columns=['Dim 1','Dim 2']) \
        .assign(target = target)  \
        .assign(target = lambda d: d.target.map(dict(zip(set(y),target_names))))

    sns.scatterplot(data=dim_reduced, x= 'Dim 1', y ='Dim 2',hue ='target', ax=ax)
    ax.set_title(title, size=15)
    if pca:
        ax.set_xlabel('PC1')
        ax.set_ylabel('PC2')

In [None]:
encoder_init, encode = stax.serial(
    stax.Dense(5), stax.Relu,
    stax.Dense(ndim), stax.Sigmoid
)

decoder_init, decode = stax.serial(
    stax.Dense(X.shape[1]), stax.Sigmoid,
)

In [None]:
rng = random.PRNGKey(1)  # fixed prng key for evaluation
encode_rng, decode_rng = random.split(rng)

In [None]:
@jit
def loss(predicted, Y):
    #mse function
    return jnp.mean( (Y - predicted)**2 )

@jit
def VAE(params, x):
    encoded = jit(encode)(params['encoder'], x)
    decoding = jit(decode)(params['decoder'], encoded)
    return loss(decoding, x)

In [None]:
epoch = 1000
losses = np.zeros(epoch)
params = {}
_, params['encoder'] = encoder_init(encode_rng, (batch_size, X.shape[1]))
_, params['decoder'] = decoder_init(decode_rng, (batch_size, ndim))
opt_init, opt_update, get_params = optimizers.adam(step_size = lr)
opt_state = opt_init(params)
bootstrap = Bootstrap()
minibatches = bootstrap.bootstrap(X, group_size = batch_size, n_boots = epoch)

for i in range(epoch):
    minibatch = X[next(minibatches)]
    rmse, gradients = value_and_grad(VAE)(get_params(opt_state), minibatch)
    losses[i] = rmse.mean()
    opt_state = opt_update(i, gradients, opt_state)
    if (i+1) % (epoch//5) == 0:
        logger.info('%i iteration: RMSE = %.2f' %(i+1, rmse))
plt.plot(losses)

In [None]:
fig = plt.figure(figsize=(8,4))
ax = fig.add_subplot(121)
vae = encode(params['encoder'], X)
plot_dim_reduction(ax, vae, y, title='Autoencoder', pca=False)
ax.legend().set_visible(False)
pca = PCA(n_components=2).fit_transform(X)
ax = fig.add_subplot(122)
plot_dim_reduction(ax, pca, y, title = 'PCA', pca=True)
fig.tight_layout()
ax.legend(fontsize=15, title='', bbox_to_anchor=(1,1), frameon=False)
sns.despine()

In [None]:
from sklearn.svm import SVC
vae_model = SVC()
vae_model.fit(vae, y)

pca_model = SVC()
pca_model.fit(pca, y)
print('Logit for VAE: %.3f vs Logit for PCA: %.3f' %(vae_model.score(vae,y), pca_model.score(pca,y)))