# 1.0 Load modules and define hyperparameters

In [1]:
"""
Code for training beta-TCVAE (beta = 100) on Tabula Muris heart data
"""
import os
os.chdir('..')

import math

import numpy as np
import pandas as pd

import tensorflow as tf
from tensorflow import distributions as ds


from nets import *
from lib import *


class Options(object):
    def __init__(self, num_cells_train, gex_size):
        self.num_cells_train =  num_cells_train  # number of cells
        self.gex_size = gex_size                 # number of genes
        self.epsilon_use = 1e-16                 # small constant value
        self.n_train_epochs = 100                 # number of epochs for training
        self.batch_size = 32                     # batch size of the GAN-based methods
        self.vae_batch_size = 128                # batch size of the VAE-based methods
        self.code_size = 10                      # number of codes
        self.noise_size = 118                    # number of noise variables
        self.inflate_to_size1 = 256              # number of neurons
        self.inflate_to_size2 = 512              # number of neurons
        self.inflate_to_size3 = 1024             # number of neurons
        self.TotalCorrelation_lamb = 100.0         # hyperparameter for the total correlation penalty in beta-TCVAE
        self.InfoGAN_fix_std = True              # fixing the standard deviation or not for the Q network of InfoGAN
        self.dropout_rate = 0.2                  # dropout hyperparameter
        self.disc_internal_size1 = 1024          # number of neurons
        self.disc_internal_size2 = 512           # number of neurons
        self.disc_internal_size3 = 10            # number of neurons
        self.num_cells_generate = 3000           # number of sampled cells
        self.GradientPenaly_lambda = 10.0        # hyperparameter for the gradient penalty of Wasserstein GANs
        self.latentSample_size = 1               # number of samples of the encoder of VAEs
        self.MutualInformation_lamb = 10.0       # hyperparameter for the mutual information penalty in InfoGAN
        self.Diters = 5                          # number of training discriminator network per training of generator network of Wasserstein GANs
        self.model_path = "./examples/models_tcvae/"        # path saving the model


# 2.0 Load data and define hyperparameters

In [2]:
data_matrix = np.load('./data/TabulaMurisHeart_Processed.npy')
data_meta = pd.read_csv("./data/TabulaMurisHeart_MetaInformation.csv")
opt = Options(data_matrix.shape[0], data_matrix.shape[1])


# 3.0 Define network tensors

In [3]:
z_v = tf.placeholder(tf.float32, shape = (None, opt.code_size))
X_v = tf.placeholder(tf.float32, shape = (None, opt.gex_size))

## encoder
z_gen_mean_v, z_gen_std_v = vaes_encoder(X_v, opt)

### reparameterization of latent space
batch_size = tf.shape(z_gen_mean_v)[0]
eps = tf.random_normal(shape=[batch_size, opt.code_size])
z_gen_data_v = z_gen_mean_v + z_gen_std_v * eps

### latent entropies in a minibatch
margin_entropy_mss, joint_entropy_mss = estimate_minibatch_mss_entropy(z_gen_mean_v, z_gen_std_v, z_gen_data_v, opt)
### total correlation in a minibatch
TotalCorre_mss = tf.reduce_sum(margin_entropy_mss) - tf.reduce_sum(joint_entropy_mss)

## decoder
z_gen_decoder = vaes_decoder(z_gen_data_v, opt)

### generated data
X_gen_data = z_gen_decoder.sample(opt.latentSample_size)
X_gen_data = tf.reshape(X_gen_data , tf.shape(X_gen_data)[1:])

## loss elements
### reconstruction error
z_gen_de = z_gen_decoder.log_prob(X_v)
z_gen_de_value = tf.reduce_sum(z_gen_de, [1])
rec_x_loss = - tf.reduce_mean(z_gen_de_value)

### latent prior and posterior probabilities
stg_prior = tf_standardGaussian_prior(tf.shape(X_v)[0], opt.code_size)
latent_prior = stg_prior.log_prob(z_gen_data_v)
latent_posterior = c_mutual_mu_var_entropy(z_gen_mean_v, z_gen_std_v, z_gen_data_v, opt)

### latent joint prior and posterior probabilities
latent_prior_joint = tf.reduce_sum(latent_prior, [1])
latent_posterior_joint = tf.reduce_sum(latent_posterior, [1])

### KL divergence
kl_latent = - tf.reduce_mean(latent_prior_joint) + tf.reduce_mean(latent_posterior_joint)

### VAE/beta-TCVAE loss function
obj_vae = rec_x_loss  + kl_latent + opt.TotalCorrelation_lamb * TotalCorre_mss

## time step
time_step = tf.placeholder(tf.int32)

## training tensors 
tf_all_vars = tf.trainable_variables()
encodervar  = [var for var in tf_all_vars if var.name.startswith("EncoderX2Z")]
decodervar  = [var for var in tf_all_vars if var.name.startswith("DecoderZ2X")]

optimizer_vae = tf.train.AdamOptimizer(1e-4)
opt_vae = optimizer_vae.minimize(obj_vae, var_list = encodervar + decodervar)

saver = tf.train.Saver()
global_step = tf.Variable(0, name = 'global_step', trainable = False, dtype = tf.int32)

sess = tf.InteractiveSession()	
init = tf.global_variables_initializer().run()
assign_step_zero = tf.assign(global_step, 0)
init_step = sess.run(assign_step_zero)


The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

Instructions for updating:
Use keras.layers.dense instead.
Instructions for updating:
Use keras.layers.BatchNormalization instead.  In particular, `tf.control_dependencies(tf.GraphKeys.UPDATE_OPS)` should not be used (consult the `tf.keras.layers.batch_normalization` documentation).
Instructions for updating:
Use keras.layers.dropout instead.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


Instructions for updating:
Use `tf.cast` instead.




Instructions for updating:
The TensorFlow Distributions library has moved to TensorFlow Probability (https://github.com/tensorflow/probability). You should update all references to use `tfp.distributions` instead of `tf.distributions`.
Instructions for updating:
The TensorFlow Distributions library has moved to TensorFlow Probability (https://github.com/tensorflow/probability). You should update all references to use `tfp.distributions` instead of `tf.distributions`.


# 4.0 Training the networks

In [4]:
x_input = data_matrix.copy()
index_shuffle = list(range(opt.num_cells_train))
current_step = 0

for epoch in range(opt.n_train_epochs):
    # shuffling the data per epoch
    np.random.shuffle(index_shuffle)
    x_input = x_input[index_shuffle, :]

    for i in range(0, opt.num_cells_train // opt.vae_batch_size):

        # train VAE/beta-TCVAE in each minibatch
        x_data = sample_X(x_input, opt.vae_batch_size)
        z_data = noise_prior(opt.vae_batch_size, opt.code_size)
        sess.run([opt_vae], {X_v : x_data, z_v: z_data, time_step : current_step})

        current_step += 1

    obj_vae_value, TC_value = sess.run([obj_vae, TotalCorre_mss], {X_v : x_data, z_v: z_data, time_step : current_step})

    print('epoch: {}; iteration: {}; beta-TCVAE loss:{}; Total Correlation:{}'.format(epoch, current_step, obj_vae_value, TC_value))

epoch: 0; iteration: 32; beta-TCVAE loss:4570.86328125; Total Correlation:-0.057712554931640625
epoch: 1; iteration: 64; beta-TCVAE loss:4504.205078125; Total Correlation:-0.1436929702758789
epoch: 2; iteration: 96; beta-TCVAE loss:4450.568359375; Total Correlation:-0.06525707244873047
epoch: 3; iteration: 128; beta-TCVAE loss:4516.85888671875; Total Correlation:-0.05398750305175781
epoch: 4; iteration: 160; beta-TCVAE loss:4458.2919921875; Total Correlation:-0.24315643310546875
epoch: 5; iteration: 192; beta-TCVAE loss:4540.57177734375; Total Correlation:-0.2803611755371094
epoch: 6; iteration: 224; beta-TCVAE loss:4491.02197265625; Total Correlation:-0.077880859375
epoch: 7; iteration: 256; beta-TCVAE loss:4473.6103515625; Total Correlation:-0.3150959014892578
epoch: 8; iteration: 288; beta-TCVAE loss:4387.6845703125; Total Correlation:-0.3378334045410156
epoch: 9; iteration: 320; beta-TCVAE loss:4430.58447265625; Total Correlation:-0.461151123046875
epoch: 10; iteration: 352; beta-T

epoch: 84; iteration: 2720; beta-TCVAE loss:4373.41845703125; Total Correlation:-0.9572696685791016
epoch: 85; iteration: 2752; beta-TCVAE loss:4354.94482421875; Total Correlation:-0.5725669860839844
epoch: 86; iteration: 2784; beta-TCVAE loss:4294.759765625; Total Correlation:-0.8115329742431641
epoch: 87; iteration: 2816; beta-TCVAE loss:4360.32470703125; Total Correlation:-0.9910449981689453
epoch: 88; iteration: 2848; beta-TCVAE loss:4450.7373046875; Total Correlation:-0.4662361145019531
epoch: 89; iteration: 2880; beta-TCVAE loss:4309.81689453125; Total Correlation:-1.151773452758789
epoch: 90; iteration: 2912; beta-TCVAE loss:4334.31982421875; Total Correlation:-0.7199783325195312
epoch: 91; iteration: 2944; beta-TCVAE loss:4364.9560546875; Total Correlation:-0.6222820281982422
epoch: 92; iteration: 2976; beta-TCVAE loss:4406.6953125; Total Correlation:-0.8430805206298828
epoch: 93; iteration: 3008; beta-TCVAE loss:4297.58740234375; Total Correlation:-1.0058116912841797
epoch: 94

# 5.0 Saving the trained model

In [5]:
model_file_path = opt.model_path + "models_tcvae"
saving_model = saver.save(sess, model_file_path, global_step = current_step)