# Training of the autoencoders using the IQA metrics

### Imports and definitions

In [None]:
import os

from IPython.display import clear_output

import numpy as np
import time
import scipy.io
import matplotlib.pyplot as plt
from scipy import stats
import tensorflow.keras.backend as K

import tensorflow as tf

import random
import itertools

import pickle

import numbers
from numpy.lib.stride_tricks import as_strided

tf.enable_eager_execution()

In [None]:
import utils
import models
import importlib

importlib.reload(utils)
importlib.reload(models)

In [None]:
tid_loc = "../tid2013"

In [None]:
scores_tid = utils.read_scores("{0}/mos.txt".format(tid_loc)).astype(np.float32)
ref_images = utils.read_images("{0}/reference_images".format(tid_loc))
ref_images = ref_images.astype(np.float32)

def key_fun(x):
    splitted = x.split("_")
    return 10000 * int(splitted[0][1:]) + 100 * int(splitted[1]) + int(splitted[2][:-4])
image_names_tid = sorted(os.listdir("{0}/distorted_images".format(tid_loc)), key=key_fun)

### Creating the autoencoder

In [None]:
# IMPORTANT: Due to the way Tensorflow works these two cells need to be run when training a new
# autoencoder. If the training step method are not redefined for the new generator exception
# will be thrown.
generator = make_generator()
generator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)
generator_loss = tf.keras.losses.MeanSquaredError()

In [None]:
ref_tensor = tf.convert_to_tensor(ref_images, tf.float32)
@tf.function
def gen_train_step_r(metric, _lambda=1):
    with tf.GradientTape() as gen_tape:
        generated_images = generator(tf.convert_to_tensor(ref_images, tf.float32))
        in_list = []
        in_list_ref = []
        for j in range(64):
            ul1 = np.random.randint(384-32)
            ul2 = np.random.randint(512-32)
            new_in = generated_images[:, ul1:ul1+32, ul2:ul2+32, :]
            in_list.append(new_in)
            new_in_ref = ref_tensor[:, ul1:ul1+32, ul2:ul2+32, :]
            in_list_ref.append(new_in_ref)

        input_tensor = tf.stack(in_list, 1)
        input_tensor_ref = tf.stack(in_list_ref, 1)
        scores_output = metric([tf.cast(input_tensor, tf.float32), tf.cast(input_tensor_ref, tf.float32)])
        gen_loss = -tf.math.reduce_mean(scores_output) + _lambda * generator_loss(generated_images, ref_images.astype(np.float32))

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_weights)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_weights))

@tf.function
def gen_train_step_error():
    with tf.GradientTape() as gen_tape:
        generated_images = generator(tf.convert_to_tensor(ref_images, tf.float32))
        gen_loss = generator_loss(generated_images, ref_images.astype(np.float32))

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_weights)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_weights))

# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def gen_train_step(metrics, _lambda=1):
    with tf.GradientTape() as gen_tape:
        generated_images = generator(tf.convert_to_tensor(ref_images))
        patches_gen = tf.image.extract_patches(generated_images, [1,32,32,1], [1,32,32,1], [1,1,1,1], 'SAME')
        patches_gen = tf.reshape(patches_gen, (25, 192, 32, 32, 3))
        patches_ref = tf.image.extract_patches(ref_images, [1,32,32,1], [1,32,32,1], [1,1,1,1], 'SAME')
        patches_ref = tf.reshape(patches_ref, (25, 192, 32, 32, 3))
        #patches_gen, patches_ref = extract_random_patches(generated_images, ref_images, 192, 32)
        scores_output = metrics[0]([tf.cast(patches_gen, tf.float32), tf.cast(patches_ref, tf.float32)])
        for i in range(len(metrics) - 1):
            scores_output += metrics[i+1]([tf.cast(patches_gen, tf.float32), tf.cast(patches_ref, tf.float32)])
        scores_output /= len(metrics)
        gen_loss = -tf.math.reduce_mean(scores_output) + _lambda * generator_loss(generated_images, ref_images.astype(np.float32))

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_weights)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_weights))

@tf.function
def gen_train_step_full(metrics, _lambda=1):
    with tf.GradientTape() as gen_tape:
        generated_images = generator(tf.convert_to_tensor(ref_images, tf.float32))
        scores_output = metrics[0]([tf.cast(generated_images, tf.float32), tf.cast(ref_images, tf.float32)])
        for i in range(len(metrics) - 1):
            scores_output += metrics[i+1]([tf.cast(generated_images, tf.float32), tf.cast(ref_images, tf.float32)])
        scores_output /= len(metrics)
        gen_loss = -tf.math.reduce_mean(scores_output) + _lambda * generator_loss(generated_images, ref_images.astype(np.float32))

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_weights)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_weights))

