In [1]:
from __future__ import print_function, division
import os
import numpy as np
import matplotlib.pyplot as plt
from keras.layers import *
from keras.layers import Conv2D, UpSampling2D
from keras.layers import MaxPooling2D
from keras import *
from keras.models import Model,Sequential
from keras.utils import to_categorical
import keras.backend as k

In [2]:
from keras.optimizers import legacy

In [3]:
from keras.datasets import cifar10


In [4]:
from keras.applications import VGG19
import sys
import datetime
from keras.optimizers import legacy

In [5]:
from keras.preprocessing import image
from keras.preprocessing.image import img_to_array
import tensorflow as tf

In [6]:
from keras.layers import *

In [15]:
!mkdir Images2

In [16]:
!mkdir saved_models2

In [9]:
pip install scipy



In [10]:
from keras.datasets import cifar10

In [11]:
from skimage.metrics import peak_signal_noise_ratio as psnr

In [12]:
from re import I
import scipy
from glob import  glob
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os
#from skimage.transform import resize
class DataLoader():
  def __init__(self,dataset_name, img_res = (128,128)):
    self.dataset_name = dataset_name
    self.img_res = img_res
    #self.image_dir = 'Images'
  def load_data(self,batch_size=1,is_testing = False):
    (x,y),(_,_) = cifar10.load_data()
    data_type = 'train' if not is_testing else 'test'
    batch_images = np.random.choice(range(x.shape[0]),size = batch_size)
    imgs_hr = []
    imgs_lr = []
    for img_index in batch_images:
      img= x[img_index,:,:,:]
      h,w = self.img_res
      low_h, low_w = int(h/4),int(w/4)
      img_pil1 =Image.fromarray(img)
      resized_img_pil1 = img_pil1.resize((low_h,low_w))
      img_lr= np.array(resized_img_pil1)
      img_pil2 = Image.fromarray(img)
      resized_img_pil2 = img_pil2.resize((self.img_res))
      img_hr = np.array(resized_img_pil2)


      if not is_testing and np.random.random()<0.5:
        img_lr = np.fliplr(img_lr)
        img_hr = np.fliplr(img_hr)

      imgs_hr.append(img_hr)
      imgs_lr.append(img_lr)
    imgs_hr = np.array(imgs_hr)/127.5 -1
    imgs_lr = np.array(imgs_lr)/127.5 -1

    return (imgs_hr,imgs_lr)

