In [None]:
import numpy as np
import matplotlib.pyplot as plt

sys.path.append("../code/")
from pi_vae import *
from util import *
from keras.callbacks import ModelCheckpoint

## import plot packages
from matplotlib.ticker import FormatStrFormatter
from matplotlib import ticker

%load_ext autoreload
%autoreload 2
%matplotlib inline

# discrete label data

## load discrete label simulated data

In [None]:
dat = np.load('../data/sim/sim_100d_poisson_disc_label.npz');
u_true = dat['u'];
z_true = dat['z'];
x_true = dat['x'];

In [None]:
x_all = x_true.reshape(50,200,-1);
u_all = u_true.reshape(50,200,-1);

x_train = x_all[:40];
u_train = u_all[:40];

x_valid = x_all[40:45];
u_valid = u_all[40:45];

x_test = x_all[40:45];
u_test = u_all[40:45];

## fit pi-vae

In [None]:
np.random.seed(999);
vae = vae_mdl(dim_x=x_all[0].shape[-1],
                   dim_z=2,
                   dim_u=np.unique(np.concatenate(u_all)).shape[0],
                   gen_nodes=60, n_blk=2, mdl='poisson', disc=True, learning_rate=5e-4)

In [None]:
model_chk_path = '../results/sim_disc_nflow_2d_999.h5' ##999, 777
mcp = ModelCheckpoint(model_chk_path, monitor="val_loss", save_best_only=True, save_weights_only=True)
s_n = vae.fit_generator(custom_data_generator(x_train, u_train),
              steps_per_epoch=len(x_train), epochs=600,
              verbose=1,
              validation_data = custom_data_generator(x_valid, u_valid),
              validation_steps = len(x_valid), callbacks=[mcp]);

In [None]:
plt.plot(s_n.history['val_loss'][:])

In [None]:
model_chk_path = '../results/sim_disc_nflow_2d_999.h5'
vae.load_weights(model_chk_path);

In [None]:
outputs = vae.predict_generator(custom_data_generator(x_all, u_all),
                                                steps = len(x_all));
# post_mean, post_log_var, z_sample,fire_rate, lam_mean, lam_log_var, z_mean, z_log_var

In [None]:
c_vec = np.array(['red','orange','pink','green','indigo']);
fsz = 14;

ll = 5000;
plt.figure(figsize=(12,4));
ax1 = plt.subplot(1,3,1)
plt.scatter(z_true[:ll,0], z_true[:ll,1], c=c_vec[u_true], s=1,alpha=0.5);
ax1.set_xlabel('Latent 1',fontsize=fsz,fontweight='normal');
ax1.set_ylabel('Latent 2',fontsize=fsz,fontweight='normal');
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)
plt.setp(ax1.get_xticklabels(), fontsize=fsz);
plt.setp(ax1.get_yticklabels(), fontsize=fsz);
ax1.xaxis.set_major_locator(ticker.MaxNLocator(nbins=4,min_n_ticks=4,prune=None))
ax1.yaxis.set_major_locator(ticker.MaxNLocator(nbins=4,min_n_ticks=4,prune=None))

ax2 = plt.subplot(1,3,2)
plt.scatter(outputs[0][:ll,0], -outputs[0][:ll,1], c=c_vec[u_true], s=1,alpha=0.5);
#ax2.set_xlabel('Latent 1',fontsize=fsz,fontweight='normal');
#ax2.set_ylabel('Latent 2',fontsize=fsz,fontweight='normal');
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)
plt.setp(ax2.get_xticklabels(), fontsize=fsz);
plt.setp(ax2.get_yticklabels(), fontsize=fsz);
ax2.xaxis.set_major_locator(ticker.MaxNLocator(nbins=4,min_n_ticks=4,prune=None))
ax2.yaxis.set_major_locator(ticker.MaxNLocator(nbins=4,min_n_ticks=4,prune=None))

ax3 = plt.subplot(1,3,3)
plt.scatter(outputs[6][:ll,0], -outputs[6][:ll,1], c=c_vec[u_true], s=1,alpha=0.5);
#ax3.set_xlabel('Latent 1',fontsize=fsz,fontweight='normal');
#ax3.set_ylabel('Latent 2',fontsize=fsz,fontweight='normal');
ax3.spines['top'].set_visible(False)
ax3.spines['right'].set_visible(False)
plt.setp(ax3.get_xticklabels(), fontsize=fsz);
plt.setp(ax3.get_yticklabels(), fontsize=fsz);
ax3.xaxis.set_major_locator(ticker.MaxNLocator(nbins=4,min_n_ticks=4,prune=None))
ax3.yaxis.set_major_locator(ticker.MaxNLocator(nbins=4,min_n_ticks=4,prune=None))

plt.tight_layout();