In [1]:
import numpy as np
from keras.models import Model
from tensorflow.python.keras.models import load_model
from keras.layers import Input, BatchNormalization, Dense, Add, Activation, Reshape, Permute, Flatten, Conv2DTranspose
from keras.layers.convolutional import Conv2D, UpSampling2D
from keras.layers.advanced_activations import PReLU, LeakyReLU
from keras.optimizers import Adam
from keras.applications import VGG19
import datetime
import random
from PIL import Image
import glob
import cv2
import os
import matplotlib.pyplot as plt


ratio = 4
LR_shape = (120, 160, 3)

L_h, L_w, channels = LR_shape
H_h = L_h * ratio
H_w = L_w * ratio
HR_shape = (H_h, H_w, channels)

optimizer = Adam()

Using TensorFlow backend.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
def load_data(batch_size):

    files = glob.glob("images/train/*.png", recursive=True)
    batch_images = random.sample(files, batch_size)

    hr_imgs = []
    lr_imgs = []
    for img_path in batch_images:
        img = Image.open(img_path)

        hr_img = img.resize((H_w, H_h))  #(64, 64)
        lr_img = img.resize((L_w, L_h))
        hr_img = np.array(hr_img)
        lr_img = np.array(lr_img)

        hr_imgs.append(hr_img)
        lr_imgs.append(lr_img)

    hr_imgs = np.array(hr_imgs) / 127.5 - 1.
    lr_imgs = np.array(lr_imgs) / 127.5 - 1.

    return hr_imgs, lr_imgs
      

In [3]:

def pixel_shuffle(in_map, h, w, c):
    
    x = Reshape((h, w, 2, 2, c))(in_map)
    x = Permute((3, 1, 4, 2, 5))(x)
    out_map = Reshape((2 * h, 2 * w, c))(x)
    
    return out_map


def upsampling(in_map, h, w, c):
    
    x = Conv2D(filters = 4 * c, 
                     kernel_size = 3,
                     strides = 1,
                     padding = "same")(in_map)
    x = pixel_shuffle(x, h, w, c)
    out_map = PReLU()(x)
    
    return out_map


def residual_block(in_map):
    x = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = "same")(in_map)
    x = LeakyReLU(alpha = 0)(x)
    x = BatchNormalization()(x)
    x = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = "same")(x)
    x = BatchNormalization()(x)
    out_map = Add()([x, in_map])
    return out_map


def d_block(in_map, filters, kernel_size, strides, padding):
    d = Conv2D(filters = filters, kernel_size = kernel_size, strides = strides, padding = padding)(in_map)
    d = LeakyReLU(alpha = 0.2)(d)
    d = BatchNormalization(momentum = 0.8)(d)
    return d


def deconv2d(layer_input):
    """Layers used during upsampling"""
    u = Conv2DTranspose(120, kernel_size = 3,
                        strides=2, padding='same')(layer_input)
    u = Activation('relu')(u)
    return u



In [4]:
def build_generator():
    input_img = Input(shape = LR_shape)
    middle = Conv2D(filters = 64, kernel_size = 9, strides = 1, padding = "same")(input_img)
    middle = LeakyReLU(alpha = 0)(middle)
    
    g = residual_block(middle)
    for _ in range(4):
        g = residual_block(g)

    g = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = "same")(g)
    g = BatchNormalization()(g)
    g = Add()([g, middle])

    n = ratio
    i = 1
    while(n % 2 == 0):
        g = deconv2d(g)
        n = n // 2

    output_img = Conv2D(filters = 3, kernel_size = 9, strides = 1, padding = "same")(g)

    return Model(input_img, output_img)


def build_discriminator():
    input_img = Input(shape = HR_shape)
    
    d = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = "same")(input_img)#1
    d = LeakyReLU(alpha=0.2)(d)
    d = d_block(d, filters = 64, kernel_size = 3, strides = 2, padding = "same")
    d = d_block(d, filters = 128, kernel_size = 3, strides = 1, padding = "same")#1
    d = d_block(d, filters = 128, kernel_size = 3, strides = 2, padding = "same")
    d = d_block(d, filters = 256, kernel_size = 3, strides = 1, padding = "same")#1
    d = d_block(d, filters = 256, kernel_size = 3, strides = 2, padding = "same")
    d = d_block(d, filters = 512, kernel_size = 3, strides = 1, padding = "same")#1
    d = d_block(d, filters = 512, kernel_size = 3, strides = 2, padding = "same")