In [17]:
class SuperRes_GAN():
  def __init__(self):
    self.channels =3
    self.lr_height = 64
    self.lr_width = 64
    self.lr_shape = (self.lr_height,self.lr_width,self.channels)

    self.hr_height = self.lr_height*4
    self.hr_width = self.lr_width*4
    self.hr_shape = (self.hr_height,self.hr_width,self.channels)

    self.n_residual_blocks = 16
    optimizer = legacy.Adam(0.0002,0.5)

    self.vgg = self.build_vgg()
    self.vgg.trainable = False
    self.vgg.compile(loss='mse',optimizer=optimizer,metrics=['accuracy'])

    self.dataset_name = 'cifar_dataset'
    self.data_loader = DataLoader(dataset_name=self.dataset_name,img_res=(self.hr_height,self.hr_width))

    patch = int(self.hr_height/2 **4)
    self.disc_patch=(patch,patch,1)
    self.gf = 64
    self.df = 64
    self.generator = self.build_generator()
    self.discriminator = self.build_discriminator()
    self.discriminator.compile(loss='mse',optimizer=optimizer,metrics=['accuracy'])

    img_hr = Input(shape=self.hr_shape)
    img_lr = Input(shape=self.lr_shape)

    fake_hr = self.generator(img_lr)
    fake_features = self.vgg(fake_hr)
    self.discriminator.trainable = False
    validity = self.discriminator(fake_hr)
    self.combined = Model(inputs=[img_lr,img_hr],outputs=[validity,fake_features])
    self.combined.compile(loss=['binary_crossentropy','mse'],loss_weights=[1e-3,1],optimizer=optimizer)

  def build_vgg(self):
    vgg = VGG19(weights='imagenet', include_top = False,input_shape =(256,256,3))


    img = Input(shape=(self.hr_shape))


    img_features = vgg(img)
    return Model(img,img_features)


  def build_generator(self):
    def residual_block(layer_input,filters):
      d = Conv2D(filters,kernel_size=3,strides=1,padding='same')(layer_input)
      d = Activation('relu')(d)
      d = BatchNormalization(momentum=0.8)(d)
      d = Conv2D(filters,kernel_size=3,strides=1,padding='same')(d)
      d = Add()([d,layer_input])
      return d
    def deconv2d(layer_input):
      u = UpSampling2D(size=2)(layer_input)
      u = Conv2D(256,kernel_size=3,strides=1,padding='same')(u)
      u= Activation('relu')(u)
      return u

    img_lr = Input(shape=self.lr_shape)
    c1=Conv2D(64,kernel_size=9,strides=1,padding='same')(img_lr)
    c1= Activation('relu')(c1)

    r = residual_block(c1,self.gf)
    for _ in range(self.n_residual_blocks-1):
      r = residual_block(r,self.gf)

    c2 = Conv2D(64,kernel_size=3,strides=1,padding='same')(r)
    c2 = BatchNormalization(momentum=0.8)(c2)
    c=Add()([c2,c1])

    u1 = deconv2d(c2)
    u2 = deconv2d(u1)


    gen_hr = Conv2D(self.channels,kernel_size=9,strides=1,padding='same',activation='tanh')(u2)

    return Model(img_lr,gen_hr)

  def build_discriminator(self):
    def d_block(layer_input, filters, strides=1, bn=True):
      d = Conv2D(filters, kernel_size=3, strides=strides, padding='same')(layer_input)
      d = LeakyReLU(alpha=0.2)(d)
      if bn:
        d = BatchNormalization(momentum=0.8)(d)
      return d

    d0 = Input(shape=self.hr_shape)
    d1 = d_block(d0, self.df, bn=False)
    d2 = d_block(d1, self.df, strides=2)
    d3 = d_block(d2, self.df*2)
    d4 = d_block(d3, self.df*2, strides=2)
    d5 = d_block(d4, self.df*4)
    d6 = d_block(d5, self.df*4, strides=2)
    d7 = d_block(d6, self.df*8)
    d8 = d_block(d7, self.df*8, strides=2)
    d9 = Dense(self.df*16)(d8)
    d10 = LeakyReLU(alpha=0.2)(d9)
    validity = Dense(1, activation='sigmoid')(d10)
    return Model(d0, validity)

  def train(self,epochs,batch_size=1,sample_interval=50):
    start_time = datetime.datetime.now()

    for epoch in range(epochs):

      imgs_hr,imgs_lr = self.data_loader.load_data(batch_size)

      fake_hr = self.generator.predict(imgs_lr)

      valid = np.ones((batch_size,)+self.disc_patch)
      fake = np.zeros((batch_size,)+self.disc_patch)

      d_loss_real = self.discriminator.train_on_batch(imgs_hr,valid)
      d_loss_fake = self.discriminator.train_on_batch(fake_hr,fake)
      d_loss = 0.5 * np.add(d_loss_real,d_loss_fake)

      imgs_hr,imgs_lr = self.data_loader.load_data(batch_size)

      valid = np.ones((batch_size,)+self.disc_patch)

      image_features = self.vgg.predict(imgs_hr)

      g_loss = self.combined.train_on_batch([imgs_lr,imgs_hr],[valid,image_features])
      psnr_val = np.mean([psnr(imgs_hr[i],fake_hr[i]) for i in range(batch_size)])


      elapsed_time = datetime.datetime.now() - start_time

      print('%d time: %s | PSNR: %.2f'%(epoch,elapsed_time, psnr_val))

      if epoch % sample_interval == 0:
        self.sample_images(epoch)


  def sample_images(self,epoch):
    r,c = 2,2
    os.makedirs('Images2/%s'%self.dataset_name, exist_ok = True)
    imgs_hr,imgs_lr = self.data_loader.load_data(batch_size=2, is_testing = True)
    fake_hr = self.generator.predict(imgs_lr)

    imgs_lr = 0.5 * imgs_lr + 0.5
    fake_hr = 0.5 * fake_hr + 0.5
    imgs_hr = 0.5 * imgs_hr + 0.5

    titles = ["generated", "original"]
    fig, axs = plt.subplots(r,c)
    cnt = 0
    for row in range(r):
      for col,image in enumerate([fake_hr,imgs_hr]):
        axs[row,col].imshow(image[row])
        axs[row,col].set_title(titles[col])
        axs[row,col].axis('off')
      cnt += 1
    plt.savefig("Images2/%s/%d.png"%(self.dataset_name,epoch))
    plt.close()

    for i in range(r):
      fig = plt.figure()
      plt.imshow(imgs_lr[i])
      fig.savefig('Images2/%s/%d_%d.png'%(self.dataset_name,epoch,i))
      plt.close()
  def save_model(self):
    def save(model, model_name):
      model_path = "saved_models2/%s.json" % model_name
      weights_path = "saved_models2/%s_weights.hdf5" % model_name
      options = {"file_arch": model_path,
                 "file_weight": weights_path}
      json_string = model.to_json()
      open(options['file_arch'], 'w').write(json_string)
      model.save_weights(options['file_weight'])
    save(self.generator, "generator")
    save(self.discriminator, "discriminator")


In [18]:
if __name__ == '__main__':
  gan = SuperRes_GAN()
  gan.train(epochs=2601,batch_size=1,sample_interval=200)
  gan.save_model()



  psnr_val = np.mean([psnr(imgs_hr[i],fake_hr[i]) for i in range(batch_size)])


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
937 time: 0:30:52.270743 | PSNR: 11.15
938 time: 0:30:54.272028 | PSNR: 9.38
939 time: 0:30:56.265326 | PSNR: 6.58
940 time: 0:30:58.213422 | PSNR: 12.53
941 time: 0:31:00.123729 | PSNR: 9.99
942 time: 0:31:02.218994 | PSNR: 4.59
943 time: 0:31:04.135968 | PSNR: 5.09
944 time: 0:31:06.126755 | PSNR: 10.41
945 time: 0:31:08.114074 | PSNR: 14.38
946 time: 0:31:10.126708 | PSNR: 6.64
947 time: 0:31:12.091737 | PSNR: 7.86
948 time: 0:31:14.037577 | PSNR: 9.82
949 time: 0:31:16.047997 | PSNR: 12.66
950 time: 0:31:17.980333 | PSNR: 7.76
951 time: 0:31:19.961772 | PSNR: 8.18
952 time: 0:31:21.908662 | PSNR: 7.38
953 time: 0:31:23.828573 | PSNR: 7.95
954 time: 0:31:25.812708 | PSNR: 13.41
955 time: 0:31:27.807971 | PSNR: 12.02
956 time: 0:31:29.806394 | PSNR: 11.28
957 time: 0:31:31.806028 | PSNR: 12.23
958 time: 0:31:33.798130 | PSNR: 12.78
959 time: 0:31:35.681770 | PSNR: 8.93
960 time: 0:31:37.664719 | PSNR: 10.08
961 time: 0: