# Stochastic Video Generation 
### Using a variational recurrent structure

### Load packages

In [1]:
# Import modules
from __future__ import print_function
import tensorflow as tf
import tensorflow_datasets as tfds

physical_devices = tf.config.list_physical_devices('GPU')
#tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)

import math
import numpy as np
from numpy.random import shuffle
import time
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim_metric

from files.model_svg import *
from files.utils import *

# Plot configurations
%matplotlib inline

# Notebook auto reloads code. (Ref: http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython)
%load_ext autoreload
%autoreload 2
print(tf.__version__)

2.4.0


### Load data

In [3]:
batch_size = 256

train_dataset, test_dataset = load_smmnist(batch_size)

KeyboardInterrupt: 

In [None]:
# Visualize our data: visualize one video from the training data
visualize_one_video(train_dataset)

### Build model and train

In [None]:
# encoder_dim, lstm_q_dim, lstm_prior_dim, latent_dim, lstm_dec_dim, lstm_dec_out_dim
model = SVG(128, 256, 256, 10, 256, 128, use_skip=True, num_frame=dim_l)
model.train(train_dataset, test_dataset, epochs=1000, lr=5e-4, batch_size=batch_size, test_generate_batch_size=10)



### Load traind models

In [None]:
model.load_weights('trained_models/model3/model')

In [None]:
# Generate once on all test videos

num_test_videos = 2000

ssim = np.zeros((num_test_videos, 15))

for x_batch in test_dataset.take(1):
    x_batch = x_batch[:num_test_videos,:,:,:,:]
    # Initialize with the first frame and condition on the rest like in training
    model.svg_cell.x_tm1 = x_batch[:,0,:,:,:]
    x_batch = x_batch[:,1:,:,:,:]

    # Split x_batch into conditioned ones, and ones that model doesn't know
    x_batch_cond, x_batch_gen = tf.split(
        x_batch, 
        num_or_size_splits=[5-1, 
                            15], 
        axis=1)
    x_batch_gen_numpy = x_batch_gen.numpy()

    # Run on the conditioned frames:
    out = model(x_batch_cond)
    mean, logvar, mean_0, logvar_0, z, x_recons, \
    lstm_q_states, lstm_prior_states, \
    lstm_dec_1_states, lstm_dec_2_states = out
    states = [lstm_q_states, lstm_prior_states, lstm_dec_1_states, lstm_dec_2_states]
    x_out = x_recons[:,5-2,:,:,:]

    # Generate:
    print('\n------- Generating -------+\n')
    print('Original video (15 frames)')
    fig = plt.figure(figsize=(16, 4))
    plt.subplots_adjust(wspace=0.1, hspace=0)
    for t in range(15):
        plt.subplot(1, 15, t+1)
        plt.imshow(x_batch_gen[1300,t,:,:,:], cmap='gray')
        plt.axis('off')
    plt.show()
    print('Generated video (15 frames)')
    fig = plt.figure(figsize=(16, 4))
    plt.subplots_adjust(wspace=0.1, hspace=0)
    for t in range(15):
        x_out, states = model.svg_cell.generate(x_out, states)
        x_out_numpy = x_out.numpy()

        plt.subplot(1, 15, t+1)
        plt.imshow(x_out[1300,:,:,:], cmap='gray')
        plt.axis('off')
        for i in range(num_test_videos):
            ssim[i, t] = ssim_metric(x_batch_gen_numpy[i,t,:,:,0],
                                     x_out_numpy[i,:,:,0])
    plt.show()
    model.svg_cell.batch_starts = False

In [None]:
idx = np.argmax(np.mean(ssim, axis=1))
print(idx)

In [None]:
# We record SSIM on all test videos 100 times

num_test_videos = 2000

ssim = np.zeros((100, num_test_videos, 15))

x_out_record = np.zeros((100, 1, 15, 64, 64, 1))

for samp in range(100):
    for x_batch in test_dataset.take(1):
        x_batch = x_batch[:num_test_videos,:,:,:,:]
        # Initialize with the first frame and condition on the rest like in training
        model.svg_cell.x_tm1 = x_batch[:,0,:,:,:]
        x_batch = x_batch[:,1:,:,:,:]

        # Split x_batch into conditioned ones, and ones that model doesn't know
        x_batch_cond, x_batch_gen = tf.split(
            x_batch, 
            num_or_size_splits=[5-1, 
                                15], 
            axis=1)
        x_batch_gen_numpy = x_batch_gen.numpy()

        # Run on the conditioned frames:
        out = model(x_batch_cond)
        mean, logvar, mean_0, logvar_0, z, x_recons, \
        lstm_q_states, lstm_prior_states, \
        lstm_dec_1_states, lstm_dec_2_states = out
        states = [lstm_q_states, lstm_prior_states, lstm_dec_1_states, lstm_dec_2_states]
        x_out = x_recons[:,5-2,:,:,:]
        x_out_record_tmp = np.zeros((num_test_videos, 15, 64, 64, 1))
        for t in range(15):
            x_out, states = model.svg_cell.generate(x_out, states)
            x_out_numpy = x_out.numpy()
            x_out_record_tmp[:,t,:,:,:] = x_out_numpy

            for i in range(num_test_videos):
                ssim[samp, i, t] = ssim_metric(x_batch_gen_numpy[i,t,:,:,0],
                                         x_out_numpy[i,:,:,0])
        model.svg_cell.batch_starts = False            
    idx = np.argmax(np.mean(ssim[samp,:,:], axis=1))
    x_out_record[samp,:,:,:,:,:] = x_out_record_tmp[idx,:,:,:,:]  
    print(samp)
    
