### Imports

In [37]:
import os, time, pickle, random, time
from datetime import datetime
import numpy as np
from time import localtime, strftime
import logging, scipy
import tensorflow as tf
import tensorlayer as tl
from model import SRGAN_g, SRGAN_d, Vgg19_simple_api
from utils import *
from config import config, log_config

### Hyper Params

In [38]:
## Adam
batch_size = config.TRAIN.batch_size
lr_init = config.TRAIN.lr_init
beta1 = config.TRAIN.beta1
## initialize G
n_epoch_init = config.TRAIN.n_epoch_init
## adversarial learning (SRGAN)
n_epoch = config.TRAIN.n_epoch
lr_decay = config.TRAIN.lr_decay
decay_every = config.TRAIN.decay_every
ni = int(np.sqrt(batch_size))

### Dirs and Stuff

In [3]:
## create folders to save result images and trained model
save_dir_ginit = "samples/{}_ginit".format("srgan")
save_dir_gan = "samples/{}_gan".format("srgan")
tl.files.exists_or_mkdir(save_dir_ginit)
tl.files.exists_or_mkdir(save_dir_gan)
checkpoint_dir = "checkpoint"  # checkpoint_resize_conv
tl.files.exists_or_mkdir(checkpoint_dir)

[TL] [!] samples/srgan_ginit exists ...
[TL] [!] samples/srgan_gan exists ...
[TL] [!] checkpoint exists ...


True

### Preload Data because AWS

In [4]:
train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))
train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))
valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False))
valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False))

In [5]:
train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32)
train_lr_img = tl.vis.read_images(train_lr_img_list, path=config.TRAIN.lr_img_path, n_threads=32)
valid_hr_img = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32)
valid_lr_img = tl.vis.read_images(train_lr_img_list, path=config.VALID.lr_img_path, n_threads=32)

[TL] read 32 from devData/HR/
[TL] read 64 from devData/HR/
[TL] read 96 from devData/HR/
[TL] read 114 from devData/HR/
[TL] read 32 from devData/LR/
[TL] read 64 from devData/LR/
[TL] read 96 from devData/LR/
[TL] read 114 from devData/LR/
[TL] read 32 from devData/HR/
[TL] read 64 from devData/HR/
[TL] read 96 from devData/HR/
[TL] read 114 from devData/HR/
[TL] read 32 from devData/LR/
[TL] read 64 from devData/LR/
[TL] read 96 from devData/LR/
[TL] read 114 from devData/LR/


In [6]:
# Removing B&W images since it causes problems in Channel
for i,img in enumerate(train_hr_imgs):
    if len(img.shape) == 2:
        train_hr_imgs.pop(i)
for i,img in enumerate(train_lr_img):
    if len(img.shape) == 2:
        train_hr_imgs.pop(i)
for i,img in enumerate(valid_hr_img):
    if len(img.shape) == 2:
        train_hr_imgs.pop(i)
for i,img in enumerate(valid_lr_img):
    if len(img.shape) == 2:
        train_hr_imgs.pop(i)   

In [7]:
cropDim = 5000
for img in train_hr_imgs:
    minDim = min(img.shape[0:2])
    if cropDim > minDim :
        cropDim = minDim

In [8]:
cropDim -= 1+(cropDim-1)%4

In [9]:
sample_imgs = train_hr_imgs[0:batch_size]
#sample_imgs = tl.vis.read_images(train_hr_img_list[0:batch_size], path=config.TRAIN.hr_img_path, n_threads=32) # if no pre-load train set
sample_imgs_384 = tl.prepro.threading_data(sample_imgs, fn=crop_sub_imgs_fn,dim=cropDim,is_random=False)
print('sample HR sub-image:', sample_imgs_384.shape, sample_imgs_384.min(), sample_imgs_384.max())
sample_imgs_96 = tl.prepro.threading_data(sample_imgs_384,dim=int(cropDim/4),fn=downsample_fn)
print('sample LR sub-image:', sample_imgs_96.shape, sample_imgs_96.min(), sample_imgs_96.max())
tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_ginit + '/_train_sample_96.png')
tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_ginit + '/_train_sample_384.png')
tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_gan + '/_train_sample_96.png')
tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_gan + '/_train_sample_384.png')

sample HR sub-image: (25, 284, 284, 3) -1.0 1.0
sample LR sub-image: (25, 71, 71, 3) -1.0 1.0


In [10]:
t_image = tf.placeholder('float32', [batch_size, int(cropDim/4), int(cropDim/4), 3], name='t_image_input_to_SRGAN_generator')

In [11]:
t_target_image = tf.placeholder('float32', [batch_size, cropDim, cropDim, 3], name='t_target_image')

In [12]:
net_g = SRGAN_g(t_image, is_train=True, reuse=False)

[TL] InputLayer  SRGAN_g/in: (25, 71, 71, 3)
[TL] Conv2d SRGAN_g/n64s1/c: n_filter: 64 filter_size: (3, 3) strides: (1, 1) pad: SAME act: relu
Instructions for updating:
Colocations handled automatically by placer.
[TL] Conv2d SRGAN_g/n64s1/c1/0: n_filter: 64 filter_size: (3, 3) strides: (1, 1) pad: SAME act: No Activation
[TL] BatchNormLayer SRGAN_g/n64s1/b1/0: decay: 0.900000 epsilon: 0.000010 act: relu is_train: True
[TL] Conv2d SRGAN_g/n64s1/c2/0: n_filter: 64 filter_size: (3, 3) strides: (1, 1) pad: SAME act: No Activation
[TL] BatchNormLayer SRGAN_g/n64s1/b2/0: decay: 0.900000 epsilon: 0.000010 act: No Activation is_train: True
[TL] ElementwiseLayer SRGAN_g/b_residual_add/0: size: (25, 71, 71, 64) fn: add
[TL] Conv2d SRGAN_g/n64s1/c1/1: n_filter: 64 filter_size: (3, 3) strides: (1, 1) pad: SAME act: No Activation
[TL] BatchNormLayer SRGAN_g/n64s1/b1/1: decay: 0.900000 epsilon: 0.000010 act: relu is_train: True
[TL] Conv2d SRGAN_g/n64s1/c2/1: n_filter: 64 filter_size: (3, 3) strid

