# Video GAN


In [0]:
import cv2
import glob
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf
import time

In [0]:
from google.colab import drive
drive.mount('/content/drive/')

## Settings

In [0]:
# Video read settings
VIDEO_DIR = '/content/drive/My Drive/Colab Data/video-gan'
INPUT_SIZE = (240, 320)
VIDEO_SIZE = (64, 64)
FRAME_INT = 1
FRAME_CAP = 32

# Training parameters
BUFFER_SIZE = 1000
BATCH_SIZE = 32
EPOCHS = 1000
Z_DIM = 100
DISC_ITERATIONS = 5
GEN_ITERATIONS = 1

# Adam optimizer
LEARNING_RATE = 0.0002
BETA1 = 0.5

# Output frequency
SAMPLE_RATE = 100
SAVE_RATE = 50
NUM_OUT = 1

# Use eager execution
tf.enable_eager_execution()

## Video Processing

### Extract frames

In [0]:
videos = glob.glob(os.path.join(VIDEO_DIR, 'inputs/*.avi'))
all_image_paths = glob.glob(os.path.join(VIDEO_DIR, 'inputs/*.jpg'))

if len(videos) > len(all_image_paths):
    # For each video in directory, capture every frame_int number of frames and 
    # store in array where each frame is stacked horizontally.
    for vnum, video in enumerate(videos):
        description = os.path.splitext(video)[0]
        vidcap = cv2.VideoCapture(os.path.join(VIDEO_DIR, video))
        success, image = vidcap.read()
        output = np.zeros((FRAME_CAP * image.shape[0], image.shape[1], image.shape[2]))
        loc, frames = 0, 0
        while success and frames < FRAME_CAP:
            output[frames * image.shape[0]:(frames + 1) * image.shape[0]] = image
            loc += FRAME_INT
            frames += 1
            vidcap.set(cv2.CAP_PROP_POS_MSEC, loc)
            success, image = vidcap.read()
        INPUT_SIZE = image.shape[:2]
        if frames == FRAME_CAP:
            cv2.imwrite(os.path.join(VIDEO_DIR, 'inputs', description + str(vnum) + '.jpg'), np.float32(output))
    vidcap.release()
    all_image_paths = glob.glob(os.path.join(VIDEO_DIR, 'inputs/*.jpg'))

### Read frames into tf data object

In [0]:
# Reads video image, decodes into tensor, resized to desired shape.
def parse_video(filename):
    image_string = tf.read_file(filename)
    image_decoded = tf.cast(tf.image.decode_jpeg(image_string, channels=3), tf.float32)
    frames = tf.reshape(image_decoded, [-1, INPUT_SIZE[0], INPUT_SIZE[1], 3])
    image_resized = tf.image.resize_images(frames, VIDEO_SIZE)
    return tf.subtract(tf.math.divide(image_resized, 127.5), 1.0)

# File name vector.
num_use = int(len(all_image_paths)/BATCH_SIZE)*BATCH_SIZE
all_image_paths = [str(path) for path in all_image_paths[:num_use]]

# Construct dataset.
dataset = tf.data.Dataset.from_tensor_slices(all_image_paths).shuffle(BUFFER_SIZE)
dataset = dataset.map(parse_video).batch(BATCH_SIZE)

## Utilities

In [0]:
def write_avi(frames, directory, name='', frate=24):
    writer = cv2.VideoWriter(os.path.join(directory, name), 
                             cv2.VideoWriter_fourcc('X', 'V', 'I', 'D'),
                             frate, VIDEO_SIZE)
    frames_concat = None
    for f_num, frame in enumerate(frames):
        if frames_concat is None:
            frames_concat = frame[0]
        else:
            frames_concat = np.hstack((frames_concat, frame[0]))
        writer.write(frame[0])
    cv2.imwrite(os.path.join(directory, 'epoch' + name + '.jpg'), frames_concat)
    writer.release()

    
def convert_image(images, samples):
    images = tf.cast(np.clip(((images + 1.0) * 127.5), 0, 255), tf.uint8)
    images = [tf.squeeze(image).numpy() for image in tf.split(images, samples, axis=1)]
    return images

## Model

