In [None]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, LeakyReLU, Dense, Add, UpSampling2D, Reshape, Conv2DTranspose, Flatten, Add, Concatenate, Lambda, MaxPool2D
from tensorflow.keras.models import Model
import datetime
import matplotlib.pyplot as plt
import pandas as pd
import tensorflow_datasets as tfds
import os
import time

In [None]:
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

Your runtime has 54.8 gigabytes of available RAM

You are using a high-RAM runtime!


In [None]:
ds = tfds.load('div2k/bicubic_x2', shuffle_files=True)

Downloading and preparing dataset 4.68 GiB (download: 4.68 GiB, generated: Unknown size, total: 4.68 GiB) to /root/tensorflow_datasets/div2k/bicubic_x2/2.0.0...
EXTRACTING {'train_lr_url': 'https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_bicubic_X2.zip', 'valid_lr_url': 'https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_bicubic_X2.zip', 'train_hr_url': 'https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip', 'valid_hr_url': 'https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip'}


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/2 [00:00<?, ? splits/s]

Generating train examples...:   0%|          | 0/800 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/div2k/bicubic_x2/2.0.0.incompleteN8BFQI/div2k-train.tfrecord*...:   0%|   …

Generating validation examples...:   0%|          | 0/100 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/div2k/bicubic_x2/2.0.0.incompleteN8BFQI/div2k-validation.tfrecord*...:   0…

Dataset div2k downloaded and prepared to /root/tensorflow_datasets/div2k/bicubic_x2/2.0.0. Subsequent calls will reuse this data.


In [None]:
LR_SHAPE = (96,96,3)
HR_SHAPE = (192,192,3)
def resize_and_normalize(sample):
    sample['hr'] = tf.image.resize(sample['hr'], [HR_SHAPE[0], HR_SHAPE[1]])
    sample['lr'] = tf.image.resize(sample['lr'], [LR_SHAPE[0], LR_SHAPE[1]])

    sample['hr'] = tf.cast(sample['hr'], tf.float32) / 255.0
    sample['lr'] = tf.cast(sample['lr'], tf.float32) / 255.0

    return sample

ds_train_resized = ds["train"].map(resize_and_normalize)
ds_train_batched = ds_train_resized.batch(8)

ds_valid_resized = ds["validation"].map(resize_and_normalize)
ds_valid_batched = ds_valid_resized.batch(8)



In [None]:
from tensorflow.keras.layers import Layer
from tensorflow.keras.utils import get_custom_objects
from tensorflow.nn import depth_to_space