[TL] Conv2d SRGAN_g/n64s1/c2/15: n_filter: 64 filter_size: (3, 3) strides: (1, 1) pad: SAME act: No Activation
[TL] BatchNormLayer SRGAN_g/n64s1/b2/15: decay: 0.900000 epsilon: 0.000010 act: No Activation is_train: True
[TL] ElementwiseLayer SRGAN_g/b_residual_add/15: size: (25, 71, 71, 64) fn: add
[TL] Conv2d SRGAN_g/n64s1/c/m: n_filter: 64 filter_size: (3, 3) strides: (1, 1) pad: SAME act: No Activation
[TL] BatchNormLayer SRGAN_g/n64s1/b/m: decay: 0.900000 epsilon: 0.000010 act: No Activation is_train: True
[TL] ElementwiseLayer SRGAN_g/add3: size: (25, 71, 71, 64) fn: add
[TL] Conv2d SRGAN_g/n256s1/1: n_filter: 256 filter_size: (3, 3) strides: (1, 1) pad: SAME act: No Activation
[TL] SubpixelConv2d  SRGAN_g/pixelshufflerx2/1: scale: 2 n_out_channel: 64 act: relu
[TL] Conv2d SRGAN_g/n256s1/2: n_filter: 256 filter_size: (3, 3) strides: (1, 1) pad: SAME act: No Activation
[TL] SubpixelConv2d  SRGAN_g/pixelshufflerx2/2: scale: 2 n_out_channel: 64 act: relu
[TL] Conv2d SRGAN_g/out: n_fi

In [13]:
net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False)

Instructions for updating: TensorLayer relies on TensorFlow to check name reusing

[TL] InputLayer  SRGAN_d/input/images: (25, 284, 284, 3)
[TL] Conv2d SRGAN_d/h0/c: n_filter: 64 filter_size: (4, 4) strides: (2, 2) pad: SAME act: <lambda>
Instructions for updating: This API is deprecated. Please use as `tf.nn.leaky_relu`

[TL] Conv2d SRGAN_d/h1/c: n_filter: 128 filter_size: (4, 4) strides: (2, 2) pad: SAME act: No Activation
[TL] BatchNormLayer SRGAN_d/h1/bn: decay: 0.900000 epsilon: 0.000010 act: <lambda> is_train: True
[TL] Conv2d SRGAN_d/h2/c: n_filter: 256 filter_size: (4, 4) strides: (2, 2) pad: SAME act: No Activation
[TL] BatchNormLayer SRGAN_d/h2/bn: decay: 0.900000 epsilon: 0.000010 act: <lambda> is_train: True
[TL] Conv2d SRGAN_d/h3/c: n_filter: 512 filter_size: (4, 4) strides: (2, 2) pad: SAME act: No Activation
[TL] BatchNormLayer SRGAN_d/h3/bn: decay: 0.900000 epsilon: 0.000010 act: <lambda> is_train: True
[TL] Conv2d SRGAN_d/h4/c: n_filter: 1024 filter_size: (4, 4) stride

In [14]:
_, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True)

[TL] InputLayer  SRGAN_d/input/images: (25, 284, 284, 3)
[TL] Conv2d SRGAN_d/h0/c: n_filter: 64 filter_size: (4, 4) strides: (2, 2) pad: SAME act: <lambda>
[TL] Conv2d SRGAN_d/h1/c: n_filter: 128 filter_size: (4, 4) strides: (2, 2) pad: SAME act: No Activation
[TL] BatchNormLayer SRGAN_d/h1/bn: decay: 0.900000 epsilon: 0.000010 act: <lambda> is_train: True
[TL] Conv2d SRGAN_d/h2/c: n_filter: 256 filter_size: (4, 4) strides: (2, 2) pad: SAME act: No Activation
[TL] BatchNormLayer SRGAN_d/h2/bn: decay: 0.900000 epsilon: 0.000010 act: <lambda> is_train: True
[TL] Conv2d SRGAN_d/h3/c: n_filter: 512 filter_size: (4, 4) strides: (2, 2) pad: SAME act: No Activation
[TL] BatchNormLayer SRGAN_d/h3/bn: decay: 0.900000 epsilon: 0.000010 act: <lambda> is_train: True
[TL] Conv2d SRGAN_d/h4/c: n_filter: 1024 filter_size: (4, 4) strides: (2, 2) pad: SAME act: No Activation
[TL] BatchNormLayer SRGAN_d/h4/bn: decay: 0.900000 epsilon: 0.000010 act: <lambda> is_train: True
[TL] Conv2d SRGAN_d/h5/c: n_fil

In [15]:
net_g.print_params(False)
net_g.print_layers()
net_d.print_params(False)
net_d.print_layers()

