# Training of Glows

## Preparations
* TensorFlow version == 1.x required
* Gast version == 0.2.x, TensorFlow-Gan version == 1.x required
* First few lines of code are for running on Google Colab

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
import numpy as np
import tensorflow as tf
from scipy.stats import norm
import matplotlib.pyplot as plt
tfd = tf.contrib.distributions

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
from tensorflow.python.client import device_lib
device_lib.list_local_devices()

In [None]:
import utils
import nets
import flow_layers as fl

In [None]:
print(tf.__version__)

In [None]:
config = tf.ConfigProto()
config.allow_soft_placement=True
config.gpu_options.allow_growth = True
tf.set_random_seed(0)
sess = tf.InteractiveSession(config=config)

## Model Structuring

### Import Datasets 
* Default resolution set to 256 in utils
* Any lower resolution entered triggers down-sampling

In [None]:
batch_size = 8
image_size = 128
x_train_samples = utils.create_tfrecord_dataset_iterator(
    "train.tfrecords", batch_size=batch_size, image_size=image_size
)
x_valid_samples = utils.create_tfrecord_dataset_iterator(
    "valid.tfrecords", batch_size=batch_size, image_size=image_size
)

### Check Shapes
* Input tensor of $[$batch_size, $h, w, c]$ expected
* In this setup should be $[$8, 256, 256, 3$]$

In [None]:
x_valid_samples.eval().shape

In [None]:
x_exampled = x_valid_samples.eval()

In [None]:
plt.imshow(utils.plot_grid(x_valid_samples).eval())

In [None]:
# benchmark dataset reading
%timeit -n 100 x_valid_samples.eval()

### Build Forward Flow
* Please refer to Implementation in report for reference
* Scale down accordingly depending on compute power

In [None]:
nn_template_fn = nets.OpenAITemplate(
    width=128
)

In [None]:
layers, actnorm_layers = nets.create_simple_flow(
    num_steps=32, 
    num_scales=5, 
    template_fn=nn_template_fn
)

images = x_train_samples
flow = fl.InputLayer(images)
model_flow = fl.ChainLayer(layers)
output_flow = model_flow(flow, forward=True)

### Output Tensors

In [None]:
y, logdet, z = output_flow
output_flow

### Loss Function

* $x$ partitioned into $y$, $z$ by affine coupling layers
* Total loss = log $p(x) + L_2$ loss

In [None]:
beta_ph = tf.placeholder(tf.float32, [])

y_flatten = tf.reshape(y, [batch_size, -1])
z_flatten = tf.reshape(z, [batch_size, -1])

prior_y = tfd.MultivariateNormalDiag(loc=tf.zeros_like(y_flatten), scale_diag=beta_ph * tf.ones_like(y_flatten))
prior_z = tfd.MultivariateNormalDiag(loc=tf.zeros_like(z_flatten), scale_diag=beta_ph * tf.ones_like(z_flatten))

log_prob_y =  prior_y.log_prob(y_flatten)
log_prob_z =  prior_z.log_prob(z_flatten)

* Main loss

In [None]:
loss = log_prob_y + log_prob_z + logdet
loss = - tf.reduce_mean(loss)

* $L_2$ Regularization Loss 

In [None]:
trainable_variables = tf.trainable_variables() 
l2_reg = 0.00001 
l2_loss = l2_reg * tf.add_n([ tf.nn.l2_loss(v) for v in trainable_variables])

### Print Trainable Variables 

In [None]:
total_params = 0
for k, v in enumerate(trainable_variables):
    num_params = np.prod(v.shape.as_list())
    total_params += num_params
#     print(f"[{k:4}][{num_params:6}] {v.op.name[:96]}")

print(f"total_params: {total_params/1e6} M")

### Total Loss

In [None]:
sess.run(tf.global_variables_initializer())

In [None]:
loss_per_pixel = loss / image_size / image_size  

In [None]:
total_loss = l2_loss + loss_per_pixel 

In [None]:
l2_loss.eval(feed_dict={beta_ph: 1.0}), loss_per_pixel.eval(feed_dict={beta_ph: 1.0})

### Build Backward Flow

