# 1.0 Load modules and define hyperparameters

In [1]:
"""
Code for training InfoWGAN-GP on Tabula Muris heart data
"""
import os
os.chdir('..')

import math

import numpy as np
import pandas as pd

import tensorflow as tf
from tensorflow import distributions as ds


from nets import *
from lib import *


class Options(object):
    def __init__(self, num_cells_train, gex_size):
        self.num_cells_train =  num_cells_train  # number of cells
        self.gex_size = gex_size                 # number of genes
        self.epsilon_use = 1e-16                 # small constant value
        self.n_train_epochs = 10                 # number of epochs for training
        self.batch_size = 32                     # batch size of the GAN-based methods
        self.vae_batch_size = 128                # batch size of the VAE-based methods
        self.code_size = 10                      # number of codes
        self.noise_size = 118                    # number of noise variables
        self.inflate_to_size1 = 256              # number of neurons
        self.inflate_to_size2 = 512              # number of neurons
        self.inflate_to_size3 = 1024             # number of neurons
        self.TotalCorrelation_lamb = 100.0         # hyperparameter for the total correlation penalty in beta-TCVAE
        self.InfoGAN_fix_std = True              # fixing the standard deviation or not for the Q network of InfoGAN
        self.dropout_rate = 0.2                  # dropout hyperparameter
        self.disc_internal_size1 = 1024          # number of neurons
        self.disc_internal_size2 = 512           # number of neurons
        self.disc_internal_size3 = 10            # number of neurons
        self.num_cells_generate = 3000           # number of sampled cells
        self.GradientPenaly_lambda = 10.0        # hyperparameter for the gradient penalty of Wasserstein GANs
        self.latentSample_size = 1               # number of samples of the encoder of VAEs
        self.MutualInformation_lamb = 10.0       # hyperparameter for the mutual information penalty in InfoGAN
        self.Diters = 5                          # number of training discriminator network per training of generator network of Wasserstein GANs
        self.model_path = "./examples/models_infowgangp/"        # path saving the model

# 2.0 Load data and define hyperparameters

In [2]:
data_matrix = np.load('./data/TabulaMurisHeart_Processed.npy')
data_meta = pd.read_csv("./data/TabulaMurisHeart_MetaInformation.csv")
opt = Options(data_matrix.shape[0], data_matrix.shape[1])

# 3.0 Define network tensors

In [3]:
z = tf.placeholder(tf.float32, shape = (None, (opt.code_size + opt.noise_size) ))
X = tf.placeholder(tf.float32, shape = (None, opt.gex_size))

## generator
X_gen_data = gan_generator(z, opt)


## discriminator 
Dx_real, Dx_qc_real = infogan_discriminator(X, opt)
Dx_fake, Dx_qc_fake = infogan_discriminator(X_gen_data, opt, reuse = True)

### compute E(log[q(c|X)])
q_vec = MutualInformationLowerBound(Dx_qc_fake, z, opt)
q_mutual = tf.reduce_mean(q_vec)

## discriminator loss
obj_d_or = tf.reduce_mean(Dx_real) - tf.reduce_mean(Dx_fake) 
gradient_penalty = compute_gp(X, X_gen_data, infogan_discriminator, opt)
obj_d_orgp = obj_d_or + opt.GradientPenaly_lambda * gradient_penalty
obj_d = obj_d_orgp - (q_mutual * opt.MutualInformation_lamb) 

## generator loss
obj_g = tf.reduce_mean(Dx_fake) - (q_mutual * opt.MutualInformation_lamb) 

## time step
time_step = tf.placeholder(tf.int32)

## training tensors 
tf_all_vars = tf.trainable_variables()
dvar  = [var for var in tf_all_vars if var.name.startswith("InfoGANDiscriminator")]
gvar  = [var for var in tf_all_vars if var.name.startswith("Generator")]