[TL]   param   0: SRGAN_g/n64s1/c/kernel:0 (3, 3, 3, 64)      float32_ref
[TL]   param   1: SRGAN_g/n64s1/c/bias:0 (64,)              float32_ref
[TL]   param   2: SRGAN_g/n64s1/c1/0/kernel:0 (3, 3, 64, 64)     float32_ref
[TL]   param   3: SRGAN_g/n64s1/b1/0/beta:0 (64,)              float32_ref
[TL]   param   4: SRGAN_g/n64s1/b1/0/gamma:0 (64,)              float32_ref
[TL]   param   5: SRGAN_g/n64s1/b1/0/moving_mean:0 (64,)              float32_ref
[TL]   param   6: SRGAN_g/n64s1/b1/0/moving_variance:0 (64,)              float32_ref
[TL]   param   7: SRGAN_g/n64s1/c2/0/kernel:0 (3, 3, 64, 64)     float32_ref
[TL]   param   8: SRGAN_g/n64s1/b2/0/beta:0 (64,)              float32_ref
[TL]   param   9: SRGAN_g/n64s1/b2/0/gamma:0 (64,)              float32_ref
[TL]   param  10: SRGAN_g/n64s1/b2/0/moving_mean:0 (64,)              float32_ref
[TL]   param  11: SRGAN_g/n64s1/b2/0/moving_variance:0 (64,)              float32_ref
[TL]   param  12: SRGAN_g/n64s1/c1/1/kernel:0 (3, 3, 64, 64)  

[TL]   param 104: SRGAN_g/n64s1/b1/10/gamma:0 (64,)              float32_ref
[TL]   param 105: SRGAN_g/n64s1/b1/10/moving_mean:0 (64,)              float32_ref
[TL]   param 106: SRGAN_g/n64s1/b1/10/moving_variance:0 (64,)              float32_ref
[TL]   param 107: SRGAN_g/n64s1/c2/10/kernel:0 (3, 3, 64, 64)     float32_ref
[TL]   param 108: SRGAN_g/n64s1/b2/10/beta:0 (64,)              float32_ref
[TL]   param 109: SRGAN_g/n64s1/b2/10/gamma:0 (64,)              float32_ref
[TL]   param 110: SRGAN_g/n64s1/b2/10/moving_mean:0 (64,)              float32_ref
[TL]   param 111: SRGAN_g/n64s1/b2/10/moving_variance:0 (64,)              float32_ref
[TL]   param 112: SRGAN_g/n64s1/c1/11/kernel:0 (3, 3, 64, 64)     float32_ref
[TL]   param 113: SRGAN_g/n64s1/b1/11/beta:0 (64,)              float32_ref
[TL]   param 114: SRGAN_g/n64s1/b1/11/gamma:0 (64,)              float32_ref
[TL]   param 115: SRGAN_g/n64s1/b1/11/moving_mean:0 (64,)              float32_ref
[TL]   param 116: SRGAN_g/n64s1/b1/11/

[TL]   layer  36: SRGAN_g/b_residual_add/6:0 (25, 71, 71, 64)    float32
[TL]   layer  37: SRGAN_g/n64s1/c1/7/Conv2D:0 (25, 71, 71, 64)    float32
[TL]   layer  38: SRGAN_g/n64s1/b1/7/Relu:0 (25, 71, 71, 64)    float32
[TL]   layer  39: SRGAN_g/n64s1/c2/7/Conv2D:0 (25, 71, 71, 64)    float32
[TL]   layer  40: SRGAN_g/n64s1/b2/7/batchnorm/Add_1:0 (25, 71, 71, 64)    float32
[TL]   layer  41: SRGAN_g/b_residual_add/7:0 (25, 71, 71, 64)    float32
[TL]   layer  42: SRGAN_g/n64s1/c1/8/Conv2D:0 (25, 71, 71, 64)    float32
[TL]   layer  43: SRGAN_g/n64s1/b1/8/Relu:0 (25, 71, 71, 64)    float32
[TL]   layer  44: SRGAN_g/n64s1/c2/8/Conv2D:0 (25, 71, 71, 64)    float32
[TL]   layer  45: SRGAN_g/n64s1/b2/8/batchnorm/Add_1:0 (25, 71, 71, 64)    float32
[TL]   layer  46: SRGAN_g/b_residual_add/8:0 (25, 71, 71, 64)    float32
[TL]   layer  47: SRGAN_g/n64s1/c1/9/Conv2D:0 (25, 71, 71, 64)    float32
[TL]   layer  48: SRGAN_g/n64s1/b1/9/Relu:0 (25, 71, 71, 64)    float32
[TL]   layer  49: SRGAN_g/n64

[TL]   layer   1: SRGAN_d/h0/c/leaky_relu:0 (25, 142, 142, 64)    float32
[TL]   layer   2: SRGAN_d/h1/c/Conv2D:0 (25, 71, 71, 128)    float32
[TL]   layer   3: SRGAN_d/h1/bn/leaky_relu:0 (25, 71, 71, 128)    float32
[TL]   layer   4: SRGAN_d/h2/c/Conv2D:0 (25, 36, 36, 256)    float32
[TL]   layer   5: SRGAN_d/h2/bn/leaky_relu:0 (25, 36, 36, 256)    float32
[TL]   layer   6: SRGAN_d/h3/c/Conv2D:0 (25, 18, 18, 512)    float32
[TL]   layer   7: SRGAN_d/h3/bn/leaky_relu:0 (25, 18, 18, 512)    float32
[TL]   layer   8: SRGAN_d/h4/c/Conv2D:0 (25, 9, 9, 1024)    float32
[TL]   layer   9: SRGAN_d/h4/bn/leaky_relu:0 (25, 9, 9, 1024)    float32
[TL]   layer  10: SRGAN_d/h5/c/Conv2D:0 (25, 5, 5, 2048)    float32
[TL]   layer  11: SRGAN_d/h5/bn/leaky_relu:0 (25, 5, 5, 2048)    float32
[TL]   layer  12: SRGAN_d/h6/c/Conv2D:0 (25, 5, 5, 1024)    float32
[TL]   layer  13: SRGAN_d/h6/bn/leaky_relu:0 (25, 5, 5, 1024)    float32
[TL]   layer  14: SRGAN_d/h7/c/Conv2D:0 (25, 5, 5, 512)    float32
[TL]   