In [None]:
sample_y_flatten = prior_y.sample()
sample_y = tf.reshape(sample_y_flatten, y.shape.as_list())
sample_z = tf.reshape(prior_z.sample(), z.shape.as_list())
sampled_logdet = prior_y.log_prob(sample_y_flatten)

In [None]:
inverse_flow = sample_y, sampled_logdet, sample_z
sampled_flow = model_flow(inverse_flow, forward=False)

In [None]:
total_params = 0
for k, v in enumerate(trainable_variables):
    num_params = np.prod(v.shape.as_list())
    total_params += num_params
print(f"total_params: {total_params/1e6} M")

In [None]:
x_flow_sampled, _, _ = sampled_flow

In [None]:
x_flow_sampled.eval({beta_ph: 1.0}).shape

## Model Training
### Optimizer & Learning Rate

In [None]:
lr_ph = tf.placeholder(tf.float32)
optimizer = tf.train.AdamOptimizer(lr_ph)
train_op = optimizer.minimize(total_loss)

### DDI of Actnorms

In [None]:
sess.run(tf.global_variables_initializer())
nets.initialize_actnorms(
    sess,
    feed_dict_fn=lambda: {beta_ph: 1.0},
    actnorm_layers=actnorm_layers,
    num_steps=10,
)

### Save First

In [None]:
saver = tf.train.Saver()
save_path = saver.save(sess, "t1-4-saves/steps.ckpt")
print("Model saved in path: %s" % save_path)

### Metrics & Trainer

In [None]:
metrics = utils.Metrics(50, metrics_tensors={"total_loss": total_loss, "loss_per_pixel": loss_per_pixel, "l2_loss": l2_loss})
plot_metrics_hook = utils.PlotMetricsHook(metrics, step=1000)

In [None]:
sess.run(train_op, feed_dict={lr_ph: 0.0, beta_ph: 1.0})

In [None]:
total_loss.eval(feed_dict={lr_ph: 0.0, beta_ph: 1.0})

### Check Initial Samples

In [None]:
quantize_image_layer = layers[0]
aux_feed_dict = {lr_ph: 0.0, beta_ph: 1.0}

In [None]:
x_flow_sampled_uint = quantize_image_layer.to_uint8(x_flow_sampled)
plt.imshow(utils.plot_grid(x_flow_sampled_uint).eval(aux_feed_dict))

### Training Starts HERE
* Training this model is time-consuming
* Very prone to gradient explosion so very small learning rate required
* Adjust accordingly to actual configuration

In [None]:
utils.trainer(
    sess, 
    num_steps=100, 
    train_op=train_op, 
    feed_dict_fn=lambda: {lr_ph: 0.000005, beta_ph: 1.0}, 
    metrics=[metrics], 
    hooks=[plot_metrics_hook]
)

saver = tf.train.Saver()
save_path = saver.save(sess, "t1-4-saves/steps.ckpt")
print("Model saved in path: %s" % save_path)

In [None]:
for i in range(5):
    utils.trainer(
        sess, 
        num_steps=1000, 
        train_op=train_op, 
        feed_dict_fn=lambda: {lr_ph: 0.00002, beta_ph: 1.0}, 
        metrics=[metrics], 
        hooks=[plot_metrics_hook]
    )
    
    plt.subplot(121)
    plt.imshow(utils.plot_grid(x_flow_sampled_uint).eval({lr_ph: 0.0, beta_ph: 0.9}))
    plt.subplot(122)
    plt.imshow(utils.plot_grid(x_flow_sampled_uint).eval({lr_ph: 0.0, beta_ph: 1.0}))
    plt.show()
    
    saver = tf.train.Saver()
    save_path = saver.save(sess, "t1-4-saves/steps.ckpt")
    print("Model saved in path: %s" % save_path)

In [None]:
for i in range(5):
    utils.trainer(
        sess, 
        num_steps=1000,
        train_op=train_op, 
        feed_dict_fn=lambda: {lr_ph: 0.00001, beta_ph: 1.0}, 
        metrics=[metrics], 
        hooks=[plot_metrics_hook]
    )
    
    plt.subplot(121)
    plt.imshow(utils.plot_grid(x_flow_sampled_uint).eval({lr_ph: 0.0, beta_ph: 0.9}))
    plt.subplot(122)
    plt.imshow(utils.plot_grid(x_flow_sampled_uint).eval({lr_ph: 0.0, beta_ph: 1.0}))
    plt.show()

    saver = tf.train.Saver()
    save_path = saver.save(sess, "t1-4-saves/steps.ckpt")
    print("Model saved in path: %s" % save_path)