#     d = Flatten()(d)
    d = Dense(512)(d)
    d = LeakyReLU(alpha = 0.2)(d)
    output = Dense(1, activation = "sigmoid")(d)

    return Model(input_img, output)


def build_vgg():
    vgg = VGG19(include_top = False)
    return Model(vgg.input, vgg.layers[9].output)
    

def combined(generator, discriminator, vgg):
    input_img = Input(shape = LR_shape)
    fake_img = generator(input_img)
    
    validity = discriminator(fake_img)
    features = vgg(fake_img)
    
    return Model(input_img, [validity, features])

In [5]:
losses = []
epochs_checkpoint = []

def train(epochs, batch_size, interval):
    
    start_time = datetime.datetime.now()
    os.makedirs("weights/" + start_time.strftime('%m%d'), exist_ok = True)
    os.makedirs("images/train_result/" + start_time.strftime('%m%d'), exist_ok = True)
    
    real = np.ones((batch_size,) + (H_h // 16, H_w // 16, 1))
    fake = np.zeros((batch_size,) + (H_h // 16, H_w // 16, 1))
    
    for epoch in range(epochs):
        real_imgs, lr_imgs = load_data(batch_size)
        fake_imgs = generator.predict(lr_imgs)
        
        #Dの訓練
        d_loss_real = discriminator.train_on_batch(real_imgs, real)
        d_loss_fake = discriminator.train_on_batch(fake_imgs, fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
        
        #Gの訓練
        vgg_features = vgg.predict(real_imgs)
        g_loss = srgan.train_on_batch(lr_imgs, [real, vgg_features])
        
        time = datetime.datetime.now() - start_time
        print("%d time: %s" % (epoch+1, time))
        
        if (epoch+1) % interval == 0:
            losses.append((d_loss, g_loss))
            epochs_checkpoint.append(epoch+1)
            generator.save("weights/weightname.h5")
            print("save weights")
    

In [6]:
discriminator = build_discriminator()
discriminator.compile(loss = "mse",
                      optimizer = optimizer,
                      metrics = ["accuracy"])
discriminator.summary()

generator = build_generator()
generator.summary()





_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 480, 640, 3)       0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 480, 640, 64)      1792      
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 480, 640, 64)      0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 240, 320, 64)      36928     
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 240, 320, 64)      0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 240, 320, 64)      256       
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 240, 320, 128)     73856     
______

In [7]:
vgg = build_vgg()
vgg.trainable = False
discriminator.trainable = False
srgan = combined(generator, discriminator, vgg)
srgan.compile(loss=['binary_crossentropy', 'mse'],
                              loss_weights=[1e-3, 1],
                              optimizer=optimizer)
srgan.summary()




Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
input_4 (InputLayer)             (None, 120, 160, 3)   0                                            
____________________________________________________________________________________________________
model_2 (Model)                  (None, 480, 640, 3)   652763                                       
____________________________________________________________________________________________________
model_1 (Model)                  (None, 30, 40, 1)     4955969                                      
____________________________________________________________________________________________________
model_3 (Model)                  multiple              1735488                                 

In [8]:
train(epochs = 3000, batch_size = 2, interval = 500)

Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
1 time: 0:00:14.319958
2 time: 0:00:15.245596
3 time: 0:00:16.151572
4 time: 0:00:17.076362
5 time: 0:00:17.993367
6 time: 0:00:18.910634
7 time: 0:00:19.834302
8 time: 0:00:20.742789
9 time: 0:00:21.667795
10 time: 0:00:22.576198
11 time: 0:00:23.501771
12 time: 0:00:24.410169
13 time: 0:00:25.334094
14 time: 0:00:26.241984
15 time: 0:00:27.166915
16 time: 0:00:28.091826
17 time: 0:00:29.013045
18 time: 0:00:29.934432
19 time: 0:00:30.846483
20 time: 0:00:31.770580
21 time: 0:00:32.677463
22 time: 0:00:33.600403
23 time: 0:00:34.505856
24 time: 0:00:35.429263
25 time: 0:00:36.334662
26 time: 0:00:37.258591
27 time: 0:00:38.164239
28 time: 0:00:39.086945
29 time: 0:00:40.002121
30 time: 0:00:40.920889
31 time: 0:00:41.846450
32 time: 0:00:42.759152
33 time: 0:00:43.682607
34 time: 0:00:44.592165
35 time: 0:00:45.517601
36 time: 0:00:46.422477
37 time: 0:00:47.346959
38 