In [16]:
t_target_image_224 = tf.image.resize_images(t_target_image, size=[224, 224], method=0,align_corners=False)

In [17]:
t_predict_image_224 = tf.image.resize_images(net_g.outputs, size=[224, 224], method=0, align_corners=False)

In [18]:
net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224 + 1) / 2, reuse=False)
_, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2, reuse=True)

build model started
[TL] InputLayer  VGG19/input: (25, 224, 224, 3)
[TL] Conv2d VGG19/conv1_1: n_filter: 64 filter_size: (3, 3) strides: (1, 1) pad: SAME act: relu
[TL] Conv2d VGG19/conv1_2: n_filter: 64 filter_size: (3, 3) strides: (1, 1) pad: SAME act: relu
[TL] MaxPool2d VGG19/pool1: filter_size: (2, 2) strides: (2, 2) padding: SAME
Instructions for updating:
Use keras.layers.max_pooling2d instead.
[TL] Conv2d VGG19/conv2_1: n_filter: 128 filter_size: (3, 3) strides: (1, 1) pad: SAME act: relu
[TL] Conv2d VGG19/conv2_2: n_filter: 128 filter_size: (3, 3) strides: (1, 1) pad: SAME act: relu
[TL] MaxPool2d VGG19/pool2: filter_size: (2, 2) strides: (2, 2) padding: SAME
[TL] Conv2d VGG19/conv3_1: n_filter: 256 filter_size: (3, 3) strides: (1, 1) pad: SAME act: relu
[TL] Conv2d VGG19/conv3_2: n_filter: 256 filter_size: (3, 3) strides: (1, 1) pad: SAME act: relu
[TL] Conv2d VGG19/conv3_3: n_filter: 256 filter_size: (3, 3) strides: (1, 1) pad: SAME act: relu
[TL] Conv2d VGG19/conv3_4: n_fil

In [19]:
## test inference
net_g_test = SRGAN_g(t_image, is_train=False, reuse=True)

[TL] InputLayer  SRGAN_g/in: (25, 71, 71, 3)
[TL] Conv2d SRGAN_g/n64s1/c: n_filter: 64 filter_size: (3, 3) strides: (1, 1) pad: SAME act: relu
[TL] Conv2d SRGAN_g/n64s1/c1/0: n_filter: 64 filter_size: (3, 3) strides: (1, 1) pad: SAME act: No Activation
[TL] BatchNormLayer SRGAN_g/n64s1/b1/0: decay: 0.900000 epsilon: 0.000010 act: relu is_train: False
[TL] Conv2d SRGAN_g/n64s1/c2/0: n_filter: 64 filter_size: (3, 3) strides: (1, 1) pad: SAME act: No Activation
[TL] BatchNormLayer SRGAN_g/n64s1/b2/0: decay: 0.900000 epsilon: 0.000010 act: No Activation is_train: False
[TL] ElementwiseLayer SRGAN_g/b_residual_add/0: size: (25, 71, 71, 64) fn: add
[TL] Conv2d SRGAN_g/n64s1/c1/1: n_filter: 64 filter_size: (3, 3) strides: (1, 1) pad: SAME act: No Activation
[TL] BatchNormLayer SRGAN_g/n64s1/b1/1: decay: 0.900000 epsilon: 0.000010 act: relu is_train: False
[TL] Conv2d SRGAN_g/n64s1/c2/1: n_filter: 64 filter_size: (3, 3) strides: (1, 1) pad: SAME act: No Activation
[TL] BatchNormLayer SRGAN_g/n

[TL] ElementwiseLayer SRGAN_g/b_residual_add/15: size: (25, 71, 71, 64) fn: add
[TL] Conv2d SRGAN_g/n64s1/c/m: n_filter: 64 filter_size: (3, 3) strides: (1, 1) pad: SAME act: No Activation
[TL] BatchNormLayer SRGAN_g/n64s1/b/m: decay: 0.900000 epsilon: 0.000010 act: No Activation is_train: False
[TL] ElementwiseLayer SRGAN_g/add3: size: (25, 71, 71, 64) fn: add
[TL] Conv2d SRGAN_g/n256s1/1: n_filter: 256 filter_size: (3, 3) strides: (1, 1) pad: SAME act: No Activation
[TL] SubpixelConv2d  SRGAN_g/pixelshufflerx2/1: scale: 2 n_out_channel: 64 act: relu
[TL] Conv2d SRGAN_g/n256s1/2: n_filter: 256 filter_size: (3, 3) strides: (1, 1) pad: SAME act: No Activation
[TL] SubpixelConv2d  SRGAN_g/pixelshufflerx2/2: scale: 2 n_out_channel: 64 act: relu
[TL] Conv2d SRGAN_g/out: n_filter: 3 filter_size: (1, 1) strides: (1, 1) pad: SAME act: tanh


In [20]:
# ###========================== DEFINE TRAIN OPS ==========================###
d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real), name='d1')
d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake), name='d2')
d_loss = d_loss1 + d_loss2

g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(logits_fake, tf.ones_like(logits_fake), name='g')
mse_loss = tl.cost.mean_squared_error(net_g.outputs, t_target_image, is_mean=True)
vgg_loss = 2e-6 * tl.cost.mean_squared_error(vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True)

g_loss = mse_loss + vgg_loss + g_gan_loss

g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True)
d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True)

with tf.variable_scope('learning_rate'):
    lr_v = tf.Variable(lr_init, trainable=False)

