<a href="https://colab.research.google.com/github/radanim/HM-AIRS-PROGRAM/blob/master/SRGAN_DIV2K.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
import tensorflow as tf

In [None]:
train, valid = tfds.load(
    "div2k/bicubic_x4",
    split=['train','validation'],
    as_supervised=True)

def preprocessing(lr,hr):
    hr=tf.cast(hr,tf.float32)/255
    # 크기가 큰 이미지는 (96,96,3)의 크기로 임의 영역을 잘라내서 사용
    hr_patch=tf.image.random_crop(hr,size=[96,96,3])

# 잘라낸 고해상도 이미지의 가로, 세로 픽셀 수를 1/4배로 줄임
# 이렇게 만든 저해상도 이미지를 SRGAN 모델의 입력으로 사용
    lr_patch=tf.image.resize(hr_patch, [96//4,96//4],"bicubic")
    return lr_patch, hr_patch

train=train.map(preprocessing).shuffle(buffer_size=10).repeat().batch(8)
valid=valid.map(preprocessing).repeat().batch(8)

Downloading and preparing dataset 3.97 GiB (download: 3.97 GiB, generated: Unknown size, total: 3.97 GiB) to ~/tensorflow_datasets/div2k/bicubic_x4/2.0.0...
EXTRACTING {'train_lr_url': 'https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_bicubic_X4.zip', 'valid_lr_url': 'https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_bicubic_X4.zip', 'train_hr_url': 'https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip', 'valid_hr_url': 'https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip'}


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/2 [00:00<?, ? splits/s]

Generating train examples...:   0%|          | 0/800 [00:00<?, ? examples/s]

Shuffling ~/tensorflow_datasets/div2k/bicubic_x4/2.0.0.incompleteVPOAEN/div2k-train.tfrecord*...:   0%|       …

In [None]:
# k9n64s1: conv layer 내의 hyperparameter 설정 정보
# k: kerner size
# n: 사용한 필터수
# s: stride
# stride 1인 conv layer은 패딩을 통해 출력의 크기를 계속 유지 
# SRGAN은 Sequential API로 구현할 수 없고, Functional API를 통해서만 구현이 가능 

In [None]:
from tensorflow.keras import Input, Model, layers

# 파란색 블록을 정의
def gene_base_block(x):
    out=layers.Conv2D(64,3,1,"same")(x)
    out=layers.BatchNormalization()(out)
    out=layers.PReLU(shared_axes=[1,2])(out)
    out=layers.Conv2D(64,3,1,"same")(x)
    out=layers.BatchNormalization()(out)
    return layers.Add()([x,out])

In [None]:
# 뒤쪽 연두색 블록을 정의
def upsample_block(x):
    out=layers.Conv2D(256,3,1,"same")(x)
    # 그림의 PixelShuffer라고 쓰여진 부분을 구현
    out=layers.Lambda(lambda x:tf.nn.depth_to_space(x,2))(out)
    return layers.PReLU(shared_axes=[1,2])(out)

In [None]:
# Generator 정의
def get_generator(input_shape=(None,None,3)):
    inputs = Input(input_shape)
    
    out=layers.Conv2D(64,9,1,"same")(inputs)
    out = residual = layers.PReLU(shared_axes=[1,2])(out)

    for _ in range(5):
        out = gene_base_block(out)

    out = layers.Conv2D(64, 3, 1, "same")(out)
    out = layers.BatchNormalization()(out)
    out = layers.Add()([residual, out])

    for _ in range(2):
        out = upsample_block(out)
    
    out = layers.Conv2D(3, 9, 1, "same", activation='tanh')(out)
    return Model(inputs, out)
    

In [None]:
# 생성 고해상도 이미지와 원본 고해상도 이미지 사이를 판별해내는 Discriminator 구현 
# Generator처럼 Functional API 사용 

# 그림의 파란색 블록을 정의합니다.
def disc_base_block(x, n_filters=128):
    out = layers.Conv2D(n_filters, 3, 1, "same")(x)
    out = layers.BatchNormalization()(out)
    out = layers.LeakyReLU()(out)
    out = layers.Conv2D(n_filters, 3, 2, "same")(out)
    out = layers.BatchNormalization()(out)
    return layers.LeakyReLU()(out)

# 전체 Discriminator 정의합니다.
def get_discriminator(input_shape=(None, None, 3)):
    inputs = Input(input_shape)
    global n_filters
    out = layers.Conv2D(n_filters, 3, 1, "same")(inputs)
    out = layers.LeakyReLU()(out)
    out = layers.Conv2D(64, 3, 2, "same")(out)
    out = layers.BatchNormalization()(out)
    out = layers.LeakyReLU()(out)

    for n_filters in [128, 256, 512]:
        out = disc_base_block(out, n_filters)

    out = layers.Dense(1024)(out)
    out = layers.LeakyReLU()(out)
    out = layers.Dense(1, activation="sigmoid")(out)
    return Model(inputs, out)

In [None]:
# SRGAN은 VGG19로 content loss 계산, 텐서플로에서 이미지넷 데이터에서 VGG19 제공 

from tensorflow.keras.applications import VGG19
def get_feature_extractor(input_shape=(None, None, 3)):
    vgg = applications.vgg19.VGG19(
        include_top = False,
        weights = "imagenet",
        input_shape=input_shape
    )
    # 아래 vgg.layers[20]은 vgg 내의 마지막 conv layer입니다.
    return Model(vgg.input, vgg.layers[20].output)

In [None]:
!pip install --upgrade tensorflow

In [None]:
!pip install --upgrade keras

In [None]:
import keras
print(tf.__version__)
print(keras.__version__)

In [None]:
# SRGAN 학습

from tensorflow.keras import losses, metrics, optimizers

generator = get_generator()
discriminator = get_discriminator()
vgg = get_feature_extractor()

# 사용할 loss function 및 optimizer를 정의합니다.
bce = losses.BinaryCrossentropy(from_logits=False)
mse = losses.MeanSquareError()
gene_opt = optimizers.Adam()
disc_opt = optimizers.Adam()

def get_gene_loss(fake_out):
    return bce(tf.ones_lie(real_out), real_out) + bce(tf.zeros_like(fake_out), fake_out)

@tf.function
def get_content_losS(hr_real, hr_fake):
    hr_real = applications.vgg19.preprocess_input(hr_real)
    hr_fake = applications.vgg19.preprocess_input(hr_fake)

    hr_real_feature = vgg(hr_real) / 12.75
    hr_fake_feature = vgg(hr_fake) / 12.75
    return mse(hr_real_feature, hr_fake_feature)

@tf.function
def step(lr, hr_real):
    with tf.GradientTape() as gene_tape, tf.GradientTape() as disc_tape:
        hr_fake = generator(lr, training=True)

        real_out = discriminator(hr_real, training=True)
        fake_out = discriminator(hr_fake, training=True)

        perceptual_loss = get_content_losS(hr_real, hr_fake) + 1e-3 * get_gene_loss(fake_out)
        discriminator_loss = get_disc_loss(real_out, fake_out)

    gene_opt.apply_gradients(zip(gene_gradient, generator.trainable_variables))
    gene_opt.apply_gradients(zip(disc_gradient, discriminator.trainable_variables))
    return perceptual_loss, discriminator_loss

gene_losses = metrics.Mean()
disc_losses = metrics.Mean()

for epoch in range(1, 2):
    for i, (lr, hr) in enumerate(train):
        g_loss, d_loss = step(lr, hr)

        gene_losses.update_state(g_loss)
        disc_losses.update_state(d_loss)

        # 10회 반복마다 loss 출력하기
        if (i+1) % 10 == 0:
            print(f"EPOCH[{epoch}] - STEP[{i+1}] \nGenerator_loss:{gene_losses.result():.4f} \nDiscriminator_loss:{disc_losses.result():.4f}", end="\n\n")

        if (i+1) == 200:
            break
    
    gene_losses.reset_states()
    disc_losses.reset_states()

In [None]:
# SRGAN 테스트 
# 테스트에서는 Generator만 이용함.
# Generator의 역할은 저해상도 입력을 이용해서 고해상도 이미지를 출력 

import numpy as np

def apply_srgan(image):
    image = tf.cast(image[np.newaxis, ...], tf.float32)
    sr = srgan.predict(image)
    sr = tf.clip_by_value(sr, o, 255)
    sr = tf.round(sr)
    sr = tf.cast(sr, tf.unit8)
    return np.array(sr)[0]

train, valid = tfds.load(
    "div2k/bicubic_x4",
    split=["train", "validation"],
    as_supervised=True
)

for i, (lr, hr) in enumerate(valid):
    if i == 6: break

srgan_hr = apply_srgan(lr)

In [None]:
# 이미지 전체를 시각화하면 세부적인 것의 비교가 어려우므로, 일부 영역만 잘라서 비교
# bicubic interpolation vs SRGAN vs ORIGIN 비교
# 자세한 시각화를 위해 3개 영역 잘라내기
# 잘라낸 부분의 좌상단 좌표 3개 

left_tops = [(400,500), (300,1200), (0,1000)]

images = []
for left_top in left_tops:
    img1 = crop(bicubic_hr, left_top, 200, 200)
    img2 = crop(srgan_hr, left_top, 200, 200)
    img3 = crop(hr, left_top, 200, 200)
    images.extend([img1, img2, img3])

labels = ["Bicubic", "SRGAN", "HR"] * 3

plt.figure(figsize=(18,18))
for i in range(9):
    plt.subplot(3,3,i+1)
    plt.imshow(images[i])
    plt.title(labels[i], fontsize=30)
    
# SRGAN은 SRCNN보다 깊은 conv layer 사용, GAN과 VGG 구조를 활용하여 손실함수 사용, 복잡한 학습과정 구성 