318 time: 0:05:06.201180
319 time: 0:05:07.131128
320 time: 0:05:08.048508
321 time: 0:05:08.966799
322 time: 0:05:09.895427
323 time: 0:05:10.807642
324 time: 0:05:11.735083
325 time: 0:05:12.655675
326 time: 0:05:13.588499
327 time: 0:05:14.506746
328 time: 0:05:15.442352
329 time: 0:05:16.358406
330 time: 0:05:17.291405
331 time: 0:05:18.201843
332 time: 0:05:19.133673
333 time: 0:05:20.055416
334 time: 0:05:20.988353
335 time: 0:05:21.932484
336 time: 0:05:22.867448
337 time: 0:05:23.811323
338 time: 0:05:24.739853
339 time: 0:05:25.687799
340 time: 0:05:26.618769
341 time: 0:05:27.565661
342 time: 0:05:28.495558
343 time: 0:05:29.442586
344 time: 0:05:30.374029
345 time: 0:05:31.322373
346 time: 0:05:32.247838
347 time: 0:05:33.196130
348 time: 0:05:34.125210
349 time: 0:05:35.068678
350 time: 0:05:36.007729
351 time: 0:05:36.940734
352 time: 0:05:37.882023
353 time: 0:05:38.803807
354 time: 0:05:39.740566
355 time: 0:05:40.661948
356 time: 0:05:41.600258
357 time: 0:05:42.523184


646 time: 0:10:17.254595
647 time: 0:10:18.166607
648 time: 0:10:19.097868
649 time: 0:10:20.021693
650 time: 0:10:20.944640
651 time: 0:10:21.878605
652 time: 0:10:22.792553
653 time: 0:10:23.725804
654 time: 0:10:24.639346
655 time: 0:10:25.573065
656 time: 0:10:26.486069
657 time: 0:10:27.416981
658 time: 0:10:28.330445
659 time: 0:10:29.262864
660 time: 0:10:30.175654
661 time: 0:10:31.109091
662 time: 0:10:32.029920
663 time: 0:10:32.948786
664 time: 0:10:33.879358
665 time: 0:10:34.794725
666 time: 0:10:35.728562
667 time: 0:10:36.644589
668 time: 0:10:37.578376
669 time: 0:10:38.496211
670 time: 0:10:39.429535
671 time: 0:10:40.345379
672 time: 0:10:41.280208
673 time: 0:10:42.191808
674 time: 0:10:43.125296
675 time: 0:10:44.050235
676 time: 0:10:44.980443
677 time: 0:10:45.916740
678 time: 0:10:46.834909
679 time: 0:10:47.765455
680 time: 0:10:48.681808
681 time: 0:10:49.615855
682 time: 0:10:50.534717
683 time: 0:10:51.468528
684 time: 0:10:52.385111
685 time: 0:10:53.318976


974 time: 0:15:10.222382
975 time: 0:15:11.110190
976 time: 0:15:11.995996
977 time: 0:15:12.882308
978 time: 0:15:13.770624
979 time: 0:15:14.654931
980 time: 0:15:15.539737
981 time: 0:15:16.423046
982 time: 0:15:17.309853
983 time: 0:15:18.192656
984 time: 0:15:19.079968
985 time: 0:15:19.966775
986 time: 0:15:20.851581
987 time: 0:15:21.739389
988 time: 0:15:22.626701
989 time: 0:15:23.511505
990 time: 0:15:24.395816
991 time: 0:15:25.282623
992 time: 0:15:26.166932
993 time: 0:15:27.055752
994 time: 0:15:27.942559
995 time: 0:15:28.826867
996 time: 0:15:29.712673
997 time: 0:15:30.594979
998 time: 0:15:31.482291
999 time: 0:15:32.363601
1000 time: 0:15:33.250408
save weights
1001 time: 0:15:34.379446
1002 time: 0:15:35.266756
1003 time: 0:15:36.152562
1004 time: 0:15:37.037873
1005 time: 0:15:37.926186
1006 time: 0:15:38.817502
1007 time: 0:15:39.704430
1008 time: 0:15:40.589739
1009 time: 0:15:41.476546
1010 time: 0:15:42.361351
1011 time: 0:15:43.250160
1012 time: 0:15:44.134964