[TL]   [*] geting variables with SRGAN_g
[TL]   got   0: SRGAN_g/n64s1/c/kernel:0   (3, 3, 3, 64)
[TL]   got   1: SRGAN_g/n64s1/c/bias:0   (64,)
[TL]   got   2: SRGAN_g/n64s1/c1/0/kernel:0   (3, 3, 64, 64)
[TL]   got   3: SRGAN_g/n64s1/b1/0/beta:0   (64,)
[TL]   got   4: SRGAN_g/n64s1/b1/0/gamma:0   (64,)
[TL]   got   5: SRGAN_g/n64s1/c2/0/kernel:0   (3, 3, 64, 64)
[TL]   got   6: SRGAN_g/n64s1/b2/0/beta:0   (64,)
[TL]   got   7: SRGAN_g/n64s1/b2/0/gamma:0   (64,)
[TL]   got   8: SRGAN_g/n64s1/c1/1/kernel:0   (3, 3, 64, 64)
[TL]   got   9: SRGAN_g/n64s1/b1/1/beta:0   (64,)
[TL]   got  10: SRGAN_g/n64s1/b1/1/gamma:0   (64,)
[TL]   got  11: SRGAN_g/n64s1/c2/1/kernel:0   (3, 3, 64, 64)
[TL]   got  12: SRGAN_g/n64s1/b2/1/beta:0   (64,)
[TL]   got  13: SRGAN_g/n64s1/b2/1/gamma:0   (64,)
[TL]   got  14: SRGAN_g/n64s1/c1/2/kernel:0   (3, 3, 64, 64)
[TL]   got  15: SRGAN_g/n64s1/b1/2/beta:0   (64,)
[TL]   got  16: SRGAN_g/n64s1/b1/2/gamma:0   (64,)
[TL]   got  17: SRGAN_g/n64s1/c2/2/kernel:0  

In [21]:
g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(mse_loss, var_list=g_vars)
## SRGAN
g_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(g_loss, var_list=g_vars)
d_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(d_loss, var_list=d_vars)

Instructions for updating:
Use tf.cast instead.


In [22]:
###========================== RESTORE MODEL =============================###
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False))
tl.layers.initialize_global_variables(sess)
if tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_{}.npz'.format("srgan"), network=net_g) is False:
    tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_{}_init.npz'.format("srgan"), network=net_g)
tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/d_{}.npz'.format("srgan"), network=net_d)

Instructions for updating: This API is deprecated in favor of `tf.global_variables_initializer`

[TL] ERROR: file checkpoint/g_srgan.npz doesn't exist.
[TL] ERROR: file checkpoint/d_srgan.npz doesn't exist.


In [23]:
###============================= LOAD VGG ===============================###
vgg19_npy_path = "../vgg19.npy"
if not os.path.isfile(vgg19_npy_path):
    print("Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg")
    exit()
npz = np.load(vgg19_npy_path, encoding='latin1').item()

params = []
for val in sorted(npz.items()):
    W = np.asarray(val[1][0])
    b = np.asarray(val[1][1])
    print("  Loading %s: %s, %s" % (val[0], W.shape, b.shape))
    params.extend([W, b])
tl.files.assign_params(sess, params, net_vgg)
# net_vgg.print_params(False)
# net_vgg.print_layers()

  Loading conv1_1: (3, 3, 3, 64), (64,)
  Loading conv1_2: (3, 3, 64, 64), (64,)
  Loading conv2_1: (3, 3, 64, 128), (128,)
  Loading conv2_2: (3, 3, 128, 128), (128,)
  Loading conv3_1: (3, 3, 128, 256), (256,)
  Loading conv3_2: (3, 3, 256, 256), (256,)
  Loading conv3_3: (3, 3, 256, 256), (256,)
  Loading conv3_4: (3, 3, 256, 256), (256,)
  Loading conv4_1: (3, 3, 256, 512), (512,)
  Loading conv4_2: (3, 3, 512, 512), (512,)
  Loading conv4_3: (3, 3, 512, 512), (512,)
  Loading conv4_4: (3, 3, 512, 512), (512,)
  Loading conv5_1: (3, 3, 512, 512), (512,)
  Loading conv5_2: (3, 3, 512, 512), (512,)
  Loading conv5_3: (3, 3, 512, 512), (512,)
  Loading conv5_4: (3, 3, 512, 512), (512,)
  Loading fc6: (25088, 4096), (4096,)
  Loading fc7: (4096, 4096), (4096,)
  Loading fc8: (4096, 1000), (1000,)