@tf.function
def gen_train_step_metric(metrics):
    with tf.GradientTape() as gen_tape:
        generated_images = generator(tf.convert_to_tensor(ref_images))
        patches_gen = tf.image.extract_patches(generated_images, [1,32,32,1], [1,32,32,1], [1,1,1,1], 'SAME')
        patches_gen = tf.reshape(patches_gen, (25, 192, 32, 32, 3))
        patches_ref = tf.image.extract_patches(ref_images, [1,32,32,1], [1,32,32,1], [1,1,1,1], 'SAME')
        patches_ref = tf.reshape(patches_ref, (25, 192, 32, 32, 3))
        #patches_gen, patches_ref = extract_random_patches(generated_images, ref_images, 192, 32)
        scores_output = metrics[0]([tf.cast(patches_gen, tf.float32), tf.cast(patches_ref, tf.float32)])
        for i in range(len(metrics) - 1):
            scores_output += metrics[i+1]([tf.cast(patches_gen, tf.float32), tf.cast(patches_ref, tf.float32)])
        scores_output /= len(metrics)
        gen_loss = -tf.math.reduce_mean(scores_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_weights)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_weights))

## Grid patches

The number of iterations used in the following cells is too small to reach convergence, this is to demonstrate the way the code works not to obtain results.

In [None]:
# Training for grid patches using loss L = -IQA and a WaDIQaM model with sigmoid
model = tf.keras.models.load_model("saved_weights/metric_wadiqam.h5")
for i in range(500):
    if i % 50 == 0:
        print("Iteration: " + str(i))
    gen_train_step_metric([model])

In [None]:
# Training for grid patches using loss L = lambda * MSE - IQA and a WaDIQaM model with sigmoid
model = tf.keras.models.load_model("saved_weights/metric_wadiqam.h5")
l = np.float32(0.01)
for i in range(1000):
    if i % 500 == 0:
        print("Iteration: " + str(i))
    gen_train_step_full([model], l)

predictions = generator(ref_tensor)

## Adversarial training

In [None]:
gen = utils.tid_batch_generator(tid_loc, ref_images_tid, scores_tid, image_names_tid, False, full_images=True)
counter = 0
i = 0
model = models.get_wadiqam_model()
model.load_weights("saved_weights/wadiqam.hdf5")
while i < 3000:
    im, score = next(gen)
    dist_tensor = tf.convert_to_tensor(np.reshape(im["x"], (1, 384, 512, 3)))
    ref_tensor = tf.convert_to_tensor(np.reshape(im["y"], (1, 384, 512, 3)))
    patches_gen = tf.image.extract_patches(dist_tensor, [1,32,32,1], [1,32,32,1], [1,1,1,1], 'SAME')
    patches_gen = tf.reshape(patches_gen, (1, 192, 32, 32, 3))
    patches_ref = tf.image.extract_patches(ref_tensor, [1,32,32,1], [1,32,32,1], [1,1,1,1], 'SAME')
    patches_ref = tf.reshape(patches_ref, (1, 192, 32, 32, 3))
    cont = True
    while cont:
        # Perform gradient ascent on the distorted images until the score has improved enough
        clear_output(wait=True)
        with tf.GradientTape() as gen_tape:
            gen_tape.watch(dist_tensor)
            scores_output = model([tf.cast(patches_gen, tf.float32), tf.cast(patches_ref, tf.float32)])
            print(scores_output)
            print("{0}: {1}".format(i, np.mean(scores_output)))
            if np.mean(scores_output) > np.mean(score) + 1:
                cont = False
        dist_tensor += 35*gen_tape.gradient(scores_output, dist_tensor)

    # Store the new images
    mae = 0
    for j in range(25):
        diff_im = dist_tensor.numpy()[j] - im[0][j]
        mae += np.mean(np.abs(diff_im))
        res = np.clip((im[0][j] + (10 * diff_im)) / 255.0, 0, 1)
        plt.imsave("adversarial_ims/test{0}.png".format(counter), res)
        counter += 1
    print(mae / 25.0)
    i += 1

## Random patches