In [0]:
pip install tensorview

In [0]:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pylab as plt
from math import ceil
import numpy as np
import argparse
from functools import partial
import os
import tensorview as tv
import cv2
import tensorflow as tf
import matplotlib.pyplot as  plt

In [0]:
!wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip

In [0]:
!wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_bicubic_X2.zip

In [0]:
!unzip -q DIV2K_valid_HR.zip -d div2k

In [0]:
!unzip -q DIV2K_valid_LR_bicubic_X2.zip -d div2k

In [0]:
%matplotlib inline

In [0]:
import glob,os
img_dir = "div2k/DIV2K_valid_HR" 
data_path = os.path.join(img_dir,'*g')
files = glob.glob(data_path)

In [0]:
data2 = []

In [0]:
data = []

In [0]:
def visualize(original, augmented):
  fig = plt.figure()
  plt.subplot(1,2,1)
  plt.title('Original image')
  plt.imshow(original)
  plt.subplot(1,2,2)
  plt.title('Augmented image')
  plt.imshow(augmented)

In [0]:

for f in files:
  img = cv2.imread(f)
  h,w,c = img.shape
  if h == w:
    img_new = cv2.resize(img, (512, 512), cv2.INTER_AREA)
  if h > w:
    dif = h
  else:
    dif = w
  x_pos = int((dif - w)/2.)
  y_pos = int((dif - h)/2.)
  mask = np.zeros((dif, dif, c), dtype=img.dtype)
  mask[y_pos:y_pos+h, x_pos:x_pos+w, :] = img[:h, :w, :]
  img_new = cv2.resize(mask, (512, 512), cv2.INTER_AREA)
  saturated = tf.image.adjust_saturation(img_new, 5)
  #visualize(img_new, saturated)
  saturated = saturated.numpy()
  data2.append(saturated)
  data.append(img_new)

In [0]:
np.asarray(data).shape

DualGAN


In [0]:
import math
def psnr(img1, img2):
    mse = np.mean( (img1 - img2) ** 2 )
    if mse == 0:
      print("same image")
      return
    PIXEL_MAX = 255.0
    return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))

In [0]:
def wasserstein_loss(y_true, y_pred):
    return tf.keras.backend.mean(y_true * y_pred)

In [0]:
def conv2d_block(input_tensor, n_filters, kernel_size=2, batchnorm=True):
    x = tf.keras.layers.Conv2D(filters=n_filters, kernel_size=(kernel_size, kernel_size),strides = 2, kernel_initializer="he_normal",
               padding="valid")(input_tensor)
    if batchnorm:
        x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.activations.selu(x)
    return x

