# SRGAN

In [1]:
import os
import matplotlib.pyplot as plt

from data import DIV2K
from model.srgan import generator, discriminator
from train import SrganTrainer, SrganGeneratorTrainer

%matplotlib inline

In [2]:
# Location of model weights (needed for demo)
weights_dir = 'weights/srgan'
weights_file = lambda filename: os.path.join(weights_dir, filename)

os.makedirs(weights_dir, exist_ok=True)

## Datasets

You don't need to download the DIV2K dataset as the required parts are automatically downloaded by the `DIV2K` class. By default, DIV2K images are stored in folder `.div2k` in the project's root directory.

In [3]:
div2k_train = DIV2K(scale=1, subset='train', downgrade='bicubic')
div2k_valid = DIV2K(scale=1, subset='valid', downgrade='bicubic')

In [4]:
train_ds = div2k_train.dataset(batch_size=1, random_transform=True)
valid_ds = div2k_valid.dataset(batch_size=1, random_transform=True, repeat_count=1)

Tensor("concat:0", shape=(None, None, 4), dtype=uint16)
Tensor("concat:0", shape=(None, None, 4), dtype=uint16)


## Training

### Pre-trained models

If you want to skip training and directly run the demo below, download [weights-srgan.tar.gz](https://drive.google.com/open?id=1u9ituA3ScttN9Vi-UkALmpO0dWQLm8Rv) and extract the archive in the project's root directory. This will create a folder `weights/srgan` containing the weights of the pre-trained models.

### Generator pre-training

In [5]:
model=generator()

In [6]:
model.summary(120)

Model: "model"
________________________________________________________________________________________________________________________
Layer (type)                           Output Shape               Param #       Connected to                            
input_1 (InputLayer)                   [(None, None, None, 4)]    0                                                     
________________________________________________________________________________________________________________________
lambda (Lambda)                        (None, None, None, 4)      0             input_1[0][0]                           
________________________________________________________________________________________________________________________
conv2d (Conv2D)                        (None, None, None, 64)     20800         lambda[0][0]                            
________________________________________________________________________________________________________________________
p_re_lu (PReLU)  

In [7]:
pre_trainer = SrganGeneratorTrainer(model=generator(), checkpoint_dir=f'.ckpt/pre_generator')
pre_trainer.train(train_ds,
                  valid_ds.take(10),
                  steps=1000, 
                  evaluate_every=100, 
                  save_best_only=False)

pre_trainer.model.save_weights(weights_file('pre_generator.h5'))

Model restored from checkpoint at step 1000.


### Generator fine-tuning (GAN)

In [8]:
train_ds = div2k_train.dataset(batch_size=1, random_transform=True)
valid_ds = div2k_valid.dataset(batch_size=1, random_transform=True, repeat_count=1)

Tensor("concat:0", shape=(None, None, 4), dtype=uint16)
Tensor("concat:0", shape=(None, None, 4), dtype=uint16)


In [9]:
gan_generator = generator()
gan_generator.load_weights(weights_file('pre_generator.h5'))

gan_trainer = SrganTrainer(generator=gan_generator, discriminator=discriminator())
gan_trainer.train(train_ds, steps=20000)

Downloading data from https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5


ValueError: in converted code:

    E:\Anna\super-resolution-master\super-resolution-master\train.py:443 train_step  *
        hr_output = self.discriminator(hr, training=True)
    C:\Users\DL\AppData\Local\conda\conda\envs\sisr\lib\site-packages\tensorflow_core\python\keras\engine\base_layer.py:778 __call__
        outputs = call_fn(cast_inputs, *args, **kwargs)
    C:\Users\DL\AppData\Local\conda\conda\envs\sisr\lib\site-packages\tensorflow_core\python\keras\engine\network.py:717 call
        convert_kwargs_to_constants=base_layer_utils.call_context().saving)
    C:\Users\DL\AppData\Local\conda\conda\envs\sisr\lib\site-packages\tensorflow_core\python\keras\engine\network.py:891 _run_internal_graph
        output_tensors = layer(computed_tensors, **kwargs)
    C:\Users\DL\AppData\Local\conda\conda\envs\sisr\lib\site-packages\tensorflow_core\python\keras\engine\base_layer.py:737 __call__
        self.name)
    C:\Users\DL\AppData\Local\conda\conda\envs\sisr\lib\site-packages\tensorflow_core\python\keras\engine\input_spec.py:213 assert_input_compatibility
        ' but received input with shape ' + str(shape))

    ValueError: Input 0 of layer conv2d_111 is incompatible with the layer: expected axis -1 of input shape to have value 4 but received input with shape [1, 512, 512, 1]


In [None]:
gan_trainer.generator.save_weights(weights_file('gan_generator.h5'))
gan_trainer.discriminator.save_weights(weights_file('gan_discriminator.h5'))

## Demo

In [None]:
pre_generator = generator()
gan_generator = generator()

pre_generator.load_weights(weights_file('pre_generator.h5'))
gan_generator.load_weights(weights_file('gan_generator.h5'))

In [None]:
from model import resolve_single
from utils import load_image

def resolve_and_plot(lr_image_path):
    lr = load_image(lr_image_path)
    
    pre_sr = resolve_single(pre_generator, lr)
    gan_sr = resolve_single(gan_generator, lr)
    
    plt.figure(figsize=(20, 20))
    
    images = [lr, pre_sr, gan_sr]
    titles = ['LR', 'SR (PRE)', 'SR (GAN)']
    positions = [1, 3, 4]
    
    for i, (img, title, pos) in enumerate(zip(images, titles, positions)):
        plt.subplot(2, 2, pos)
        plt.imshow(img)
        plt.title(title)
        plt.xticks([])
        plt.yticks([])

In [None]:
resolve_and_plot('demo/0869x4-crop.png')

In [None]:
resolve_and_plot('demo/0829x4-crop.png')

In [None]:
resolve_and_plot('demo/0851x4-crop.png')