1290 time: 0:19:50.779041
1291 time: 0:19:51.665350
1292 time: 0:19:52.550436
1293 time: 0:19:53.438244
1294 time: 0:19:54.327559
1295 time: 0:19:55.215871
1296 time: 0:19:56.100180
1297 time: 0:19:56.986492
1298 time: 0:19:57.874301
1299 time: 0:19:58.758104
1300 time: 0:19:59.646418
1301 time: 0:20:00.530222
1302 time: 0:20:01.419031
1303 time: 0:20:02.303340
1304 time: 0:20:03.190650
1305 time: 0:20:04.074647
1306 time: 0:20:04.962961
1307 time: 0:20:05.848767
1308 time: 0:20:06.732571
1309 time: 0:20:07.618886
1310 time: 0:20:08.503193
1311 time: 0:20:09.388503
1312 time: 0:20:10.273769
1313 time: 0:20:11.162083
1314 time: 0:20:12.045888
1315 time: 0:20:12.931821
1316 time: 0:20:13.820134
1317 time: 0:20:14.702938
1318 time: 0:20:15.592253
1319 time: 0:20:16.477058
1320 time: 0:20:17.368870
1321 time: 0:20:18.251673
1322 time: 0:20:19.138986
1323 time: 0:20:20.026794
1324 time: 0:20:20.911905
1325 time: 0:20:21.799713
1326 time: 0:20:22.685519
1327 time: 0:20:23.574832
1328 time: 0

1605 time: 0:24:30.118448
1606 time: 0:24:31.001241
1607 time: 0:24:31.888907
1608 time: 0:24:32.773216
1609 time: 0:24:33.657525
1610 time: 0:24:34.539832
1611 time: 0:24:35.426649
1612 time: 0:24:36.306963
1613 time: 0:24:37.190758
1614 time: 0:24:38.072075
1615 time: 0:24:38.956870
1616 time: 0:24:39.842181
1617 time: 0:24:40.722982
1618 time: 0:24:41.610294
1619 time: 0:24:42.493098
1620 time: 0:24:43.376902
1621 time: 0:24:44.258218
1622 time: 0:24:45.142546
1623 time: 0:24:46.023854
1624 time: 0:24:46.907172
1625 time: 0:24:47.790985
1626 time: 0:24:48.673789
1627 time: 0:24:49.557096
1628 time: 0:24:50.438898
1629 time: 0:24:51.325706
1630 time: 0:24:52.208021
1631 time: 0:24:53.093818
1632 time: 0:24:53.979128
1633 time: 0:24:54.862932
1634 time: 0:24:55.747241
1635 time: 0:24:56.628043
1636 time: 0:24:57.512352
1637 time: 0:24:58.395663
1638 time: 0:24:59.280468
1639 time: 0:25:00.160784
1640 time: 0:25:01.044579
1641 time: 0:25:01.928891
1642 time: 0:25:02.812705
1643 time: 0

1921 time: 0:29:09.407731
1922 time: 0:29:10.290039
1923 time: 0:29:11.175846
1924 time: 0:29:12.058152
1925 time: 0:29:12.942957
1926 time: 0:29:13.827266
1927 time: 0:29:14.712578
1928 time: 0:29:15.598384
1929 time: 0:29:16.480691
1930 time: 0:29:17.367650
1931 time: 0:29:18.247955
1932 time: 0:29:19.134267
1933 time: 0:29:20.018577
1934 time: 0:29:20.904889
1935 time: 0:29:21.787692
1936 time: 0:29:22.669503
1937 time: 0:29:23.554308
1938 time: 0:29:24.435615
1939 time: 0:29:25.322422
1940 time: 0:29:26.206740
1941 time: 0:29:27.090535
1942 time: 0:29:27.975341
1943 time: 0:29:28.856647
1944 time: 0:29:29.740956
1945 time: 0:29:30.625267
1946 time: 0:29:31.511073
1947 time: 0:29:32.392380
1948 time: 0:29:33.277185
1949 time: 0:29:34.159997
1950 time: 0:29:35.045804
1951 time: 0:29:35.929608
1952 time: 0:29:36.811916
1953 time: 0:29:37.694720
1954 time: 0:29:38.576026
1955 time: 0:29:39.462336
1956 time: 0:29:40.345644
1957 time: 0:29:41.230449
1958 time: 0:29:42.110754
1959 time: 0

