In [None]:
# import libs
import tensorflow as tf
from tensorflow import keras
import numpy as np
import cv2
import matplotlib.pyplot as plt
from tensorflow.keras.applications import VGG19
from tensorflow.keras.losses import BinaryCrossentropy

In [None]:
input_shape = (32,32,3)
input_shape_disc= (128,128,3)
## generator model
def residualBlock(input_):
  x = keras.layers.Conv2D(64,kernel_size=3,strides=1,padding='same')(input_)
  x = keras.layers.BatchNormalization()(x)
  x = keras.layers.PReLU(shared_axes=[1, 2])(x)
  x = keras.layers.Conv2D(64,kernel_size=3,strides=1,padding='same')(x)
  x = keras.layers.BatchNormalization()(x)
  output = keras.layers.Add()([x,input_])
  return output

def pixelShuffle(scale):
  return lambda x: tf.nn.depth_to_space(x,scale)

def generator(residual_numbers=16):
  inputs = keras.layers.Input(shape=input_shape)
  x = keras.layers.Conv2D(64,kernel_size=9,strides=1,padding='same')(inputs)
  x = keras.layers.PReLU(shared_axes=[1, 2])(x)
  skip_connection = x

  for _ in range(residual_numbers):
   x = residualBlock(x)

  x = keras.layers.Conv2D(64,kernel_size=3,strides=1,padding='same')(x)
  x = keras.layers.BatchNormalization()(x)
  x = keras.layers.Add()([x,skip_connection])

  x = keras.layers.Conv2D(256,kernel_size=3,strides=1,padding='same')(x)
  x = keras.layers.Lambda(pixelShuffle(2))(x)
  x = keras.layers.PReLU(shared_axes=[1, 2])(x)

  x = keras.layers.Conv2D(256,kernel_size=3,strides=1,padding='same')(x)
  x = keras.layers.Lambda(pixelShuffle(2))(x)
  x = keras.layers.PReLU(shared_axes=[1, 2])(x)

  outputs = keras.layers.Conv2D(3,kernel_size=9,strides=1,padding='same',activation='tanh')(x)

  model = keras.Model(inputs=inputs,outputs=outputs)

  return model

In [None]:
## discriminator model
def block(input_,filter_numbers=128,strides=1):
  x = keras.layers.Conv2D(filters=filter_numbers,kernel_size=3,strides=strides,padding='same')(input_)
  x = keras.layers.BatchNormalization()(x)
  x = keras.layers.LeakyReLU(alpha=0.2)(x)
  return x

def discriminator():
  inputs = keras.layers.Input(shape=input_shape_disc)
  x = keras.layers.Conv2D(64,kernel_size=3,strides=1,padding='same')(inputs)
  x = keras.layers.LeakyReLU(alpha=0.2)(x)

  x = block(x,64,2)
  x = block(x,128,1)
  x = block(x,128,2)
  x = block(x,256,1)
  x = block(x,256,2)
  x = block(x,512,1)
  x = block(x,512,2)

  x = keras.layers.Flatten()(x)
  x = keras.layers.Dense(1024)(x)

  x = keras.layers.LeakyReLU(alpha=0.2)(x)
  outputs = keras.layers.Dense(1,activation='sigmoid')(x)

  model = keras.Model(inputs=inputs,outputs=outputs)
  return model