with open('results/x_out.pickle', 'wb') as handle:
    pickle.dump(x_out_record, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open('results/ssim_test.pickle', 'wb') as handle:
    pickle.dump(ssim, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
# Plot 100 times on a chosen video

video_idx = 257
ssim = np.zeros((100, 15))
model.svg_cell.batch_starts = False

for samp in range(100):

    for x_batch in test_dataset.take(1):
        x_batch = tf.expand_dims(x_batch[video_idx,:,:,:,:], axis=0)

        model.svg_cell.x_tm1 = x_batch[:,0,:,:,:]
        x_batch = x_batch[:,1:,:,:,:]

        # Split x_batch into conditioned ones, and ones that model doesn't know
        x_batch_cond, x_batch_gen = tf.split(
            x_batch, 
            num_or_size_splits=[5-1, 
                                15], 
            axis=1)
        x_batch_gen_numpy = x_batch_gen.numpy()

        # Run on the conditioned frames:
        out = model(x_batch_cond)
        mean, logvar, mean_0, logvar_0, z, x_recons, \
        lstm_q_states, lstm_prior_states, \
        lstm_dec_1_states, lstm_dec_2_states = out
        states = [lstm_q_states, lstm_prior_states, lstm_dec_1_states, lstm_dec_2_states]
        x_out = x_recons[:,5-2,:,:,:]

        # Generate:
        print('\n------- Generating -------+', samp)
        print('Original video (15 frames)')
        fig = plt.figure(figsize=(16, 4))
        plt.subplots_adjust(wspace=0.1, hspace=0)
        for t in range(15):
            plt.subplot(1, 15, t+1)
            plt.imshow(x_batch_gen[0,t,:,:,:], cmap='gray')
            plt.axis('off')
        plt.show()
        print('Generated video (15 frames)')
        fig = plt.figure(figsize=(16, 4))
        plt.subplots_adjust(wspace=0.1, hspace=0)
        for t in range(15):
            x_out, states = model.svg_cell.generate(x_out, states)
            x_out_numpy = x_out.numpy()
            ssim[samp, t] = ssim_metric(x_batch_gen_numpy[0,t,:,:,0],
                                        x_out_numpy[0,:,:,0])

            plt.subplot(1, 15, t+1)
            plt.imshow(x_out[0,:,:,:], cmap='gray')
            plt.axis('off')
        plt.show()
        model.svg_cell.batch_starts = False  


In [None]:
np.argmax(np.mean(ssim, axis=1))

In [None]:
plt.plot(list(range(1000)),model.loss_list)
plt.title('Loss value across epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.savefig('figs/loss.png')
plt.show()

In [None]:
ssim = pd.read_pickle('ssim_test_model3.pickle')

plt.plot(list(range(5, 20)),np.mean(ssim, axis=(0,1)))
plt.title('Mean SSIM across time-steps')
plt.xlabel('Time-steps')
plt.ylabel('Mean SSIM')
plt.savefig('figs/ssim.png')
plt.show()

In [None]:
# Look at the test videos

for x_batch in test_dataset.take(1):
    for frame in range(2000):
        # Plot:
        print('Original video (15 frames)', frame)
        fig = plt.figure(figsize=(16, 2))
        plt.subplots_adjust(wspace=0.1, hspace=0)
        for t in range(15):
            plt.subplot(1, 15, t+1)
            plt.imshow(x_batch[frame, t+5,:,:,:], cmap='gray')
            plt.axis('off')
        plt.show()
 


In [4]:
!sudo apt-get install tree
!tree ./ >> README.md

Reading package lists... Done
Building dependency tree       
Reading state information... Done
tree is already the newest version (1.7.0-5).
The following package was automatically installed and is no longer required:
  libnuma1
Use 'sudo apt autoremove' to remove it.
0 upgraded, 0 newly installed, 0 to remove and 41 not upgraded.
