In [1]:
import numpy as np
from keras.models import Model
from tensorflow.python.keras.models import load_model
from keras.layers import Input, BatchNormalization, Dense, Add, Activation, Reshape, Permute, Flatten, Conv2DTranspose
from keras.layers.convolutional import Conv2D, UpSampling2D
from keras.layers.advanced_activations import PReLU, LeakyReLU
from keras.optimizers import Adam
from keras.applications import VGG19
import datetime
import random
from PIL import Image
import glob
import cv2
import os
import matplotlib.pyplot as plt


ratio = 4
LR_shape = (120, 160, 3)

L_h, L_w, channels = LR_shape
H_h = L_h * ratio
H_w = L_w * ratio
HR_shape = (H_h, H_w, channels)

optimizer = Adam()

Using TensorFlow backend.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
def load_data(batch_size):

    files = glob.glob("images/train/*.png", recursive=True)
    batch_images = random.sample(files, batch_size)

    hr_imgs = []
    lr_imgs = []
    for img_path in batch_images:
        img = Image.open(img_path)

        hr_img = img.resize((H_w, H_h))  #(64, 64)
        lr_img = img.resize((L_w, L_h))
        hr_img = np.array(hr_img)
        lr_img = np.array(lr_img)

        hr_imgs.append(hr_img)
        lr_imgs.append(lr_img)

    hr_imgs = np.array(hr_imgs) / 127.5 - 1.
    lr_imgs = np.array(lr_imgs) / 127.5 - 1.

    return hr_imgs, lr_imgs
      

In [3]:

def pixel_shuffle(in_map, h, w, c):
    
    x = Reshape((h, w, 2, 2, c))(in_map)
    x = Permute((3, 1, 4, 2, 5))(x)
    out_map = Reshape((2 * h, 2 * w, c))(x)
    
    return out_map


def upsampling(in_map, h, w, c):
    
    x = Conv2D(filters = 4 * c, 
                     kernel_size = 3,
                     strides = 1,
                     padding = "same")(in_map)
    x = pixel_shuffle(x, h, w, c)
    out_map = PReLU()(x)
    
    return out_map


def residual_block(in_map):
    x = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = "same")(in_map)
    x = LeakyReLU(alpha = 0)(x)
    x = BatchNormalization()(x)
    x = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = "same")(x)
    x = BatchNormalization()(x)
    out_map = Add()([x, in_map])
    return out_map


def d_block(in_map, filters, kernel_size, strides, padding):
    d = Conv2D(filters = filters, kernel_size = kernel_size, strides = strides, padding = padding)(in_map)
    d = LeakyReLU(alpha = 0.2)(d)
    d = BatchNormalization(momentum = 0.8)(d)
    return d


def deconv2d(layer_input):
    """Layers used during upsampling"""
    u = Conv2DTranspose(120, kernel_size = 3,
                        strides=2, padding='same')(layer_input)
    u = Activation('relu')(u)
    return u



In [4]:
def build_generator():
    input_img = Input(shape = LR_shape)
    middle = Conv2D(filters = 64, kernel_size = 9, strides = 1, padding = "same")(input_img)
    middle = LeakyReLU(alpha = 0)(middle)
    
    g = residual_block(middle)
    for _ in range(4):
        g = residual_block(g)

    g = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = "same")(g)
    g = BatchNormalization()(g)
    g = Add()([g, middle])

    n = ratio
    i = 1
    while(n % 2 == 0):
        g = deconv2d(g)
        n = n // 2

    output_img = Conv2D(filters = 3, kernel_size = 9, strides = 1, padding = "same")(g)

    return Model(input_img, output_img)


def build_discriminator():
    input_img = Input(shape = HR_shape)
    
    d = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = "same")(input_img)
    d = LeakyReLU(alpha=0.2)(d)
    d = d_block(d, filters = 64, kernel_size = 3, strides = 2, padding = "same")
    d = d_block(d, filters = 128, kernel_size = 3, strides = 1, padding = "same")
    d = d_block(d, filters = 128, kernel_size = 3, strides = 2, padding = "same")
    d = d_block(d, filters = 256, kernel_size = 3, strides = 1, padding = "same")
    d = d_block(d, filters = 256, kernel_size = 3, strides = 2, padding = "same")
    d = d_block(d, filters = 512, kernel_size = 3, strides = 1, padding = "same")
    d = d_block(d, filters = 512, kernel_size = 3, strides = 2, padding = "same")
#     d = Flatten()(d)
    d = Dense(512)(d)
    d = LeakyReLU(alpha = 0.2)(d)
    output = Dense(1, activation = "sigmoid")(d)

    return Model(input_img, output)


def build_vgg():
    vgg = VGG19(include_top = False)
    return Model(vgg.input, vgg.layers[9].output)
    

def combined(generator, discriminator, vgg):
    input_img = Input(shape = LR_shape)
    fake_img = generator(input_img)
    
    validity = discriminator(fake_img)
    features = vgg(fake_img)
    
    return Model(input_img, [validity, features])

In [5]:
losses = []
epochs_checkpoint = []

def train(epochs, batch_size, interval):
    
    start_time = datetime.datetime.now()
    
    real = np.ones((batch_size,) + (H_h // 16, H_w // 16, 1))
    fake = np.zeros((batch_size,) + (H_h // 16, H_w // 16, 1))
    
    for epoch in range(epochs):
        real_imgs, lr_imgs = load_data(batch_size)
        fake_imgs = generator.predict(lr_imgs)
        
        #Dの訓練
        d_loss_real = discriminator.train_on_batch(real_imgs, real)
        d_loss_fake = discriminator.train_on_batch(fake_imgs, fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
        
        #Gの訓練
        vgg_features = vgg.predict(real_imgs)
        g_loss = srgan.train_on_batch(lr_imgs, [real, vgg_features])
        
        time = datetime.datetime.now() - start_time
        print("%d time: %s" % (epoch+1, time))
        
        if (epoch+1) % interval == 0:
            losses.append((d_loss, g_loss))
            epochs_checkpoint.append(epoch+1)
            generator.save("weights/weight.h5")
            print("save weights")
    

In [6]:
discriminator = build_discriminator()
discriminator.compile(loss = "mse",
                      optimizer = optimizer,
                      metrics = ["accuracy"])
discriminator.summary()

generator = build_generator()
generator.summary()





_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 480, 640, 3)       0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 480, 640, 64)      1792      
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 480, 640, 64)      0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 240, 320, 64)      36928     
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 240, 320, 64)      0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 240, 320, 64)      256       
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 240, 320, 128)     73856     
______

In [7]:
vgg = build_vgg()
vgg.trainable = False
discriminator.trainable = False
srgan = combined(generator, discriminator, vgg)
srgan.compile(loss=['binary_crossentropy', 'mse'],
                              loss_weights=[1e-3, 1],
                              optimizer=optimizer)
srgan.summary()




Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
input_4 (InputLayer)             (None, 120, 160, 3)   0                                            
____________________________________________________________________________________________________
model_2 (Model)                  (None, 480, 640, 3)   652763                                       
____________________________________________________________________________________________________
model_1 (Model)                  (None, 30, 40, 1)     4955969                                      
____________________________________________________________________________________________________
model_3 (Model)                  multiple              1735488                                 

In [9]:
train(epochs = 3000, batch_size = 1, interval = 300)