## Preparations
* TensorFlow version == 1.x required
* Gast version == 0.2.x, TensorFlow-Gan version == 1.x required

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
import numpy as np
import tensorflow as tf
from scipy.stats import norm
import matplotlib.pyplot as plt
tfd = tf.contrib.distributions

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
from tensorflow.python.client import device_lib
device_lib.list_local_devices()

In [None]:
import utils
import nets
import flow_layers as fl

In [None]:
print(tf.__version__)

In [None]:
config = tf.ConfigProto()
config.allow_soft_placement=True
config.gpu_options.allow_growth = True
tf.set_random_seed(0)
sess = tf.InteractiveSession(config=config)

## Main Part

### Import Dataset

In [None]:
batch_size = 1
image_size = 128

valid_data = utils.create_tfrecord_dataset_iterator(
    "test_2.tfrecords", batch_size=batch_size, image_size=image_size
)

### Build Decoder Forward Flow

In [None]:
nn_template_fn = nets.OpenAITemplate(
    width=128
)

In [None]:
image = tf.placeholder(tf.float32, [1, image_size, image_size, 3])

In [None]:
layers, actnorm_layers = nets.create_simple_flow(
    num_steps=32, 
    num_scales=5, 
    template_fn=nn_template_fn
)

flow = fl.InputLayer(image)
model_flow = fl.ChainLayer(layers)
output_flow = model_flow(flow, forward=True)

### Read Stored Tensors

In [None]:
reader_a = tf.train.NewCheckpointReader("aux-saves/a.ckpt")
reader_b = tf.train.NewCheckpointReader("aux-saves/b.ckpt")
reader_c = tf.train.NewCheckpointReader("aux-saves/c.ckpt")

a = reader_a.get_tensor("a")
b = reader_b.get_tensor("b")
c = reader_c.get_tensor("c")

a_np = np.asarray(a, np.float32)
b_np = np.asarray(b, np.float32)
c_np = np.asarray(c, np.float32)

data_a = tf.convert_to_tensor(a)
data_b = tf.convert_to_tensor(b)
data_c = tf.convert_to_tensor(c)
decoder_input = data_a, data_b, data_c
decoder_input

In [None]:
decoder_output = model_flow(decoder_input, forward=False)

### Restore Weights

In [None]:
saver = tf.train.Saver()
saver.restore(sess, "t2-2-saves/steps.ckpt")

### Get Results

In [None]:
xx, yy, zz = decoder_output

### Choose Proper ROI after Plotting Out

In [None]:
x_exampled = xx.eval()
plt.imshow(x_exampled[0,0:128,0:128,:])

In [None]:
print(x_exampled[0,:,:,0].shape)
print(valid_data.eval()[0,:,:,0].shape)

### Dice

In [None]:
from scipy.spatial import distance
test_unwrapped = tf.reshape(x_exampled[:,:,:,2], [-1])
valid_unwrapped = tf.reshape(valid_data.eval()[:,:,:,2], [-1])

distance.dice(test_unwrapped.eval(), valid_unwrapped.eval())