[<tf.Tensor 'Assign:0' shape=(3, 3, 3, 64) dtype=float32_ref>,
 <tf.Tensor 'Assign_1:0' shape=(64,) dtype=float32_ref>,
 <tf.Tensor 'Assign_2:0' shape=(3, 3, 64, 64) dtype=float32_ref>,
 <tf.Tensor 'Assign_3:0' shape=(64,) dtype=float32_ref>,
 <tf.Tensor 'Assign_4:0' shape=(3, 3, 64, 128) dtype=float32_ref>,
 <tf.Tensor 'Assign_5:0' shape=(128,) dtype=float32_ref>,
 <tf.Tensor 'Assign_6:0' shape=(3, 3, 128, 128) dtype=float32_ref>,
 <tf.Tensor 'Assign_7:0' shape=(128,) dtype=float32_ref>,
 <tf.Tensor 'Assign_8:0' shape=(3, 3, 128, 256) dtype=float32_ref>,
 <tf.Tensor 'Assign_9:0' shape=(256,) dtype=float32_ref>,
 <tf.Tensor 'Assign_10:0' shape=(3, 3, 256, 256) dtype=float32_ref>,
 <tf.Tensor 'Assign_11:0' shape=(256,) dtype=float32_ref>,
 <tf.Tensor 'Assign_12:0' shape=(3, 3, 256, 256) dtype=float32_ref>,
 <tf.Tensor 'Assign_13:0' shape=(256,) dtype=float32_ref>,
 <tf.Tensor 'Assign_14:0' shape=(3, 3, 256, 256) dtype=float32_ref>,
 <tf.Tensor 'Assign_15:0' shape=(256,) dtype=float32_re

In [24]:
def train_hr_imgs_datagen(train_hr_imgs,batch_size):
    while True:
        idx = np.random.randint(0,len(train_hr_imgs),batch_size)
        yield [train_hr_imgs[x] for x in idx]

In [25]:
sess.run(tf.assign(lr_v, lr_init))
print(" ** fixed learning rate: %f (for init G)" % lr_init)
datagen = train_hr_imgs_datagen(train_hr_imgs,batch_size)
for epoch in range(0, n_epoch_init + 1):
    epoch_time = time.time()
    total_mse_loss, n_iter = 0, 0

    ## If your machine cannot load all images into memory, you should use
    ## this one to load batch of images while training.
    # random.shuffle(train_hr_img_list)
    # for idx in range(0, len(train_hr_img_list), batch_size):
    #     step_time = time.time()
    #     b_imgs_list = train_hr_img_list[idx : idx + batch_size]
    #     b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path)
    #     b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True)
    #     b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)

    ## If your machine have enough memory, please pre-load the whole train set.
    random.shuffle(train_hr_img_list)
    for idx in range(0, len(train_hr_imgs), batch_size):
        step_time = time.time()
        '''        
        b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, dim=cropDim, is_random=True)
        b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn,dim=int(cropDim/4))'''
        b_imgs_384 = tl.prepro.threading_data(next(datagen), fn=crop_sub_imgs_fn,dim=cropDim, is_random=True)
        b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn,dim=int(cropDim/4))
        ## update G
        errM, _ = sess.run([mse_loss, g_optim_init], {t_image: b_imgs_96, t_target_image: b_imgs_384})
        print("Epoch [%2d/%2d] %4d time: %4.4fs, mse: %.8f " % (epoch, n_epoch_init, n_iter, time.time() - step_time, errM))
        total_mse_loss += errM
        n_iter += 1
    log = "[*] Epoch: [%2d/%2d] time: %4.4fs, mse: %.8f" % (epoch, n_epoch_init, time.time() - epoch_time, total_mse_loss / n_iter)
    print(log)

    ## quick evaluation on train set
    if (epoch != 0) and (epoch % 10 == 0):
        out = sess.run(net_g_test.outputs, {t_image: sample_imgs_96})  #; print('gen sub-image:', out.shape, out.min(), out.max())
        print("[*] save images")
        tl.vis.save_images(out, [ni, ni], save_dir_ginit + '/train_%d.png' % epoch)

    ## save model
    if (epoch != 0) and (epoch % 10 == 0):
        tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}_init.npz'.format("train"), sess=sess)


 ** fixed learning rate: 0.000100 (for init G)
Epoch [ 0/50]    0 time: 7.5343s, mse: 0.25088289 
Epoch [ 0/50]    1 time: 1.2858s, mse: 0.29249138 
Epoch [ 0/50]    2 time: 1.2861s, mse: 0.26170868 
Epoch [ 0/50]    3 time: 1.2875s, mse: 0.24054869 
Epoch [ 0/50]    4 time: 1.2855s, mse: 0.20309448 
[*] Epoch: [ 0/50] time: 12.6798s, mse: 0.24974522
Epoch [ 1/50]    0 time: 1.2849s, mse: 0.27957079 
Epoch [ 1/50]    1 time: 1.2913s, mse: 0.27528676 
Epoch [ 1/50]    2 time: 1.2868s, mse: 0.21564987 
Epoch [ 1/50]    3 time: 1.2829s, mse: 0.21859474 
Epoch [ 1/50]    4 time: 1.2870s, mse: 0.22136886 
[*] Epoch: [ 1/50] time: 6.4339s, mse: 0.24209421
Epoch [ 2/50]    0 time: 1.2870s, mse: 0.22367394 
Epoch [ 2/50]    1 time: 1.2845s, mse: 0.18373340 
Epoch [ 2/50]    2 time: 1.2897s, mse: 0.17565523 
Epoch [ 2/50]    3 time: 1.2885s, mse: 0.18766643 
Epoch [ 2/50]    4 time: 1.2883s, mse: 0.18341057 
[*] Epoch: [ 2/50] time: 6.4387s, mse: 0.19082792
Epoch [ 3/50]    0 time: 1.2844s, mse

Epoch [26/50]    1 time: 1.3194s, mse: 0.03321599 
Epoch [26/50]    2 time: 1.3252s, mse: 0.03772750 
Epoch [26/50]    3 time: 1.3178s, mse: 0.03430004 
Epoch [26/50]    4 time: 1.3299s, mse: 0.04560937 
[*] Epoch: [26/50] time: 6.6188s, mse: 0.03918353
Epoch [27/50]    0 time: 1.3279s, mse: 0.04544378 
Epoch [27/50]    1 time: 1.3241s, mse: 0.04976065 
Epoch [27/50]    2 time: 1.3263s, mse: 0.03320111 
Epoch [27/50]    3 time: 1.3239s, mse: 0.02935722 
Epoch [27/50]    4 time: 1.3201s, mse: 0.03793742 
[*] Epoch: [27/50] time: 6.6229s, mse: 0.03914003
Epoch [28/50]    0 time: 1.3261s, mse: 0.03565137 
Epoch [28/50]    1 time: 1.3212s, mse: 0.03629584 
Epoch [28/50]    2 time: 1.3243s, mse: 0.03781015 
Epoch [28/50]    3 time: 1.3255s, mse: 0.04568729 
Epoch [28/50]    4 time: 1.3250s, mse: 0.03955882 
[*] Epoch: [28/50] time: 6.6228s, mse: 0.03900070
Epoch [29/50]    0 time: 1.3168s, mse: 0.03494084 
Epoch [29/50]    1 time: 1.3241s, mse: 0.02992238 
Epoch [29/50]    2 time: 1.3269s, 