In [0]:
def get_unet(input_img,n_filters=16, dropout=0.5, batchnorm=True):
    #Contracting path
    input_img = tf.keras.layers.Conv2D(filters=n_filters, kernel_size=(1,1),strides = 1,padding="valid")(input_img)
    c1 = conv2d_block(input_img, n_filters=n_filters*2, kernel_size=2, batchnorm=batchnorm)
    p1 = tf.keras.layers.Dropout(dropout*0.5)(c1)
    c2 = conv2d_block(p1, n_filters=n_filters*4, kernel_size=2, batchnorm=batchnorm)
    p2 = tf.keras.layers.Dropout(dropout)(c2)
    c3 = conv2d_block(p2, n_filters=n_filters*8, kernel_size=2, batchnorm=batchnorm)
    p3 = tf.keras.layers.Dropout(dropout)(c3)
    c4 = conv2d_block(p3, n_filters=n_filters*8, kernel_size=2, batchnorm=batchnorm)
    p4 = tf.keras.layers.Dropout(dropout)(c4)  
    c6 = conv2d_block(p4, n_filters=n_filters*8, kernel_size=2, batchnorm=batchnorm)
    p6 = tf.keras.layers.Dropout(dropout)(c6)
    c7 = conv2d_block(p6, n_filters=n_filters*8, kernel_size=2, batchnorm=batchnorm)
    p7 = tf.keras.layers.Dropout(dropout)(c7)
    c8 = conv2d_block(p7, n_filters=n_filters*8, kernel_size=8, batchnorm=batchnorm)
    p8 = tf.keras.layers.Dropout(dropout)(c8)
    new_p8 = tf.keras.backend.repeat_elements(p8,rep = 32,axis = 1)
    new_p8 = tf.keras.backend.repeat_elements(new_p8,rep = 32,axis = 2)
    out = tf.keras.backend.concatenate((new_p8,p4),axis = -1)
    #Expansive path
    u6 = tf.keras.layers.Conv2DTranspose(n_filters*4, (2, 2), strides=2, padding='valid') (out)
    u6 = tf.keras.backend.concatenate([u6,p3],axis = -1)
    u6 = tf.keras.layers.Dropout(dropout)(u6)
    if batchnorm:
        u6 = tf.keras.layers.BatchNormalization()(u6)
    u6 = tf.keras.activations.selu(u6)   
    u7 = tf.keras.layers.Conv2DTranspose(n_filters*2, (2, 2), strides=2, padding='valid') (u6)
    u7 = tf.keras.backend.concatenate([u7,p2],axis = -1)
    u7 = tf.keras.layers.Dropout(dropout)(u7)
    if batchnorm:
        u7 = tf.keras.layers.BatchNormalization()(u7)
    u7 = tf.keras.activations.selu(u7)    
    u8 = tf.keras.layers.Conv2DTranspose(n_filters, (2, 2), strides=2, padding='valid') (u7)
    u8 = tf.keras.backend.concatenate([u8,p1],axis = -1)
    u8 = tf.keras.layers.Dropout(dropout)(u8)
    if batchnorm:
        u8 = tf.keras.layers.BatchNormalization()(u8)
    u8 = tf.keras.activations.selu(u8)    
    u9 = tf.keras.layers.Conv2DTranspose(n_filters, (2, 2), strides=2, padding='valid') (u8)
    if batchnorm:
        u9 = tf.keras.layers.BatchNormalization()(u9)
    u9 = tf.keras.activations.selu(u9)    
    u10 = tf.keras.layers.Conv2DTranspose(3, (1, 1), strides=1, padding='valid') (u9)
    
    
    return u10

In [0]:
def get_disc(input_img,n_filters = 16,dropout = 0.5):

  input_img = tf.keras.layers.Conv2D(filters=n_filters, kernel_size=(1,1),strides = 1,padding="valid")(input_img)
  c1 = tf.keras.layers.Conv2D(filters=n_filters*2, kernel_size=(2,2),strides = 2,padding="valid")(input_img)
  c1 = tf.keras.layers.LeakyReLU()(c1)
  c2 = tf.keras.layers.Conv2D(filters=n_filters*4, kernel_size=(2,2),strides = 2,padding="valid")(c1)
  c2 = tf.keras.layers.LeakyReLU()(c2)
  c3 = tf.keras.layers.Conv2D(filters=n_filters*8, kernel_size=(2,2),strides = 2,padding="valid")(c2)
  c3 = tf.keras.layers.LeakyReLU()(c3)
  c4 = tf.keras.layers.Conv2D(filters=n_filters*8, kernel_size=(2,2),strides = 2,padding="valid")(c3)
  c4 = tf.keras.layers.LeakyReLU()(c4)
  c5 = tf.keras.layers.Conv2D(filters=n_filters*8, kernel_size=(2,2),strides = 2,padding="valid")(c4)
  c5 = tf.keras.layers.LeakyReLU()(c5)
  c6 = tf.keras.layers.Conv2D(filters=n_filters*8, kernel_size=(2,2),strides = 2,padding="valid")(c5)
  c6 = tf.keras.layers.LeakyReLU()(c6)
  c7 = tf.keras.layers.Flatten()(c6)
  c8 = tf.keras.layers.Dense(512,activation='relu')(c7)
  u1 = tf.keras.layers.Dense(1,activation='sigmoid')(c8)
 
  return u1


In [0]:
inp1 = tf.keras.Input(shape = (512,512,3))
disc_out_1 = get_disc(inp1)
disc_mod_1 = tf.keras.Model(inputs = inp1,outputs = disc_out_1)
disc_mod_1.compile(loss = wasserstein_loss,
                   optimizer = tf.keras.optimizers.Adam(),
                   metrics = ['accuracy'])
inp2 = tf.keras.Input(shape = (512,512,3))
disc_out_2 = get_disc(inp2)
disc_mod_2 = tf.keras.Model(inputs = inp2,outputs = disc_out_2)
disc_mod_2.compile(loss = wasserstein_loss,
                   optimizer = tf.keras.optimizers.Adam(),
                   metrics = ['accuracy'])