class SubpixelConv2D(Layer):
    """ Subpixel Conv2D Layer

    upsampling a layer from (h, w, c) to (h*r, w*r, c/(r*r)),
    where r is the scaling factor, default to 4

    # Arguments
    upsampling_factor: the scaling factor

    # Input shape
        Arbitrary. Use the keyword argument `input_shape`
        (tuple of integers, does not include the samples axis)
        when using this layer as the first layer in a model.

    # Output shape
        the second and the third dimension increased by a factor of
        `upsampling_factor`; the last layer decreased by a factor of
        `upsampling_factor^2`.

    # References
        Real-Time Single Image and Video Super-Resolution Using an Efficient
        Sub-Pixel Convolutional Neural Network Shi et Al. https://arxiv.org/abs/1609.05158
    """

    def __init__(self, upsampling_factor=4, **kwargs):
        super(SubpixelConv2D, self).__init__(**kwargs)
        self.upsampling_factor = upsampling_factor

    def build(self, input_shape):
        last_dim = input_shape[-1]
        factor = self.upsampling_factor * self.upsampling_factor
        if last_dim % (factor) != 0:
            raise ValueError('Channel ' + str(last_dim) + ' should be of '
                             'integer times of upsampling_factor^2: ' +
                             str(factor) + '.')

    def call(self, inputs, **kwargs):
        return depth_to_space( inputs, self.upsampling_factor )

    def get_config(self):
        config = { 'upsampling_factor': self.upsampling_factor, }
        base_config = super(SubpixelConv2D, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        factor = self.upsampling_factor * self.upsampling_factor
        input_shape_1 = None
        if input_shape[1] is not None:
            input_shape_1 = input_shape[1] * self.upsampling_factor
        input_shape_2 = None
        if input_shape[2] is not None:
            input_shape_2 = input_shape[2] * self.upsampling_factor
        dims = [ input_shape[0],
                 input_shape_1,
                 input_shape_2,
                 int(input_shape[3]/factor)
               ]
        return tuple( dims )

get_custom_objects().update({'SubpixelConv2D': SubpixelConv2D})

In [None]:


def dense_block(inpt):
    """
    Dense block containes total 4 conv blocks with leakyRelu
    activation, followed by post conv layer
    Params: tensorflow layer
    Returns: tensorflow layer
    """
    b1 = Conv2D(128, kernel_size=3, strides=1, padding='same')(inpt)
    b1 = BatchNormalization()(b1)
    b1 = LeakyReLU(0.2)(b1)
    b1 = Concatenate()([inpt,b1])


    b2 = Conv2D(128, kernel_size=3, strides=1, padding='same')(b1)
    b2 = BatchNormalization()(b2)
    b2 = LeakyReLU(0.2)(b2)
    b2 = Concatenate()([inpt,b1,b2])

    b3 = Conv2D(128, kernel_size=3, strides=1, padding='same')(b2)
    b3 = BatchNormalization()(b3)
    b3 = LeakyReLU(0.2)(b3)
    b3 = Concatenate()([inpt,b1,b2,b3])

    b4 = Conv2D(128, kernel_size=3, strides=1, padding='same')(b3)
    b4 = BatchNormalization()(b4)
    b4 = LeakyReLU(0.2)(b4)
    b4 = Concatenate()([inpt,b1,b2,b3,b4])

    b5 = Conv2D(128, kernel_size=3, strides=1, padding='same')(b4)
    b5 = Lambda(lambda x:x*0.2)(b5)
    b5 = Add()([b5, inpt])

    return b5

def RRDB(inpt):
    x = dense_block(inpt)
    x = dense_block(x)
    x = dense_block(x)
    x = Lambda(lambda x:x*0.2)(x)
    out = Add()([x,inpt])
    return out


def make_generator_model():

    input = Input(shape=LR_SHAPE)
    x = Conv2D(128, 3, 1, padding="same")(input)
    x = RRDB(x)
    x = RRDB(x)
    x = RRDB(x)
    x = SubpixelConv2D(upsampling_factor=2)(x)
    #x = Conv2DTranspose(3, 1, strides=(2,2))(x)
    x = Conv2D(3, 3, 1, padding="same", activation="sigmoid")(x)

    return Model(input, x)

def make_baseline_generator():
  input = Input(shape=LR_SHAPE)
  x = Conv2D(128, 3, 1, padding="same")(input)
  x = RRDB(x)


  x = SubpixelConv2D(upsampling_factor=2)(x)

  return Model(input, x)




In [None]:

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def disc_block(inpt):
    x = Conv2D(64, (3,3), strides=(1,1), padding="same")(inpt)
    x = LeakyReLU()(x)
    x = Conv2D(64, (3,3), strides=(1,1), padding="same" )(x)
    x = LeakyReLU()(x)
    x = MaxPool2D()(x)

    x = Conv2D(128, (3,3), strides=(1,1), padding="same")(inpt)
    x = LeakyReLU()(x)
    x = Conv2D(128, (3,3), strides=(1,1), padding="same")(x)
    x = LeakyReLU()(x)
    x = MaxPool2D()(x)



    return x


def make_discriminator_model():
    input = Input(shape=HR_SHAPE)
    x = Conv2D(64, (3,3), strides=(1,1), padding="same")(input)
    x = LeakyReLU()(x)

    x = disc_block(x)
    x = disc_block(x)
    x = disc_block(x)

    x = disc_block(x)
    x = disc_block(x)
    x = disc_block(x)

    x = Flatten()(x)
    x = Dense(1)(x)

    return Model(input, x)

def make_baseline_discriminator():
    input = Input(shape=HR_SHAPE)
    x = Conv2D(64, (3,3), strides=(1,1), padding="same")(input)
    x = LeakyReLU()(x)

    x = MaxPool2D()(x)

    x = Flatten()(x)

    x = Dense(1)(x)
    return Model(input, x)

discriminator_model = make_discriminator_model()
discriminator_model.summary()



Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 192, 192, 3)]     0         
                                                                 
 conv2d (Conv2D)             (None, 192, 192, 64)      1792      
                                                                 
 leaky_re_lu (LeakyReLU)     (None, 192, 192, 64)      0         
                                                                 
 conv2d_3 (Conv2D)           (None, 192, 192, 128)     73856     
                                                                 
 leaky_re_lu_3 (LeakyReLU)   (None, 192, 192, 128)     0         
                                                                 
 conv2d_4 (Conv2D)           (None, 192, 192, 128)     147584    
                                                                 
 leaky_re_lu_4 (LeakyReLU)   (None, 192, 192, 128)     0     