### Optional: Test Effects of Different Temperatures

In [None]:
saver = tf.train.Saver()
saver.restore(sess, "t1-4-saves/steps.ckpt")

In [None]:
for beta in np.linspace(0.0, 1.5, 10):
    print(f"beta={beta:10.4f}")
    plt.figure(figsize=(5, 5))
    plt.imshow(utils.plot_grid(x_flow_sampled_uint).eval({lr_ph: 0.0, beta_ph: beta}))
    plt.show()

In [None]:
plot_metrics_hook.run()

## Model Evaluation 
### Augment $y_{a}$ for Next Step Comparison

In [None]:
y_flatten_np = np.concatenate([y_flatten.eval({lr_ph: 0.0, beta_ph: 1.0}) for i in range(100)])
y_flatten_np.shape

### Check Gaussianization of $p(y_{a})$
* $p(y_{a})$ expected to be Gaussianized
* Two distributions are expected to match

In [None]:
plt.scatter(y_flatten_np[:, 0], y_flatten_np[:, 1], label="sampled")
plt.scatter(*np.random.randn(2, 1000), alpha=0.7, label="N(0, 1)")
plt.axis("equal")
plt.legend()

### Augment $y_{b}$ for Next Step Comparison

In [None]:
z_flatten_np = np.concatenate([z_flatten.eval({lr_ph: 0.0, beta_ph: 1.0}) for i in range(100)])
z_flatten_np.shape

### Check $p(y_{b})$ (Trivial)

In [None]:
plt.scatter(z_flatten_np[:, 0], z_flatten_np[:, 1], label="sampled")
plt.scatter(*np.random.randn(2, 1000), alpha=0.7, label="N(0, 1)")
plt.axis("equal")
plt.legend()

## Control Variable Experiment

In [None]:
beta_y_ph = tf.placeholder(tf.float32, [])
beta_z_ph = tf.placeholder(tf.float32, [])
prior_y_prim = tfd.MultivariateNormalDiag(
    loc=tf.zeros_like(y_flatten), scale_diag=beta_y_ph * tf.ones_like(y_flatten))
prior_z_prim = tfd.MultivariateNormalDiag(
    loc=tf.zeros_like(z_flatten), scale_diag=beta_z_ph * tf.ones_like(z_flatten))

In [None]:
sample_y_flatten_prim = prior_y_prim.sample()
sample_y_prim = tf.reshape(sample_y_flatten_prim, y.shape.as_list())
sample_z_prim = tf.reshape(prior_z_prim.sample(), z.shape.as_list())
sampled_logdet_prim = prior_y_prim.log_prob(sample_y_flatten_prim)

In [None]:
inverse_flow_prim = sample_y_prim, sampled_logdet_prim, sample_z_prim
sampled_flow_prim = model_flow(inverse_flow_prim, forward=False)

In [None]:
x_flow_sampled_prim, _, _ = sampled_flow_prim

### Effect of $T_{b}$ with $T_{a}=1.0$

In [None]:
for beta in np.linspace(1.0, 2.0, 10):
    print(f"beta={beta:10.4f}")
    plt.figure(figsize=(5, 5))
    plt.imshow(utils.plot_grid(x_flow_sampled_prim).eval({
        lr_ph: 0.0,
        beta_y_ph: 1.0, 
        beta_z_ph: beta
    }))
    plt.show()

### Effect of $T_{a}$ with $T_{b}=1.0$

In [None]:
for beta in np.linspace(1.0, 2.0, 10):
    print(f"beta={beta:10.4f}")
    plt.figure(figsize=(7, 7))
    plt.imshow(utils.plot_grid(x_flow_sampled_prim).eval({
        lr_ph: 0.0,
        beta_y_ph: beta, 
        beta_z_ph: 1.0
    }))
    plt.show()