# 1.0 Load modules and define hyperparameters

In [1]:
"""
Code for training beta-TCVAE (beta = 100) on Tabula Muris heart data
"""
import os
os.chdir('..')

import math

import numpy as np
import pandas as pd

import tensorflow as tf
from tensorflow import distributions as ds
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

from nets import *
from lib import *


class Options(object):
    def __init__(self, num_cells_train, gex_size):
        self.num_cells_train =  num_cells_train  # number of cells
        self.gex_size = gex_size                 # number of genes
        self.epsilon_use = 1e-16                 # small constant value
        self.n_train_epochs = 100                 # number of epochs for training
        self.batch_size = 32                     # batch size of the GAN-based methods
        self.vae_batch_size = 128                # batch size of the VAE-based methods
        self.code_size = 10                      # number of codes
        self.noise_size = 118                    # number of noise variables
        self.inflate_to_size1 = 256              # number of neurons
        self.inflate_to_size2 = 512              # number of neurons
        self.inflate_to_size3 = 1024             # number of neurons
        self.TotalCorrelation_lamb = 100.0         # hyperparameter for the total correlation penalty in beta-TCVAE
        self.InfoGAN_fix_std = True              # fixing the standard deviation or not for the Q network of InfoGAN
        self.dropout_rate = 0.2                  # dropout hyperparameter
        self.disc_internal_size1 = 1024          # number of neurons
        self.disc_internal_size2 = 512           # number of neurons
        self.disc_internal_size3 = 10            # number of neurons
        self.num_cells_generate = 3000           # number of sampled cells
        self.GradientPenaly_lambda = 10.0        # hyperparameter for the gradient penalty of Wasserstein GANs
        self.latentSample_size = 1               # number of samples of the encoder of VAEs
        self.MutualInformation_lamb = 10.0       # hyperparameter for the mutual information penalty in InfoGAN
        self.Diters = 5                          # number of training discriminator network per training of generator network of Wasserstein GANs
        self.model_path = "./examples/models_tcvae/"        # path saving the model


# 2.0 Load data and define hyperparameters

In [2]:
data_matrix = np.load('./data/TabulaMurisHeart_Processed.npy')
data_meta = pd.read_csv("./data/TabulaMurisHeart_MetaInformation.csv")
opt = Options(data_matrix.shape[0], data_matrix.shape[1])


# 3.0 Define network tensors

In [3]:
z_v = tf.placeholder(tf.float32, shape = (None, opt.code_size))
X_v = tf.placeholder(tf.float32, shape = (None, opt.gex_size))

## encoder
z_gen_mean_v, z_gen_std_v = vaes_encoder(X_v, opt)

### reparameterization of latent space
batch_size = tf.shape(z_gen_mean_v)[0]
eps = tf.random_normal(shape=[batch_size, opt.code_size])
z_gen_data_v = z_gen_mean_v + z_gen_std_v * eps

### latent entropies in a minibatch
margin_entropy_mss, joint_entropy_mss = estimate_minibatch_mss_entropy(z_gen_mean_v, z_gen_std_v, z_gen_data_v, opt)
### total correlation in a minibatch
TotalCorre_mss = tf.reduce_sum(margin_entropy_mss) - tf.reduce_sum(joint_entropy_mss)

## decoder
z_gen_decoder = vaes_decoder(z_gen_data_v, opt)

### generated data
X_gen_data = z_gen_decoder.sample(opt.latentSample_size)
X_gen_data = tf.reshape(X_gen_data , tf.shape(X_gen_data)[1:])

## loss elements
### reconstruction error
z_gen_de = z_gen_decoder.log_prob(X_v)
z_gen_de_value = tf.reduce_sum(z_gen_de, [1])
rec_x_loss = - tf.reduce_mean(z_gen_de_value)

### latent prior and posterior probabilities
stg_prior = tf_standardGaussian_prior(tf.shape(X_v)[0], opt.code_size)
latent_prior = stg_prior.log_prob(z_gen_data_v)
latent_posterior = c_mutual_mu_var_entropy(z_gen_mean_v, z_gen_std_v, z_gen_data_v, opt)

### latent joint prior and posterior probabilities
latent_prior_joint = tf.reduce_sum(latent_prior, [1])
latent_posterior_joint = tf.reduce_sum(latent_posterior, [1])

### KL divergence
kl_latent = - tf.reduce_mean(latent_prior_joint) + tf.reduce_mean(latent_posterior_joint)

### VAE/beta-TCVAE loss function
obj_vae = rec_x_loss  + kl_latent + opt.TotalCorrelation_lamb * TotalCorre_mss