In [0]:
image1 = tf.keras.Input(shape = (512,512,3))
gen_out_1 = get_unet(image1)
gen_mod_1 = tf.keras.Model(inputs = image1,outputs = gen_out_1)

In [0]:
image2 = tf.keras.Input(shape = (512,512,3))
gen_out_2 = get_unet(image2)
gen_mod_2 = tf.keras.Model(inputs = image2,outputs = gen_out_2)

In [0]:
image_real_1 = tf.keras.Input(shape = (512,512,3))
image_real_2 = tf.keras.Input(shape = (512,512,3))
image_gen_a = gen_mod_1(image_real_1)
image_gen_b = gen_mod_2(image_real_2)

In [0]:
frozen_a = tf.keras.Model(disc_mod_1.inputs, disc_mod_1.outputs)
frozen_a.trainable = False
frozen_b = tf.keras.Model(disc_mod_2.inputs, disc_mod_2.outputs)
frozen_b.trainable = False

In [0]:
logit_a = frozen_a(image_gen_b)
logit_b = frozen_b(image_gen_a)
recov_b = gen_mod_1(image_gen_b)
recov_a = gen_mod_2(image_gen_a)

In [0]:
two_way_gan = tf.keras.Model([image_real_1,image_real_2],[logit_a,logit_b,recov_a,recov_b])


In [0]:

two_way_gan.compile(loss=[wasserstein_loss, wasserstein_loss, 'mae', 'mae'],
                    optimizer=tf.keras.optimizers.Adam(0.0002, 0.5),
                    loss_weights=[1, 1, 100, 100])


In [0]:
tv_plot = tv.train.PlotMetrics(columns=2, wait_num=5)

In [0]:
train_real = np.asarray(data)
train_sat = np.asarray(data2)

In [0]:
batch_size = 2
epochs = 10
n_critic = 4

In [0]:
valid = -np.ones((batch_size, 1))
fake = np.ones((batch_size, 1))

TRAINING

In [0]:
clip_value = 0.01
for epoch in range(epochs):
  for _ in range(n_critic):
    batch_image_a = train_sat[np.random.choice(range(train_sat.shape[0]), batch_size, False)]
    batch_image_b = train_real[np.random.choice(range(train_real.shape[0]), batch_size, False)]
    batch_image_gen_a = gen_mod_2(batch_image_b)
    batch_image_gen_b = gen_mod_1(batch_image_a)
    d_a_loss_real = disc_mod_1.train_on_batch(batch_image_a,valid)
    d_a_loss_fake = disc_mod_1.train_on_batch(batch_image_gen_a,fake)
    d_b_loss_real = disc_mod_2.train_on_batch(batch_image_b,valid)
    d_b_loss_fake = disc_mod_2.train_on_batch(batch_image_gen_b,fake) 
    d_a_loss = 0.5 * np.add(d_a_loss_real,d_a_loss_fake)
    d_b_loss = 0.5 * np.add(d_b_loss_real,d_b_loss_fake)
    for d in [disc_mod_1, disc_mod_2]:
      for l in d.layers:
        weights = l.get_weights()
        weights = [np.clip(w, -clip_value, clip_value) for w in weights]
        l.set_weights(weights)
  g_loss = two_way_gan.train_on_batch([batch_image_a, batch_image_b], [valid, valid, batch_image_a, batch_image_b])
  tv_plot.update({'D_a_loss': d_a_loss[0], 'D_a_binary_acc': d_a_loss[1],
                    'D_b_loss': d_b_loss[0], 'D_b_binary_acc': d_b_loss[1],
                    'G_a_loss': g_loss[1],  'G_b_loss': g_loss[2]})
  print(g_loss)
  tv_plot.draw()
tv_plot.visual()
tv_plot.visual(name='model_visual_gif', gif=True)    


TESTING

In [0]:
def psnr_test(test_LR,test_HR):#didn't use a testing dataset
  image_b = test_LR
  image_gen_a = gen_mod_2(image_b)
  psnr_result = psnr(np.asarray(image_gen_a),test_HR)
  return psnr_result


In [0]:
psnr_test(train_real,train_sat)