# EDSR

In [1]:
import os
import matplotlib.pyplot as plt

from data import DIV2K
from model.rcan import rcan
from train import RcanTrainer

%matplotlib inline

In [2]:
# Number of residual group
num_res_groups = 10
num_res_blocks = 20
reduction = 16
# Super-resolution factor
scale = 2

# Downgrade operator
downgrade = 'bicubic'

In [3]:
# Location of model weights (needed for demo)
weights_dir = f'weights/rcan-{num_res_groups}-x{scale}'
weights_file = os.path.join(weights_dir, 'weights.h5')

os.makedirs(weights_dir, exist_ok=True)

## Datasets

You don't need to download the DIV2K dataset as the required parts are automatically downloaded by the `DIV2K` class. By default, DIV2K images are stored in folder `.div2k` in the project's root directory.

In [4]:
div2k_train = DIV2K(scale=scale, subset='train', downgrade=downgrade)
div2k_valid = DIV2K(scale=scale, subset='valid', downgrade=downgrade)

In [5]:
train_ds = div2k_train.dataset(batch_size=8, random_transform=True)
valid_ds = div2k_valid.dataset(batch_size=1, random_transform=False, repeat_count=1)

## Training

### Pre-trained models

If you want to skip training and directly run the demo below, download [weights-edsr-16-x4.tar.gz](https://martin-krasser.de/sisr/weights-edsr-16-x4.tar.gz) and extract the archive in the project's root directory. This will create a `weights/edsr-16-x4` directory containing the weights of the pre-trained model.

In [6]:
trainer = RcanTrainer(model=rcan(scale=scale, num_res_groups=num_res_groups, num_res_blocks=num_res_blocks, reduction=reduction),
                      checkpoint_dir=f'.ckpt/rcan-{num_res_groups}-x{scale}')

In [None]:
# Train EDSR model for 300,000 steps and evaluate model
# every 1000 steps on the first 10 images of the DIV2K
# validation set. Save a checkpoint only if evaluation
# PSNR has improved.
trainer.train(train_ds,
              valid_ds.take(10),
              steps=300000, 
              evaluate_every=100, 
              save_best_only=True)

100/300000: loss = 60.038, PSNR = 19.191755 (91.64s)
200/300000: loss = 20.544, PSNR = 22.447956 (73.58s)
300/300000: loss = 15.383, PSNR = 24.013599 (71.89s)
400/300000: loss = 12.067, PSNR = 26.232159 (72.09s)
500/300000: loss = 10.324, PSNR = 24.266876 (72.30s)
600/300000: loss = 9.807, PSNR = 27.037313 (72.48s)
700/300000: loss = 8.649, PSNR = 27.878475 (72.23s)
800/300000: loss = 7.916, PSNR = 29.124737 (72.40s)
900/300000: loss = 7.503, PSNR = 28.659077 (72.89s)
1000/300000: loss = 6.926, PSNR = 29.769510 (71.73s)
1100/300000: loss = 6.604, PSNR = 30.689838 (72.24s)
1200/300000: loss = 6.490, PSNR = 31.058743 (73.29s)
1300/300000: loss = 6.198, PSNR = 31.230560 (71.86s)
1400/300000: loss = 5.994, PSNR = 31.641571 (72.06s)
1500/300000: loss = 5.706, PSNR = 31.197399 (72.16s)
1600/300000: loss = 5.766, PSNR = 31.374786 (72.54s)
1700/300000: loss = 5.833, PSNR = 31.794195 (72.39s)
1800/300000: loss = 5.521, PSNR = 31.596537 (70.99s)
1900/300000: loss = 5.483, PSNR = 31.516510 (72.61

15500/300000: loss = 4.027, PSNR = 34.399986 (13.75s)
15600/300000: loss = 3.799, PSNR = 34.543385 (13.64s)
15700/300000: loss = 3.920, PSNR = 33.619179 (14.04s)
15800/300000: loss = 4.005, PSNR = 34.299309 (13.99s)
15900/300000: loss = 4.250, PSNR = 34.493389 (13.66s)
16000/300000: loss = 4.152, PSNR = 34.472939 (14.15s)
16100/300000: loss = 4.023, PSNR = 34.592316 (13.66s)
16200/300000: loss = 3.778, PSNR = 34.242882 (13.97s)
16300/300000: loss = 3.788, PSNR = 34.567432 (15.79s)
16400/300000: loss = 3.895, PSNR = 34.513355 (15.80s)
16500/300000: loss = 3.809, PSNR = 34.526775 (14.45s)
16600/300000: loss = 3.733, PSNR = 34.662342 (15.79s)
16700/300000: loss = 3.650, PSNR = 34.487518 (16.00s)
16800/300000: loss = 3.896, PSNR = 34.666599 (13.50s)
16900/300000: loss = 3.760, PSNR = 34.688522 (13.63s)


In [None]:
# Restore from checkpoint with highest PSNR
trainer.restore()

In [None]:
# Evaluate model on full validation set
psnrv = trainer.evaluate(valid_ds)
print(f'PSNR = {psnrv.numpy():3f}')

In [None]:
# Save weights to separate location (needed for demo)
trainer.model.save_weights(weights_file)

## Demo

In [None]:
model = edsr(scale=scale, num_res_blocks=depth)
model.load_weights(weights_file)

In [None]:
from model import resolve_single
from utils import load_image, plot_sample
import tensorflow as tf

def resolve_and_plot(lr_image_path):
    lr = load_image(lr_image_path)
    sr = resolve_single(model, lr)
    plot_sample(lr, sr)
    
def save_output(lr_image_path):
    filename = lr_image_path.split("/")[-1].split(".")[:-1][0]
    lr = load_image(lr_image_path)
    sr = resolve_single(model, lr)
    tf.keras.preprocessing.image.save_img("output/"+filename+'_edsr.png',sr)

In [None]:
save_output('demo/raw.jpg')

In [None]:
resolve_and_plot('demo/raw.jpg')

In [None]:
resolve_and_plot('demo/raw2.jpg')

In [None]:
resolve_and_plot('demo/0851x4-crop.png')