### Thanks to taki0112 for the TF StableGAN implementation https://github.com/taki0112/StableGAN-Tensorflow
from Adam_prediction import Adam_Prediction_Optimizer
opt_g = Adam_Prediction_Optimizer(learning_rate = 1e-4, beta1 = 0.9, beta2 = 0.999, prediction = True).minimize(obj_g, var_list = gvar)
opt_d = Adam_Prediction_Optimizer(learning_rate = 1e-4, beta1 = 0.9, beta2 = 0.999, prediction = False).minimize(obj_d, var_list = dvar)


saver = tf.train.Saver()
global_step = tf.Variable(0, name = 'global_step', trainable = False, dtype = tf.int32)

sess = tf.InteractiveSession()	
init = tf.global_variables_initializer().run()
assign_step_zero = tf.assign(global_step, 0)
init_step = sess.run(assign_step_zero)


The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

Instructions for updating:
Use keras.layers.dense instead.
Instructions for updating:
Use keras.layers.BatchNormalization instead.  In particular, `tf.control_dependencies(tf.GraphKeys.UPDATE_OPS)` should not be used (consult the `tf.keras.layers.batch_normalization` documentation).
Instructions for updating:
Use keras.layers.dropout instead.


Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor




Instructions for updating:
The TensorFlow Distributions library has moved to TensorFlow Probability (https://github.com/tensorflow/probability). You should update all references to use `tfp.distributions` instead of `tf.distributions`.


Instructions for updating:
The TensorFlow Distributions library has moved to TensorFlow Probability (https://github.com/tensorflow/probability). You should update all references to use `tfp.distributions` instead of `tf.distributions`.



Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


# 3.0 Training the networks

In [4]:
x_input = data_matrix.copy()
index_shuffle = list(range(opt.num_cells_train))
current_step = 0

for epoch in range(opt.n_train_epochs):
    # shuffling the data per epoch
    np.random.shuffle(index_shuffle)
    x_input = x_input[index_shuffle, :]

    for i in range(0, opt.num_cells_train // opt.batch_size):
        # train discriminator opt.Diters times per training of generator
        for k in range(opt.Diters):

            x_data = sample_X(x_input, opt.batch_size)
            z_data = noise_prior(opt.batch_size, (opt.code_size + opt.noise_size) )

            sess.run([opt_d], {X : x_data, z: z_data, time_step : current_step})


        # train generator
        x_data = sample_X(x_input, opt.batch_size)
        z_data = noise_prior(opt.batch_size, (opt.code_size + opt.noise_size) )
        sess.run([opt_g], {X : x_data, z: z_data, time_step : current_step})

        current_step += 1

    obj_d_value, obj_g_value, obj_q_value = sess.run([obj_d, obj_g, q_mutual], {X : x_data, z: z_data, time_step : current_step})

    print('epoch: {}; iteration: {}; generator loss: {}; discriminator loss:{}; mutual information lower bound:{}'.format(epoch, current_step, obj_g_value, obj_d_value, obj_q_value))


epoch: 0; iteration: 131; generator loss: 118.9593505859375; discriminator loss:84.12515258789062; mutual information lower bound:-11.614373207092285
epoch: 1; iteration: 262; generator loss: 119.06475067138672; discriminator loss:90.00970458984375; mutual information lower bound:-11.576972007751465
epoch: 2; iteration: 393; generator loss: 106.97241973876953; discriminator loss:89.49098205566406; mutual information lower bound:-11.18041706085205
epoch: 3; iteration: 524; generator loss: 103.58719635009766; discriminator loss:88.03169250488281; mutual information lower bound:-10.870536804199219
epoch: 4; iteration: 655; generator loss: 100.53746795654297; discriminator loss:78.76022338867188; mutual information lower bound:-10.644318580627441
epoch: 5; iteration: 786; generator loss: 98.02869415283203; discriminator loss:75.3740005493164; mutual information lower bound:-10.496047019958496
epoch: 6; iteration: 917; generator loss: 95.2208251953125; discriminator loss:69.76828002929688; 

# 5.0 Saving the trained model

In [5]:
model_file_path = opt.model_path + "models_infowgangp"
saving_model = saver.save(sess, model_file_path, global_step = current_step)