In [26]:
###========================= train GAN (SRGAN) =========================###
for epoch in range(0, n_epoch + 1):
    ## update learning rate
    if epoch != 0 and (epoch % decay_every == 0):
        new_lr_decay = lr_decay**(epoch // decay_every)
        sess.run(tf.assign(lr_v, lr_init * new_lr_decay))
        log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay)
        print(log)
    elif epoch == 0:
        sess.run(tf.assign(lr_v, lr_init))
        log = " ** init lr: %f  decay_every_init: %d, lr_decay: %f (for GAN)" % (lr_init, decay_every, lr_decay)
        print(log)

    epoch_time = time.time()
    total_d_loss, total_g_loss, n_iter = 0, 0, 0

    ## If your machine cannot load all images into memory, you should use
    ## this one to load batch of images while training.
    # random.shuffle(train_hr_img_list)
    # for idx in range(0, len(train_hr_img_list), batch_size):
    #     step_time = time.time()
    #     b_imgs_list = train_hr_img_list[idx : idx + batch_size]
    #     b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path)
    #     b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True)
    #     b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)

    ## If your machine have enough memory, please pre-load the whole train set.
    for idx in range(0, len(train_hr_imgs), batch_size):
        step_time = time.time()
        b_imgs_384 = tl.prepro.threading_data(next(datagen), fn=crop_sub_imgs_fn,dim = cropDim, is_random=True)
        b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn,dim=int(cropDim/4))
        ## update D
        errD, _ = sess.run([d_loss, d_optim], {t_image: b_imgs_96, t_target_image: b_imgs_384})
        ## update G
        errG, errM, errV, errA, _ = sess.run([g_loss, mse_loss, vgg_loss, g_gan_loss, g_optim], {t_image: b_imgs_96, t_target_image: b_imgs_384})
        print("Epoch [%2d/%2d] %4d time: %4.4fs, d_loss: %.8f g_loss: %.8f (mse: %.6f vgg: %.6f adv: %.6f)" %
              (epoch, n_epoch, n_iter, time.time() - step_time, errD, errG, errM, errV, errA))
        total_d_loss += errD
        total_g_loss += errG
        n_iter += 1

    log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f" % (epoch, n_epoch, time.time() - epoch_time, total_d_loss / n_iter,
                                                                            total_g_loss / n_iter)
    print(log)

    ## quick evaluation on train set
    if (epoch != 0) and (epoch % 10 == 0):
        out = sess.run(net_g_test.outputs, {t_image: sample_imgs_96})  #; print('gen sub-image:', out.shape, out.min(), out.max())
        print("[*] save images")
        tl.vis.save_images(out, [ni, ni], save_dir_gan + '/train_%d.png' % epoch)

    ## save model
    if (epoch != 0) and (epoch % 10 == 0):
        tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}.npz'.format("train"), sess=sess)
        tl.files.save_npz(net_d.all_params, name=checkpoint_dir + '/d_{}.npz'.format("train"), sess=sess)



 ** init lr: 0.000100  decay_every_init: 25, lr_decay: 0.010000 (for GAN)
Epoch [ 0/50]    0 time: 30.6684s, d_loss: 2.31000900 g_loss: 0.12472424 (mse: 0.028939 vgg: 0.084180 adv: 0.011606)
Epoch [ 0/50]    1 time: 4.2152s, d_loss: 2.59656048 g_loss: 0.11696380 (mse: 0.037786 vgg: 0.076921 adv: 0.002256)
Epoch [ 0/50]    2 time: 4.2310s, d_loss: 2.69071388 g_loss: 0.11864414 (mse: 0.029488 vgg: 0.080540 adv: 0.008616)
Epoch [ 0/50]    3 time: 4.2799s, d_loss: 2.39786530 g_loss: 0.11202874 (mse: 0.028076 vgg: 0.080614 adv: 0.003339)
Epoch [ 0/50]    4 time: 4.2314s, d_loss: 2.01488829 g_loss: 0.10545665 (mse: 0.028081 vgg: 0.074797 adv: 0.002579)
[*] Epoch: [ 0/50] time: 47.6265s, d_loss: 2.40200739 g_loss: 0.11556351
Epoch [ 1/50]    0 time: 4.2443s, d_loss: 1.94235516 g_loss: 0.11480162 (mse: 0.033344 vgg: 0.079797 adv: 0.001660)
Epoch [ 1/50]    1 time: 4.2471s, d_loss: 1.95841336 g_loss: 0.11257444 (mse: 0.032669 vgg: 0.077319 adv: 0.002586)
Epoch [ 1/50]    2 time: 4.4252s, d_loss