2236 time: 0:33:48.095737
2237 time: 0:33:48.983573
2238 time: 0:33:49.869380
2239 time: 0:33:50.751182
2240 time: 0:33:51.637989
2241 time: 0:33:52.521002
2242 time: 0:33:53.408810
2243 time: 0:33:54.292063
2244 time: 0:33:55.178364
2245 time: 0:33:56.060166
2246 time: 0:33:56.943475
2247 time: 0:33:57.827288
2248 time: 0:33:58.708595
2249 time: 0:33:59.591903
2250 time: 0:34:00.473705
2251 time: 0:34:01.358014
2252 time: 0:34:02.237815
2253 time: 0:34:03.125623
2254 time: 0:34:04.008931
2255 time: 0:34:04.891237
2256 time: 0:34:05.776043
2257 time: 0:34:06.657845
2258 time: 0:34:07.541659
2259 time: 0:34:08.426464
2260 time: 0:34:09.311270
2261 time: 0:34:10.190574
2262 time: 0:34:11.073883
2263 time: 0:34:11.959690
2264 time: 0:34:12.842502
2265 time: 0:34:13.727307
2266 time: 0:34:14.611111
2267 time: 0:34:15.496421
2268 time: 0:34:16.376221
2269 time: 0:34:17.259529
2270 time: 0:34:18.142333
2271 time: 0:34:19.028148
2272 time: 0:34:19.912457
2273 time: 0:34:20.794764
2274 time: 0

2551 time: 0:38:26.720943
2552 time: 0:38:27.605252
2553 time: 0:38:28.487559
2554 time: 0:38:29.373870
2555 time: 0:38:30.255672
2556 time: 0:38:31.141478
2557 time: 0:38:32.025293
2558 time: 0:38:32.909097
2559 time: 0:38:33.795411
2560 time: 0:38:34.680216
2561 time: 0:38:35.567023
2562 time: 0:38:36.449331
2563 time: 0:38:37.334640
2564 time: 0:38:38.216443
2565 time: 0:38:39.103756
2566 time: 0:38:39.988065
2567 time: 0:38:40.869867
2568 time: 0:38:41.755178
2569 time: 0:38:42.636981
2570 time: 0:38:43.524291
2571 time: 0:38:44.406116
2572 time: 0:38:45.291425
2573 time: 0:38:46.174228
2574 time: 0:38:47.058539
2575 time: 0:38:47.943848
2576 time: 0:38:48.829158
2577 time: 0:38:49.714964
2578 time: 0:38:50.600771
2579 time: 0:38:51.486080
2580 time: 0:38:52.369389
2581 time: 0:38:53.254195
2582 time: 0:38:54.136502
2583 time: 0:38:55.022308
2584 time: 0:38:55.907113
2585 time: 0:38:56.789917
2586 time: 0:38:57.676228
2587 time: 0:38:58.557029
2588 time: 0:38:59.442845
2589 time: 0

2867 time: 0:43:06.206425
2868 time: 0:43:07.095740
2869 time: 0:43:07.984052
2870 time: 0:43:08.866856
2871 time: 0:43:09.752662
2872 time: 0:43:10.635969
2873 time: 0:43:11.521279
2874 time: 0:43:12.409175
2875 time: 0:43:13.293980
2876 time: 0:43:14.176296
2877 time: 0:43:15.062093
2878 time: 0:43:15.948404
2879 time: 0:43:16.834210
2880 time: 0:43:17.719016
2881 time: 0:43:18.603325
2882 time: 0:43:19.490132
2883 time: 0:43:20.372439
2884 time: 0:43:21.258245
2885 time: 0:43:22.139551
2886 time: 0:43:23.029866
2887 time: 0:43:23.915682
2888 time: 0:43:24.799485
2889 time: 0:43:25.685291
2890 time: 0:43:26.568600
2891 time: 0:43:27.454406
2892 time: 0:43:28.335712
2893 time: 0:43:29.222519
2894 time: 0:43:30.105837
2895 time: 0:43:30.989631
2896 time: 0:43:31.874941
2897 time: 0:43:32.757745
2898 time: 0:43:33.644057
2899 time: 0:43:34.525868
2900 time: 0:43:35.412676
2901 time: 0:43:36.294983
2902 time: 0:43:37.179789
2903 time: 0:43:38.063096
2904 time: 0:43:38.950904
2905 time: 0