## time step
time_step = tf.placeholder(tf.int32)

## training tensors 
tf_all_vars = tf.trainable_variables()
encodervar  = [var for var in tf_all_vars if var.name.startswith("EncoderX2Z")]
decodervar  = [var for var in tf_all_vars if var.name.startswith("DecoderZ2X")]

optimizer_vae = tf.train.AdamOptimizer(1e-4)
opt_vae = optimizer_vae.minimize(obj_vae, var_list = encodervar + decodervar)

saver = tf.train.Saver()
global_step = tf.Variable(0, name = 'global_step', trainable = False, dtype = tf.int32)

sess = tf.InteractiveSession()	
init = tf.global_variables_initializer().run()
assign_step_zero = tf.assign(global_step, 0)
init_step = sess.run(assign_step_zero)

# 4.0 Training the networks

In [4]:
x_input = data_matrix.copy()
index_shuffle = list(range(opt.num_cells_train))
current_step = 0

for epoch in range(opt.n_train_epochs):
    # shuffling the data per epoch
    np.random.shuffle(index_shuffle)
    x_input = x_input[index_shuffle, :]

    for i in range(0, opt.num_cells_train // opt.vae_batch_size):

        # train VAE/beta-TCVAE in each minibatch
        x_data = sample_X(x_input, opt.vae_batch_size)
        z_data = noise_prior(opt.vae_batch_size, opt.code_size)
        sess.run([opt_vae], {X_v : x_data, z_v: z_data, time_step : current_step})

        current_step += 1

    obj_vae_value, TC_value = sess.run([obj_vae, TotalCorre_mss], {X_v : x_data, z_v: z_data, time_step : current_step})

    print('epoch: {}; iteration: {}; beta-TCVAE loss:{}; Total Correlation:{}'.format(epoch, current_step, obj_vae_value, TC_value))

epoch: 0; iteration: 32; beta-TCVAE loss:4432.302734375; Total Correlation:-0.04923057556152344
epoch: 1; iteration: 64; beta-TCVAE loss:4463.00927734375; Total Correlation:-0.1279582977294922
epoch: 2; iteration: 96; beta-TCVAE loss:4451.14404296875; Total Correlation:-0.1334972381591797
epoch: 3; iteration: 128; beta-TCVAE loss:4526.353515625; Total Correlation:-0.1233367919921875
epoch: 4; iteration: 160; beta-TCVAE loss:4477.26513671875; Total Correlation:-0.3037757873535156
epoch: 5; iteration: 192; beta-TCVAE loss:4468.34814453125; Total Correlation:-0.15694332122802734
epoch: 6; iteration: 224; beta-TCVAE loss:4524.65771484375; Total Correlation:-0.2891826629638672
epoch: 7; iteration: 256; beta-TCVAE loss:4391.0283203125; Total Correlation:-0.44116687774658203
epoch: 8; iteration: 288; beta-TCVAE loss:4443.611328125; Total Correlation:-0.3621349334716797
epoch: 9; iteration: 320; beta-TCVAE loss:4424.42236328125; Total Correlation:-0.3827381134033203
epoch: 10; iteration: 352; 

epoch: 84; iteration: 2720; beta-TCVAE loss:4362.6318359375; Total Correlation:-0.8049888610839844
epoch: 85; iteration: 2752; beta-TCVAE loss:4324.76025390625; Total Correlation:-0.9450263977050781
epoch: 86; iteration: 2784; beta-TCVAE loss:4391.02392578125; Total Correlation:-0.6192569732666016
epoch: 87; iteration: 2816; beta-TCVAE loss:4379.9794921875; Total Correlation:-0.8304710388183594
epoch: 88; iteration: 2848; beta-TCVAE loss:4335.74609375; Total Correlation:-0.8605022430419922
epoch: 89; iteration: 2880; beta-TCVAE loss:4323.84423828125; Total Correlation:-0.8708629608154297
epoch: 90; iteration: 2912; beta-TCVAE loss:4375.587890625; Total Correlation:-1.0577564239501953
epoch: 91; iteration: 2944; beta-TCVAE loss:4264.642578125; Total Correlation:-1.043935775756836
epoch: 92; iteration: 2976; beta-TCVAE loss:4278.5234375; Total Correlation:-1.09222412109375
epoch: 93; iteration: 3008; beta-TCVAE loss:4342.97998046875; Total Correlation:-1.0565204620361328
epoch: 94; itera

# 5.0 Saving the trained model

In [5]:
model_file_path = opt.model_path + "models_tcvae"
saving_model = saver.save(sess, model_file_path, global_step = current_step)