Epoch [12/50]    2 time: 4.3487s, d_loss: 1.64342451 g_loss: 0.08770262 (mse: 0.024631 vgg: 0.061515 adv: 0.001556)
Epoch [12/50]    3 time: 4.5188s, d_loss: 1.67053366 g_loss: 0.07927490 (mse: 0.023674 vgg: 0.053970 adv: 0.001631)
Epoch [12/50]    4 time: 4.2857s, d_loss: 1.29154813 g_loss: 0.09529214 (mse: 0.029527 vgg: 0.062385 adv: 0.003380)
[*] Epoch: [12/50] time: 21.7405s, d_loss: 1.64894598 g_loss: 0.08981529
Epoch [13/50]    0 time: 4.2433s, d_loss: 0.92347705 g_loss: 0.07892036 (mse: 0.022265 vgg: 0.052997 adv: 0.003659)
Epoch [13/50]    1 time: 4.3089s, d_loss: 1.10077953 g_loss: 0.10127513 (mse: 0.031571 vgg: 0.065865 adv: 0.003839)
Epoch [13/50]    2 time: 4.5384s, d_loss: 1.05681586 g_loss: 0.09296098 (mse: 0.031019 vgg: 0.058538 adv: 0.003403)
Epoch [13/50]    3 time: 4.2236s, d_loss: 1.42791426 g_loss: 0.08357523 (mse: 0.024290 vgg: 0.057203 adv: 0.002082)
Epoch [13/50]    4 time: 4.3536s, d_loss: 1.38350391 g_loss: 0.07987764 (mse: 0.025842 vgg: 0.052666 adv: 0.001370)

Epoch [24/50]    4 time: 4.4191s, d_loss: 0.82270873 g_loss: 0.07850844 (mse: 0.023829 vgg: 0.051810 adv: 0.002869)
[*] Epoch: [24/50] time: 22.1455s, d_loss: 0.66151026 g_loss: 0.08433631
 ** new learning rate: 0.000001 (for GAN)
Epoch [25/50]    0 time: 4.2333s, d_loss: 1.10484350 g_loss: 0.08964056 (mse: 0.034479 vgg: 0.053519 adv: 0.001643)
Epoch [25/50]    1 time: 4.4932s, d_loss: 0.85400772 g_loss: 0.08094756 (mse: 0.026793 vgg: 0.052459 adv: 0.001696)
Epoch [25/50]    2 time: 4.3838s, d_loss: 1.25939345 g_loss: 0.09107750 (mse: 0.033683 vgg: 0.055296 adv: 0.002098)
Epoch [25/50]    3 time: 4.3269s, d_loss: 1.18149984 g_loss: 0.08254436 (mse: 0.025601 vgg: 0.054773 adv: 0.002171)
Epoch [25/50]    4 time: 4.5105s, d_loss: 1.12531614 g_loss: 0.07449906 (mse: 0.023002 vgg: 0.049062 adv: 0.002435)
[*] Epoch: [25/50] time: 21.9483s, d_loss: 1.10501213 g_loss: 0.08374181
Epoch [26/50]    0 time: 4.6492s, d_loss: 0.77662289 g_loss: 0.07778311 (mse: 0.023158 vgg: 0.052100 adv: 0.002525)


Epoch [37/50]    0 time: 4.2308s, d_loss: 0.38068652 g_loss: 0.07782017 (mse: 0.024216 vgg: 0.050325 adv: 0.003279)
Epoch [37/50]    1 time: 4.5573s, d_loss: 0.12864280 g_loss: 0.08227716 (mse: 0.030515 vgg: 0.047832 adv: 0.003930)
Epoch [37/50]    2 time: 4.3086s, d_loss: 0.18234748 g_loss: 0.07273778 (mse: 0.023001 vgg: 0.045014 adv: 0.004724)
Epoch [37/50]    3 time: 4.3493s, d_loss: 0.24420893 g_loss: 0.08335929 (mse: 0.024468 vgg: 0.055286 adv: 0.003606)
Epoch [37/50]    4 time: 4.2088s, d_loss: 0.22148442 g_loss: 0.08959463 (mse: 0.033266 vgg: 0.051748 adv: 0.004580)
[*] Epoch: [37/50] time: 21.6557s, d_loss: 0.23147403 g_loss: 0.08115781
Epoch [38/50]    0 time: 4.5414s, d_loss: 0.52099806 g_loss: 0.08177376 (mse: 0.029637 vgg: 0.048223 adv: 0.003914)
Epoch [38/50]    1 time: 4.4494s, d_loss: 0.31779876 g_loss: 0.08623774 (mse: 0.026088 vgg: 0.055667 adv: 0.004483)
Epoch [38/50]    2 time: 4.4109s, d_loss: 0.20356175 g_loss: 0.08276644 (mse: 0.027581 vgg: 0.050224 adv: 0.004961)

Epoch [49/50]    2 time: 4.2497s, d_loss: 0.29090148 g_loss: 0.07398535 (mse: 0.024474 vgg: 0.045762 adv: 0.003749)
Epoch [49/50]    3 time: 4.4431s, d_loss: 0.19192170 g_loss: 0.07623458 (mse: 0.025755 vgg: 0.047149 adv: 0.003331)
Epoch [49/50]    4 time: 4.2710s, d_loss: 0.17504884 g_loss: 0.08532988 (mse: 0.028574 vgg: 0.052674 adv: 0.004082)
[*] Epoch: [49/50] time: 21.9679s, d_loss: 0.30584143 g_loss: 0.07824300
 ** new learning rate: 0.000000 (for GAN)
Epoch [50/50]    0 time: 4.2559s, d_loss: 0.22437957 g_loss: 0.08734860 (mse: 0.026028 vgg: 0.057979 adv: 0.003341)
Epoch [50/50]    1 time: 4.3485s, d_loss: 0.26060623 g_loss: 0.08253016 (mse: 0.024038 vgg: 0.055287 adv: 0.003205)
Epoch [50/50]    2 time: 4.3966s, d_loss: 0.46992776 g_loss: 0.07427488 (mse: 0.023521 vgg: 0.045362 adv: 0.005392)
Epoch [50/50]    3 time: 4.6597s, d_loss: 0.31103992 g_loss: 0.08336151 (mse: 0.022181 vgg: 0.057006 adv: 0.004175)
Epoch [50/50]    4 time: 4.2159s, d_loss: 0.13414802 g_loss: 0.08380102 (