# Super-resolution

## Setup

In [None]:
%matplotlib inline

import warnings
with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=FutureWarning)
    import h5py

import bcolz
import keras.backend as K
import matplotlib.pyplot as plt
import numpy as np
import scipy

from keras import layers
from keras.layers import Input, InputLayer, Lambda
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.layers.core import Activation
from keras.layers.normalization import BatchNormalization
from keras.models import Model, Sequential
from keras.utils.data_utils import get_file

### Limit memory used by Tensorflow

In [None]:
K.get_session().close()
cfg = K.tf.ConfigProto()
cfg.gpu_options.allow_growth = True
K.set_session(K.tf.Session(config=cfg))

### Pre/post processing functions

In [None]:
vgg_mean = np.array([123.68, 116.779, 103.939], dtype=np.float32).reshape((1, 1, 3))

# Function to subtract imagenet mean and transpose RGB to BGR
preproc = lambda x: (x - vgg_mean)[:, :, :, ::-1]

# Function to transpose BGR to RGB, add imagenet mean, then clip the result
deproc = lambda x,s: np.clip(x.reshape(s)[:, :, :, ::-1] + vgg_mean, 0, 255)

### Define convolutional part of VGG16 model, with lamdba layer to pre-process input

In [None]:
vgg_input_layer = InputLayer((288, 288, 3))

def add_convolutional_layers(model):
    blocks = [
        (2, 64),
        (2, 128),
        (3, 256),
        (3, 512),
        (3, 512)]
    for b in range(len(blocks)):
        block = blocks[b]
        layers = block[0]
        filters = block[1]
        prefix = 'block' + str(b + 1)
        for i in range(layers):
            name = prefix + '_conv' + str(i + 1)
            model.add(Conv2D(filters, (3, 3), activation='relu', padding='same', name=name))
        name = prefix + '_pool'
        model.add(MaxPooling2D((2, 2), strides=(2, 2), name=name))

vgg = Sequential()
vgg.add(vgg_input_layer)
vgg.add(Lambda(preproc, name='lambda'))
add_convolutional_layers(vgg)
for layer in vgg.layers:
    layer.trainable=False
vgg.summary()

### Load weights

In [None]:
repo = 'https://github.com/fchollet/deep-learning-models'
weights_url = repo + '/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5'
local_name = 'vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5'
weights_path = get_file(local_name, weights_url, cache_subdir='models')
vgg.load_weights(weights_path)

## Use content loss to create a super-resolution network

### Define upsampling network

In [None]:
def conv_block(x, filters, size, stride=(2,2), mode='same', act=True):
    x = Conv2D(filters, (size, size), strides=stride, padding=mode)(x)
    x = BatchNormalization()(x)
    return Activation('relu')(x) if act else x

def res_block(ip, nf=64):
    x = conv_block(ip, nf, 3, (1,1))
    x = conv_block(x, nf, 3, (1,1), act=False)
    return layers.add([x, ip])

def deconv_block(x, filters, size, shape, stride=(2,2)):
    x = Deconv2D(filters, (size, size), strides=stride, padding='same', output_shape=(None,) + shape)(x)
    x = BatchNormalization()(x)
    return Activation('relu')(x)

def up_block(x, filters, size):
    x = layers.UpSampling2D()(x)
    x = Conv2D(filters, (size, size), padding='same')(x)
    x = BatchNormalization()(x)
    return Activation('relu')(x)

inp = Input((72, 72, 3))
x = conv_block(inp, 64, 9, (1,1))
for i in range(4): 
    x = res_block(x)
x = up_block(x, 64, 3)
x = up_block(x, 64, 3)
x = Conv2D(3, (9, 9), activation='tanh', padding='same')(x)
outp = Lambda(lambda x: (x + 1) * 127.5)(x)
outp

### Define loss function

In [None]:
def get_outp(m, ln):
    return m.get_layer('block{}_conv1'.format(ln)).output

def mean_sqr_b(diff): 
    dims = list(range(1,K.ndim(diff)))
    return K.expand_dims(K.sqrt(K.mean(diff**2, dims)), 0)

w = [0.1, 0.8, 0.1]

def content_fn(x): 
    res = 0; 
    n = len(w)
    for i in range(n): 
        res += mean_sqr_b(x[i]-x[i+n]) * w[i]
    return res

vgg_content = Model(vgg_input_layer.input, [get_outp(vgg, o) for o in [1,2,3]])
vgg1 = vgg_content(vgg_input)
vgg2 = vgg_content(outp)

m_sr = Model(inputs=[inp, vgg_input], outputs=Lambda(content_fn)(vgg1 + vgg2))
m_sr.summary()

### Load training data

In [None]:
num_images = 2000
arr_lr = bcolz.open('data/super-resolution/trn_resized_72.bc')[:num_images]
arr_hr = bcolz.open('data/super-resolution/trn_resized_288.bc')[:num_images]
arr_lr.shape

In [None]:
m_sr.compile('adam', 'mse')
target = np.zeros((num_images, 1))
epochs = 8
print(m_sr.input_layers[0].is_placeholder)
print(m_sr.input_layers[1].is_placeholder)

In [None]:
m_sr.fit([arr_lr, arr_hr], target, epochs=epochs, verbose=2)

In [None]:
m_sr.fit([arr_lr, arr_hr], target, epochs=epochs, verbose=2)

### Extract part of the model that we want

In [None]:
top_model = Model(inp, outp)

In [None]:
p = top_model.predict(arr_lr[500:501])

In [None]:
plt.imshow(arr_lr[500].astype('uint8'));

In [None]:
plt.imshow(p[0].astype('uint8'));