In [0]:
class VideoGAN():
  
    def __init__(self,
                 input,
                 batch_size,
                 num_frames,
                 crop_size,
                 learning_rate,
                 z_dim,
                 conv_init,
                 beta1,
                 disc_iterations,
                 gen_iterations,
                 num_out,
                 epochs,
                 save_int,
                 save_dir,
                 save_checkpts=True):
        self.videos = input
        self.batch_size = batch_size
        self.num_frames = num_frames
        self.crop_size = crop_size
        self.learning_rate = learning_rate
        self.z_dim = z_dim
        self.conv_init = conv_init
        self.beta1 = beta1
        self.disc_iterations = disc_iterations
        self.gen_iterations = gen_iterations
        self.num_out = num_out
        self.epochs = epochs
        self.save_int = save_int
        self.save_dir = save_dir
        self.save_checkpts = save_checkpts
                
        self.generator = self.generator_model()
        self.discriminator = self.discriminator_model()

        self.gen_optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate, beta1=self.beta1, beta2=0.999)
        self.disc_optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate, beta1=self.beta1, beta2=0.999)
        
        self.checkpoint = tf.train.Checkpoint(generator_optimizer=self.gen_optimizer,
                                              discriminator_optimizer=self.disc_optimizer,
                                              generator=self.generator,
                                              discriminator=self.discriminator)

    def train_step(self, videos):
        # Generate noise from normal distribution
        noise = tf.random_normal([self.batch_size, self.z_dim])

        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            generated_videos = self.generator(noise, training=True)
            
            real_disc = self.discriminator(videos, training=True)
            generated_disc = self.discriminator(generated_videos, training=True)
            print('Disc scores real/fake:', tf.reduce_mean(real_disc), tf.reduce_mean(generated_disc))

            gen_loss = self.generator_loss(generated_disc)
            print('Gen loss:', gen_loss)
            disc_loss = self.discriminator_loss(videos, generated_videos, real_disc, generated_disc)
            print('Disc loss:', disc_loss)
            
        gradients_of_generator = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
        gradients_of_discriminator = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)
        
        for iter in range(self.gen_iterations):
            self.gen_optimizer.apply_gradients(zip(gradients_of_generator, self.generator.trainable_variables))

        for iter in range(self.disc_iterations):
            self.disc_optimizer.apply_gradients(zip(gradients_of_discriminator, self.discriminator.trainable_variables))       
       
    def train(self):
        # Plot model structure
        tf.keras.utils.plot_model(self.generator, show_shapes=True, 
                                  to_file=os.path.join(self.save_dir, 'gen.jpg'))
        tf.keras.utils.plot_model(self.discriminator, show_shapes=True,
                                  to_file=os.path.join(self.save_dir, 'disc.jpg'))
        
        # Generate noise from normal distribution
        for epoch in range(self.epochs):
            start = time.time()

            for batch in self.videos:
                self.train_step(batch)

            # Save every n intervals
            if (epoch + 1) % self.save_int == 0:
                self.generate(self.generator, epoch + 1, self.num_out)
                if self.save_checkpts:
                    self.checkpoint.save(file_prefix = os.path.join(VIDEO_DIR, "ckpt"))

            print ('Time taken for epoch {} is {} sec'.format(epoch + 1,
                                                              time.time()-start))
        # Generate samples after final epoch
        self.generate(self.generator, self.epochs, self.num_out)

    def generator_model(self):
        model = tf.keras.Sequential()
        
        # Linear block
        model.add(tf.keras.layers.Dense(self.crop_size*8 * 4 * 4 * 2, input_shape=(self.z_dim,),
                                        kernel_initializer=tf.keras.initializers.random_normal(stddev=0.01)))
        model.add(tf.keras.layers.Reshape((2, 4, 4, self.crop_size*8)))
        model.add(tf.keras.layers.BatchNormalization())
        model.add(tf.keras.layers.ReLU())
        
        # Convolution block 1
        model.add(tf.keras.layers.Conv3DTranspose(filters=self.crop_size*4, kernel_size=4, strides=2, padding='same', 
                                                  kernel_initializer=self.conv_init, use_bias=True))
        model.add(tf.keras.layers.BatchNormalization())
        model.add(tf.keras.layers.ReLU())
    
        # Convolution block 2
        model.add(tf.keras.layers.Conv3DTranspose(filters=self.crop_size*2, kernel_size=4, strides=2, padding='same', 
                                                  kernel_initializer=self.conv_init, use_bias=True))
        model.add(tf.keras.layers.BatchNormalization())
        model.add(tf.keras.layers.ReLU())
    
        # Convolution block 3
        model.add(tf.keras.layers.Conv3DTranspose(filters=self.crop_size, kernel_size=4, strides=2, padding='same', 
                                                  kernel_initializer=self.conv_init, use_bias=True))
        model.add(tf.keras.layers.BatchNormalization())
        model.add(tf.keras.layers.ReLU())
    
        # Convolution block 4
        model.add(tf.keras.layers.Conv3DTranspose(filters=3, kernel_size=4, strides=2, padding='same', 
                                                  kernel_initializer=self.conv_init, use_bias=True, activation='tanh'))

        return model

    def discriminator_model(self):
        model = tf.keras.Sequential()
        
        # Convolution block 1
        model.add(tf.keras.layers.Conv3D(filters=self.crop_size, 
                                         input_shape=(self.num_frames, self.crop_size, self.crop_size, 3),
                                         kernel_size=4, strides=2, padding='same', kernel_initializer=self.conv_init))
        model.add(tf.keras.layers.Lambda(lambda x: tf.contrib.layers.layer_norm(x)))
        model.add(tf.keras.layers.LeakyReLU(.2))
                  
        # Convolution block 2
        model.add(tf.keras.layers.Conv3D(filters=self.crop_size*2, kernel_size=4, strides=2, padding='same',
                                         kernel_initializer=self.conv_init))
        model.add(tf.keras.layers.Lambda(lambda x: tf.contrib.layers.layer_norm(x)))
        model.add(tf.keras.layers.LeakyReLU(.2))
  
        # Convolution block 3
        model.add(tf.keras.layers.Conv3D(filters=self.crop_size*4, kernel_size=4, strides=2, padding='same',
                                         kernel_initializer=self.conv_init))
        model.add(tf.keras.layers.Lambda(lambda x: tf.contrib.layers.layer_norm(x)))
        model.add(tf.keras.layers.LeakyReLU(.2))
        
        # Convolution block 4
        model.add(tf.keras.layers.Conv3D(filters=self.crop_size*8, kernel_size=4, strides=2, padding='same',
                                         kernel_initializer=self.conv_init))
        model.add(tf.keras.layers.Lambda(lambda x: tf.contrib.layers.layer_norm(x)))
        model.add(tf.keras.layers.LeakyReLU(.2))
        
        # Convolution block 5
        model.add(tf.keras.layers.Conv3D(filters=1, kernel_size=4, strides=2, padding='same',
                                         kernel_initializer=self.conv_init))
        model.add(tf.keras.layers.LeakyReLU(.2))
                  
        # Linear block
        model.add(tf.keras.layers.Flatten())
        model.add(tf.keras.layers.Dense(1, kernel_initializer=tf.keras.initializers.random_normal(stddev=0.01)))

        return model

    def generator_loss(self, generated_disc):
        # WGAN-GP loss: https://arxiv.org/pdf/1704.00028.pdf
        # Negative so that gradient descent maximizes critic score received by generated output
        return -tf.reduce_mean(generated_disc)

    def discriminator_loss(self, real_videos, generated_videos, real_disc, generated_disc):
        # WGAN-GP loss: https://arxiv.org/pdf/1704.00028.pdf
        # Difference between critic scores received by generated output vs real video
        # Lower values mean that the real video samples are receiving higher scores, therefore
        # gradient descent maximizes discriminator accuracy
        d_cost = tf.reduce_mean(generated_disc) - tf.reduce_mean(real_disc)
        alpha = tf.random_uniform(
            shape=[self.batch_size, 1],
            minval=0.,
            maxval=1.
        )
        dim = self.num_frames * self.crop_size * self.crop_size * 3
        real = tf.reshape(real_videos, [self.batch_size, dim])
        fake = tf.reshape(generated_videos, [self.batch_size, dim])
        diff = fake - real
        # Real videos adjusted by randomly weighted difference between real vs generated
        interpolates = real + (alpha*diff)
        with tf.GradientTape() as tape:
            tape.watch(interpolates)
            interpolates_reshaped = tf.reshape(interpolates,
                                   [self.batch_size, self.num_frames, 
                                    self.crop_size, self.crop_size, 3])
            interpolates_disc = self.discriminator(interpolates_reshaped)
        # Gradient of critic score wrt interpolated videos 
        gradients = tape.gradient(interpolates_disc, [interpolates])[0]
        # Euclidean norm of gradient for each sample
        norm = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1]))
        # Gradient norm penalty is the average distance from 1
        gradient_penalty = tf.reduce_mean((norm - 1.) ** 2)
        
        return d_cost + 10 * gradient_penalty

    def generate(self, model, epoch, num_out):
        gen_noise = tf.random_normal([num_out, self.z_dim])
        output = model(gen_noise, training=False)
        frames = [convert_image(output[:, i, :, :, :], self.num_out) for i in range(self.num_frames)]
        write_avi(frames, self.save_dir, name=str(epoch) + '.avi')

## Train

In [0]:
model = VideoGAN(dataset,
                 batch_size=BATCH_SIZE,
                 num_frames=FRAME_CAP,
                 crop_size=VIDEO_SIZE[0],
                 learning_rate=LEARNING_RATE,
                 z_dim=Z_DIM,
                 conv_init='he_normal',
                 beta1=BETA1,
                 disc_iterations=DISC_ITERATIONS,
                 gen_iterations=GEN_ITERATIONS,
                 epochs=EPOCHS,
                 num_out=NUM_OUT,
                 save_int=SAVE_RATE,
                 save_dir=VIDEO_DIR)
model.train()