In [None]:
from tensorflow.keras.applications import VGG19

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    return real_loss + fake_loss

def content_loss(y_true, y_pred):
    return tf.reduce_mean(tf.square(y_true - y_pred))

vgg = VGG19(include_top=False, weights='imagenet', input_shape=HR_SHAPE)
vgg.trainable = False

for layer in vgg.layers:
        layer.trainable = False

model = Model(inputs=vgg.input, outputs=vgg.get_layer('block5_conv4').output)
model.trainable = False

def ssim_loss(y_true, y_pred):
    return 1 - tf.reduce_mean(tf.image.ssim(y_true, y_pred, 1.0))


def perceptual_loss(y_true, y_pred):
    global model
    return tf.reduce_mean(tf.square(model(y_true) - model(y_pred)))

def generator_loss(fake_output, real_images, fake_images):
    adv_loss = cross_entropy(tf.ones_like(fake_output), fake_output)
    pct_loss = perceptual_loss(real_images, fake_images)
    cont_loss = content_loss(real_images, fake_images)
    # print(pct_loss, content_loss)
    # 0.05*pct_loss
    return 0.7*adv_loss + 2*pct_loss + 0.4*cont_loss

generator_optimizer = tf.keras.optimizers.Adam(learning_rate= 0.0001)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate= 0.00001)


Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg19/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5


In [None]:
gen_model = make_generator_model()

checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=gen_model,
                                 discriminator=discriminator_model)


In [None]:
EPOCHS = 5

def train_step(lr_image_batch, hr_image_batch, step):

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_images = gen_model(lr_image_batch, training=True)
        # val_gen_images = gen_model(val_lr_image_batch, training=False)


        training = True

        real_output = discriminator_model(hr_image_batch, training=training)
        fake_output = discriminator_model(gen_images, training=training)
        # real_val_output = discriminator_model(val_hr_image_batch, training=False)
        # fake_val_output = discriminator_model(val_gen_images, training=False)
        if training:

          disc_loss = discriminator_loss(real_output, fake_output)
        gen_loss = generator_loss(fake_output, hr_image_batch, gen_images)
        #print(gen_loss)
        # val_gen_loss = generator_loss(fake_val_output, val_hr_image_batch, val_gen_images)


        print("gen_loss: ", gen_loss.numpy())
        if training:
           print("disc_loss", disc_loss.numpy())

    with summary_writer.as_default():
            tf.summary.scalar('Generator Loss', gen_loss, step=step)
            if training:
              tf.summary.scalar("Discriminator Loss", disc_loss, step=step)
    gradient_gen = gen_tape.gradient(gen_loss, gen_model.trainable_variables)
    if training:
      gradient_disc = disc_tape.gradient(disc_loss, discriminator_model.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradient_gen, gen_model.trainable_variables))
    if training:
      discriminator_optimizer.apply_gradients(zip(gradient_disc, discriminator_model.trainable_variables))
    return gen_loss.numpy()