In [None]:
## training class
class SRGAN(keras.Model):
  def __init__(self,discriminator,generator):
     super().__init__()
     self.discriminator = discriminator
     self.generator = generator
     self.d_loss_tracker = keras.metrics.Mean(name="d_loss")
     self.g_loss_tracker = keras.metrics.Mean(name="g_loss")
     self.binaryCrossentropy = BinaryCrossentropy(from_logits=False)

  def compile(self,d_optimizor,g_optimizor):
     super().compile()
     self.d_optimizor = d_optimizor
     self.g_optimizor = g_optimizor

  def content_loss(self,hr_image,sr_image):
    vgg = VGG19(include_top=False,weights='imagenet')
    model = keras.Model(inputs=vgg.input,outputs=vgg.get_layer('block5_conv4').output)

    hr_features = model(keras.applications.vgg19.preprocess_input(hr_image))
    sr_features = model(keras.applications.vgg19.preprocess_input(sr_image))

    content_loss = tf.reduce_mean(tf.square(hr_features - sr_features))
    return content_loss

  def adversarial_loss(self,disc_output):
    return self.binaryCrossentropy(tf.ones_like(disc_output),disc_output)

  def discriminator_loss(self,disc_real_output,disc_fake_output):
    real_loss = self.binaryCrossentropy(tf.ones_like(disc_real_output),disc_real_output)
    fake_loss = self.binaryCrossentropy(tf.zeros_like(disc_fake_output),disc_fake_output)
    return real_loss + fake_loss

  def generator_loss(self,disc_fake_output,hr_image,sr_image):
    content_loss = self.content_loss(hr_image,sr_image)
    adversarial_loss = self.adversarial_loss(disc_fake_output)
    return content_loss+1e-3*adversarial_loss

  #@tf.function
  def train_step(self,data):
     lr_images,hr_images = data
     with tf.GradientTape() as dTape, tf.GradientTape() as gTape:

       hr_images = tf.image.resize(hr_images,input_shape_disc[:-1])
       sr_images = self.generator(lr_images,training=True)
       disc_fake_output = self.discriminator(sr_images,training=True)
       disc_real_output = self.discriminator(hr_images,training=True)

       generator_loss = self.generator_loss(disc_fake_output,hr_images,sr_images)
       disc_loss = self.discriminator_loss(disc_real_output,disc_fake_output)


     gen_gradient = gTape.gradient(generator_loss,self.generator.trainable_variables)
     disc_gradient = dTape.gradient(disc_loss,self.discriminator.trainable_variables)


     self.g_optimizor.apply_gradients(zip(gen_gradient,self.generator.trainable_variables))
     self.d_optimizor.apply_gradients(zip(disc_gradient,self.discriminator.trainable_variables))

     self.g_loss_tracker.update_state(generator_loss)
     self.d_loss_tracker.update_state(disc_loss)
     return {
         "d_loss": self.d_loss_tracker.result(),
         "g_loss": self.g_loss_tracker.result()
     }
  @property
  def metrics(self):
    return [self.g_loss_tracker,self.d_loss_tracker]

In [None]:
batch_size = 32
epochs = 50

In [None]:
## preprocceing dataset

def downscale_image(image,scale=2):
  image_size = tf.shape(image)[:2]
  new_size = image_size // scale
  lr_image = tf.image.resize(image,new_size,method='bicubic')
  lr_image = tf.image.resize(lr_image,image_size,method='bicubic')
  return lr_image

def preprocess_image(image,hr_size=input_shape[:-1]):
   hr_image = tf.image.resize(image,hr_size)
   hr_image = tf.cast(hr_image,tf.float32)/255.0
   lr_image = downscale_image(hr_image)
   lr_image = tf.image.resize(lr_image,input_shape[:-1])
   return lr_image,hr_image

# load cifar10 dataset
(x_train,_),(x_test,_) = keras.datasets.cifar10.load_data()
train_dataset = tf.data.Dataset.from_tensor_slices(x_train)
test_dataset = tf.data.Dataset.from_tensor_slices(x_test)

train_dataset_srgan = train_dataset.map(lambda image: preprocess_image(image))
train_dataset_srgan = train_dataset_srgan.batch(batch_size).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

