# Super-resolution

## Setup

In [None]:
%matplotlib inline

import warnings
with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=FutureWarning)
    import h5py

import bcolz
import keras.backend as K
import matplotlib.pyplot as plt
import numpy as np
import scipy
import tensorflow as tf

from keras import layers
from keras.engine import InputSpec
from keras.layers import Input, InputLayer, Lambda, Layer
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.layers.core import Activation
from keras.layers.normalization import BatchNormalization
from keras.models import Model, Sequential
from keras.utils.data_utils import get_file

from PIL import Image

### Limit memory used by Tensorflow

In [None]:
cfg = K.tf.ConfigProto()
cfg.gpu_options.allow_growth = True
K.set_session(K.tf.Session(config=cfg))

### Pre/post processing functions

In [None]:
vgg_mean = np.array([123.68, 116.779, 103.939], dtype=np.float32).reshape((1, 1, 3))

# Function to subtract imagenet mean and transpose RGB to BGR
preproc = lambda x: (x - vgg_mean)[:, :, :, ::-1]

### Define convolutional part of VGG16 model, with lamdba layer to pre-process input

In [None]:
vgg_input_layer = InputLayer((288, 288, 3))

def add_convolutional_layers(model):
    blocks = [
        (2, 64),
        (2, 128),
        (3, 256),
        (3, 512),
        (3, 512)]
    for b in range(len(blocks)):
        block = blocks[b]
        layers = block[0]
        filters = block[1]
        prefix = 'block' + str(b + 1)
        for i in range(layers):
            name = prefix + '_conv' + str(i + 1)
            model.add(Conv2D(filters, (3, 3), activation='relu', padding='same', name=name))
        name = prefix + '_pool'
        model.add(MaxPooling2D((2, 2), strides=(2, 2), name=name))

vgg = Sequential()
vgg.add(vgg_input_layer)
vgg.add(Lambda(preproc, name='lambda'))
add_convolutional_layers(vgg)
for layer in vgg.layers:
    layer.trainable=False
vgg.summary()

### Load weights

In [None]:
repo = 'https://github.com/fchollet/deep-learning-models'
weights_url = repo + '/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5'
local_name = 'vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5'
weights_path = get_file(local_name, weights_url, cache_subdir='models')
vgg.load_weights(weights_path)

### Loading training data

In [None]:
num_images = 2000
arr_hr = bcolz.open('data/imagenet/sample/resized-288.bc')[:num_images]
shp = arr_hr.shape[1:]
shp

### Load style

In [None]:
style = Image.open('data/neural-style/starry_night.png')
style_arr = np.array(style)[:shp[0], :shp[1], :shp[2]]
plt.imshow(style_arr)

## Use content loss to create a style-transfer network

### Reflection padding layer

In [None]:
class ReflectionPadding2D(Layer):
    def __init__(self, padding=(1, 1), **kwargs):
        self.padding = tuple(padding)
        self.input_spec = [InputSpec(ndim=4)]
        super(ReflectionPadding2D, self).__init__(**kwargs)
        
    def compute_output_shape(self, s):
        return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])

    def call(self, x, mask=None):
        w_pad,h_pad = self.padding
        return tf.pad(x, [[0,0], [h_pad,h_pad], [w_pad,w_pad], [0,0] ], 'REFLECT')
    
inp = Input((288,288,3))
ref_model = Model(inp, ReflectionPadding2D((40,10))(inp))
ref_model.compile('adam', 'mse')
p = ref_model.predict(arr_hr[10:11])
plt.imshow(p[0].astype('uint8'));

### Define upsampling network

In [None]:
def conv_block(x, filters, size, stride=(2,2), mode='same', act=True):
    x = Conv2D(filters, (size, size), strides=stride, padding=mode)(x)
    x = BatchNormalization()(x)
    return Activation('relu')(x) if act else x

def res_crop_block(ip, nf=64):
    x = conv_block(ip, nf, 3, (1,1), 'valid')
    x = conv_block(x,  nf, 3, (1,1), 'valid', False)
    ip = Lambda(lambda x: x[:, 2:-2, 2:-2])(ip)
    return layers.add([x, ip])

def up_block(x, filters, size):
    x = layers.UpSampling2D()(x)
    x = Conv2D(filters, (size, size), padding='same')(x)
    x = BatchNormalization()(x)
    return Activation('relu')(x)

inp = Input(shp)
x = ReflectionPadding2D((40, 40))(inp)
x = conv_block(x, 64, 9, (1,1))
x = conv_block(x, 64, 3)
x = conv_block(x, 64, 3)
for i in range(5): 
    x = res_crop_block(x)
x = up_block(x, 64, 3)
x = up_block(x, 64, 3)
x = Conv2D(3, (9, 9), activation='tanh', padding='same')(x)
outp = Lambda(lambda x: (x + 1) * 127.5)(x)
outp

In [None]:
def get_outp(m, ln):
    return m.get_layer('block{}_conv1'.format(ln)).output

vgg_content = Model(vgg_input_layer.input, [get_outp(vgg, o) for o in [2, 3, 4, 5]])
vgg_content.summary()

In [None]:
vgg1 = vgg_content(vgg_input_layer.input)
vgg2 = vgg_content(outp)

### Define loss function

In [None]:
style_targs = [K.variable(o) for o in
               vgg_content.predict(np.expand_dims(style_arr, 0))]

[K.eval(K.shape(o)) for o in style_targs]

In [None]:
def mean_sqr_b(diff): 
    dims = list(range(1, K.ndim(diff)))
    return K.expand_dims(K.sqrt(K.mean(diff**2, dims)), 0)

def gram_matrix_b(x):
    x = K.permute_dimensions(x, (0, 3, 1, 2))
    s = K.shape(x)
    feat = K.reshape(x, (s[0], s[1], s[2]*s[3]))
    return K.batch_dot(feat, K.permute_dimensions(feat, (0, 2, 1))) / K.prod(K.cast(s[1:], K.floatx()))

w = [0.1, 0.2, 0.6, 0.1]

def tot_loss(x):
    loss = 0; n = len(style_targs)
    for i in range(n):
        loss += mean_sqr_b(gram_matrix_b(x[i+n]) - gram_matrix_b(style_targs[i])) / 2.
        loss += mean_sqr_b(x[i]-x[i+n]) * w[i]
    return loss

loss = Lambda(tot_loss)(vgg1 + vgg2)
m_style = Model([inp, vgg_input_layer.input], loss)
target = np.zeros((arr_hr.shape[0], 1))
m_style.compile('adam', 'mae')
epochs = 8
batch_size = 12

In [None]:
m_style.fit([arr_hr, arr_hr], target, batch_size=batch_size, epochs=epochs, verbose=2)

### Extract part of the model that we want

In [None]:
top_model = Model(inp, outp)

In [None]:
p = top_model.predict(arr_hr[:20])
plt.imshow(p[0].astype('uint8'));

In [None]:
plt.imshow(arr_hr[0].astype('uint8'));