In [None]:
last_losses = []
def train(dataset, epochs,callbacks=None, earlystopping=(False,0)):
    global last_losses
    for epoch in range(epochs):
        start = time.time()
        early_stopped = False
        for n, iter_item in enumerate(iter(dataset)):
            hr_image_batch = iter_item["hr"]
            lr_image_batch = iter_item["lr"]
            # val_hr_image_batch = iter_item[1]["hr"]
            # val_lr_image_batch = iter_item[1]["lr"]
            if n > 2:
              break
            last_losses.append(train_step(lr_image_batch, hr_image_batch, n))
            flag = True
            for i in last_losses[1:]:
                if i >= last_losses[0]:
                    continue
                else:
                    flag = False
                    break
            if flag and len(last_losses) == earlystopping[1]:
                early_stopped = True
                break
            if callbacks:
                for callback in callbacks:
                    callback.on_batch_end(n)
            print(n+1, "batches complete in epoch", epoch)
            if len(last_losses) > earlystopping[1]-1:
                last_losses = last_losses[1:]
        if early_stopped:
            break

        checkpoint.save(checkpoint_prefix)
        print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))




In [None]:

current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
train_log_dir = 'logs/gradient_tape/' + current_time + '/train'
summary_writer = tf.summary.create_file_writer("./logs")

# Inside your train_step function

from tensorflow.keras.callbacks import TensorBoard
tensorboard_callback = TensorBoard(log_dir="/logs")

gen_model.summary()


Model: "model_2"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_3 (InputLayer)        [(None, 96, 96, 3)]          0         []                            
                                                                                                  
 conv2d_25 (Conv2D)          (None, 96, 96, 128)          3584      ['input_3[0][0]']             
                                                                                                  
 conv2d_26 (Conv2D)          (None, 96, 96, 128)          147584    ['conv2d_25[0][0]']           
                                                                                                  
 batch_normalization (Batch  (None, 96, 96, 128)          512       ['conv2d_26[0][0]']           
 Normalization)                                                                             

In [None]:
discriminator_model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 192, 192, 3)]     0         
                                                                 
 conv2d (Conv2D)             (None, 192, 192, 64)      1792      
                                                                 
 leaky_re_lu (LeakyReLU)     (None, 192, 192, 64)      0         
                                                                 
 conv2d_3 (Conv2D)           (None, 192, 192, 128)     73856     
                                                                 
 leaky_re_lu_3 (LeakyReLU)   (None, 192, 192, 128)     0         
                                                                 
 conv2d_4 (Conv2D)           (None, 192, 192, 128)     147584    
                                                                 
 leaky_re_lu_4 (LeakyReLU)   (None, 192, 192, 128)     0     

In [None]:
train(ds_train_batched, 10, callbacks=[tensorboard_callback])

gen_loss:  0.5179947
disc_loss 1.3841791


ResourceExhaustedError: ignored

In [None]:
x = gen_model.get_layer("conv2d_18").output
output_layer = Conv2D(3,3,1,activation="sigmoid")(x)
visualising_model = Model(gen_model.input, output_layer)

In [None]:

for n, iter_item in enumerate(iter(ds_train_batched)):
            hr_image_batch = iter_item["hr"]
            lr_image_batch = iter_item["lr"]
            example_image = lr_image_batch[0]
            example_hr_image =  hr_image_batch[0]
            break
plt.imshow(example_image)



In [None]:
plt.imshow(example_hr_image)


In [None]:
# Assuming the latest checkpoint is ckpt-5 and you want to go back to ckpt-3
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir + '/ckpt-1'))
test_gen_image = gen_model(tf.expand_dims(example_image, axis=0), training=False)[0]

plt.imshow(test_gen_image)


In [None]:
example_image= next(iter(ds_valid_batched))["lr"][0]
plt.imshow(example_image)


In [None]:
test_gen_image = gen_model(tf.expand_dims(example_image, axis=0), training=False)[0]

plt.imshow(test_gen_image)