for lr_image,hr_image in train_dataset_srgan.take(1):
  print("low batch size shape:",lr_image.shape)
  print("high batch size shape:",hr_image.shape)

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
[1m170498071/170498071[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 0us/step
low batch size shape: (32, 32, 32, 3)
high batch size shape: (32, 32, 32, 3)


In [None]:
#Instantiate Model
generator = generator()
discriminator = discriminator()
srgan = SRGAN(discriminator,generator)
srgan.compile(
    d_optimizor = keras.optimizers.Adam(learning_rate=0.0003),
    g_optimizor = keras.optimizers.Adam(learning_rate=0.0003)
)
discriminator.summary()



In [None]:
# callbacks

from google.colab import drive
drive.mount('/content/drive')

checkpoint_dir = '/content/drive/MyDrive/srgan_checkpoints'
callbacks = [
    keras.callbacks.ModelCheckpoint(
        filepath = checkpoint_dir+ '/srgan_{epoch:02d}.keras',
        monitor='g_loss',
        mode='min',
        save_best_only=True,
        save_weights_only=True
    ),
    keras.callbacks.EarlyStopping(
        monitor='g_loss',
        mode='min',
        patience=150,
        restore_best_weights=True
    )
]
callbackList = keras.callbacks.CallbackList(callbacks)
callbackList.set_model(srgan)
callbackList.set_params({
    'epochs': epochs,
    'steps': batch_size,
    'verbose':1
})

Mounted at /content/drive


In [None]:
## training loop - without fit
import datetime

latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
if latest_checkpoint:
  srgan.load_weights(latest_checkpoint)
for epoch in range(epochs):
  print(f"Epoch {epoch+1}/{epochs}, time: {datetime.datetime.now()}")
  callbackList.on_epoch_begin(epoch)
  for batch,(lr_images,hr_images) in enumerate(train_dataset_srgan.take(100)):
    callbackList.on_batch_begin(batch)
    losses = srgan.train_step((lr_images,hr_images))
    callbackList.on_batch_end(batch,logs=losses)
    if batch%10==0:
      print(f"Batch {batch}, Disc Loss: {losses['d_loss']} , Gen Loss: {losses['g_loss']}, time: {datetime.datetime.now()}")

  callbackList.on_epoch_end(epoch,logs=losses)
callbackList.on_train_end()

Epoch 1/50, time: 2024-09-28 06:18:34.979133
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg19/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5
[1m80134624/80134624[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 0us/step
Batch 0, Disc Loss: 1.5423977375030518 , Gen Loss: 0.0012898046988993883, time: 2024-09-28 06:18:53.077561
Batch 10, Disc Loss: 2.661508798599243 , Gen Loss: 0.040615107864141464, time: 2024-09-28 06:19:16.652437
Batch 20, Disc Loss: 2.2686097621917725 , Gen Loss: 0.03400161862373352, time: 2024-09-28 06:19:40.371244
Batch 30, Disc Loss: 3.805115222930908 , Gen Loss: 0.03214787319302559, time: 2024-09-28 06:20:04.018296
Batch 40, Disc Loss: 3.7624447345733643 , Gen Loss: 0.026745760813355446, time: 2024-09-28 06:20:27.862250
Batch 50, Disc Loss: 3.1052513122558594 , Gen Loss: 0.02433622069656849, time: 2024-09-28 06:20:51.752935
Batch 60, Disc Loss: 2.662524700164795 , Gen Loss: 0.022574448958039284, time: 2024-09-28 06:2

  return saving_lib.save_model(model, filepath)


Epoch 2/50, time: 2024-09-28 06:22:53.485026
Batch 0, Disc Loss: 1.802703857421875 , Gen Loss: 0.01972830295562744, time: 2024-09-28 06:22:56.049227
Batch 10, Disc Loss: 1.6517943143844604 , Gen Loss: 0.018953239545226097, time: 2024-09-28 06:23:20.754145
Batch 20, Disc Loss: 1.5161113739013672 , Gen Loss: 0.018829479813575745, time: 2024-09-28 06:23:44.890361
Batch 30, Disc Loss: 1.4110040664672852 , Gen Loss: 0.019493063911795616, time: 2024-09-28 06:24:08.904911
Batch 40, Disc Loss: 1.3152350187301636 , Gen Loss: 0.0188888106495142, time: 2024-09-28 06:24:32.909126
Batch 50, Disc Loss: 1.231011152267456 , Gen Loss: 0.018261725082993507, time: 2024-09-28 06:24:56.945845
Batch 60, Disc Loss: 1.1598132848739624 , Gen Loss: 0.017635570839047432, time: 2024-09-28 06:25:21.210705
Batch 70, Disc Loss: 1.0946106910705566 , Gen Loss: 0.01710994727909565, time: 2024-09-28 06:25:45.195347
Batch 80, Disc Loss: 1.0358434915542603 , Gen Loss: 0.016650313511490822, time: 2024-09-28 06:26:09.139197

In [None]:
image = cv2.imread("rose.jpg")
image = image/255.0
image = tf.image.resize(image,(32,32))
#image = downscale_image(image)
image = tf.expand_dims(image,axis=0)
plt.imshow(image[0])

In [None]:
image = generator.predict(image)
plt.imshow(image[0])