In [6]:
import tensorflow as tf
import tensorflow.keras as keras
import numpy as np
from custom_layer import *
from image_loader import *
import os
import matplotlib.pyplot as plt
import time
from IPython import display
import cv2
import tensorflow_addons as tfa
from utils_from_gp import pbar

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [14]:
class WGANGP:
    def __init__(self):
        self.epochs = 100
        self.batch_size = 3
        self.image_size = 512
        self.total_images = 4500
        self.n_critic = 5
        self.grad_penalty_weight = 10.0
        self.g_opt = tf.keras.optimizers.Adam(1e-4)
        self.d_opt = tf.keras.optimizers.Adam(1e-4)
        self.G, self.g_tl = self.build_generator(keras.Input((512,512,3),batch_size=3))
        self.D, self.d_tl = self.build_discriminator(keras.Input((512,512,3),batch_size=3))
    
    # 함수형 API로 구성한 Generator의 첫번째 그
    def generator_first(self,input_):
        tensor_list = []
    #     input_ = keras.layers.Input(shape=(512, 512, 3),batch_size=3) # (3,512,512,3) 크기의 Input을 생성한다.
        tensor_list.append(input_)
        hidden1 = conv2d_layer_same(16,3,1)(input_)
        tensor_list.append(hidden1)
        act1 = keras.activations.selu(hidden1)
        tensor_list.append(act1)
        bn1 = keras.layers.BatchNormalization()(act1)
        tensor_list.append(bn1)

        hidden2 = conv2d_layer_same(32, 5, 2)(bn1)
        tensor_list.append(hidden2)
        act2 = keras.activations.selu(hidden2)
        tensor_list.append(act2)
        bn2 = keras.layers.BatchNormalization()(act2)
        tensor_list.append(bn2)

        hidden3 = conv2d_layer_same(64, 5, 2)(bn2)
        tensor_list.append(hidden3)
        act3 = keras.activations.selu(hidden3)
        tensor_list.append(act3)
        bn3 = keras.layers.BatchNormalization()(act3)
        tensor_list.append(bn3)

        hidden4 = conv2d_layer_same(128, 5, 2)(bn3)
        tensor_list.append(hidden4)
        act4 = keras.activations.selu(hidden4)
        tensor_list.append(act4)
        bn4 = keras.layers.BatchNormalization()(act4)
        tensor_list.append(bn4)

        hidden5 = conv2d_layer_same(128, 5, 2)(bn4)
        tensor_list.append(hidden5)
        act5 = keras.activations.selu(hidden5)
        tensor_list.append(act5)
        bn5 = keras.layers.BatchNormalization()(act5)

        tensor_list.append(bn5)
        model = keras.Model(inputs=[input_],outputs=bn5)
        return model,tensor_list

    def generator_second(self,input_):
        # 논문 코드의 U-net Generator 두번째 파트에 해당하는 부분
        # 9개의 레이어로 구성되어 있다.
        m1, tensor_list = self.generator_first(input_)
        input_ = tensor_list[-1]

        hidden1 = conv2d_layer_same(128,5,2)(input_)
        tensor_list.append(hidden1)
        act1 = keras.activations.selu(hidden1)
        tensor_list.append(act1)
        bn1 = keras.layers.BatchNormalization()(act1)
        tensor_list.append(bn1)

        hidden2 = conv2d_layer_same(128, 5, 2)(bn1)
        tensor_list.append(hidden2)
        act2 = keras.activations.selu(hidden2)
        tensor_list.append(act2)
        bn2 = keras.layers.BatchNormalization()(act2)
        tensor_list.append(bn2)

        hidden3 = conv2d_layer_valid(128, 8, 1)(bn2)
        tensor_list.append(hidden3)
        act3 = keras.activations.selu(hidden3)
        tensor_list.append(act3)

        hidden4 = conv2d_layer_valid(128, 1, 1)(act3) # global concat에 사용할 image의 global feature 추출
        tensor_list.append(hidden4)

        # model = keras.Model(inputs=input_,outputs=hidden4)
        return tensor_list

    # Unet Generator의 세 번째 레이어 그룹
    # global concat과 그냥 concat, residual block 등 다양한 커스텀 레이어를 제작할 필요가 있다.
    # 30개의 레이어로 구성되어 있음.

    def build_generator(self,input_):
        tensor_list = self.generator_second(input_)
        input_ = tensor_list[15] # input index = 15

        # 1번째 conv 그
        hidden1 = conv2d_layer_same(128,3,1)(input_)
        tensor_list.append(hidden1)
        gc1 = exe_global_concat_layer(hidden1, tensor_list, 24)
        tensor_list.append(gc1)

        # 2번째 conv 그룹
        hidden2 = conv2d_layer_same(128,1,1)(gc1)
        tensor_list.append(hidden2)
        act1 = keras.activations.selu(hidden2)
        tensor_list.append(act1)
        bn1 = keras.layers.BatchNormalization()(act1)
        tensor_list.append(bn1)

        # 3번째 conv 그룹
        hidden3 = conv2d_layer_same(128,3,1)(bn1)
        tensor_list.append(hidden3)
        rs1 = exe_resize_layer(hidden3, 2)
        tensor_list.append(rs1)
        #concat_layer index 10
        concat1 = tf.concat([rs1, tensor_list[10]],axis=-1)
        tensor_list.append(concat1)
        act2 = keras.activations.selu(concat1)
        tensor_list.append(act2)
        bn2 = keras.layers.BatchNormalization()(act2)
        tensor_list.append(bn2)

        # 4번째 conv 그룹
        hidden4 = conv2d_layer_same(128, 3, 1)(bn2)
        tensor_list.append(hidden4)
        rs2 = exe_resize_layer(hidden4, 2)
        tensor_list.append(rs2)
        # concat_layer index 7
        concat2 = tf.concat([rs2, tensor_list[7]], axis=-1)
        tensor_list.append(concat2)
        act3 = keras.activations.selu(concat2)
        tensor_list.append(act3)
        bn3 = keras.layers.BatchNormalization()(act3)
        tensor_list.append(bn3)

        # 5번째 Conv 그룹
        hidden5 = conv2d_layer_same(64, 3, 1)(bn3)
        tensor_list.append(hidden5)
        rs3 = exe_resize_layer(hidden5, 2)
        tensor_list.append(rs3)
        # concat_layer index 4
        concat3 = tf.concat([rs3, tensor_list[4]], axis=-1)
        tensor_list.append(concat3)
        act4 = keras.activations.selu(concat3)
        tensor_list.append(act4)
        bn4 = keras.layers.BatchNormalization()(act4)
        tensor_list.append(bn4)

        # 6번째 Conv 그룹
        hidden6 = conv2d_layer_same(32, 3, 1)(bn4)
        tensor_list.append(hidden6)
        rs4 = exe_resize_layer(hidden6, 2)
        tensor_list.append(rs4)
        # concat_layer index 4
        concat4 = tf.concat([rs4, tensor_list[1]], axis=-1)
        tensor_list.append(concat4)
        act5 = keras.activations.selu(concat4)
        tensor_list.append(act5)
        bn5 = keras.layers.BatchNormalization()(act5)
        tensor_list.append(bn5)

        # 7번째 cONV 그룹
        hidden7 = conv2d_layer_same(16,3,1)(bn5)
        tensor_list.append(hidden7)
        act6 = keras.activations.selu(hidden7)
        tensor_list.append(act6)
        bn6 = keras.layers.BatchNormalization()(act6)
        tensor_list.append(bn6)

        # 8번째 conv 그룹
        hidden8 = conv2d_layer_same(3,3,1)(bn6)
        tensor_list.append(hidden8)
        res = exe_res_layer(hidden8,tensor_list,0,[0,1,2])

        model = keras.Model(inputs=tensor_list[0],outputs=res)
        return model, tensor_list


    def build_discriminator(self,input_):

    #     input_ = keras.layers.Input(shape=(512, 512, 3),batch_size=3) # i
        tensor_list = []
        tensor_list.append(input_)

        # hidden 1
        hidden1 = conv2d_layer_same(16,3,1)(input_)
        tensor_list.append(hidden1)
        act1 = keras.layers.LeakyReLU()(hidden1)
        tensor_list.append(act1)
        bn1 = tfa.layers.InstanceNormalization()(act1)
        tensor_list.append(bn1)

        # hidden 2
        hidden2 = conv2d_layer_same(32, 5, 2)(bn1)
        tensor_list.append(hidden2)
        act2 = keras.layers.LeakyReLU()(hidden2)
        tensor_list.append(act2)
        bn2 = tfa.layers.InstanceNormalization()(act2)
        tensor_list.append(bn2)

        # hidden 3
        hidden3 = conv2d_layer_same(64, 5, 2)(bn2)
        tensor_list.append(hidden3)
        act3 = keras.layers.LeakyReLU()(hidden3)
        tensor_list.append(act3)
        bn3 = tfa.layers.InstanceNormalization()(act3)
        tensor_list.append(bn3)

        # hidden 4
        hidden4 = conv2d_layer_same(128, 5, 2)(bn3)
        tensor_list.append(hidden4)
        act4 = keras.layers.LeakyReLU()(hidden4)
        tensor_list.append(act4)
        bn4 = tfa.layers.InstanceNormalization()(act4)
        tensor_list.append(bn4)

        # hidden 5
        hidden5 = conv2d_layer_same(128, 5, 2)(bn4)
        tensor_list.append(hidden5)
        act5 = keras.layers.LeakyReLU()(hidden5)
        tensor_list.append(act5)
        bn5 = tfa.layers.InstanceNormalization()(act5)
        tensor_list.append(bn5)

        # hidden 6
        hidden6 = conv2d_layer_same(128, 5, 2)(bn5)
        tensor_list.append(hidden6)
        act6 = keras.layers.LeakyReLU()(hidden6)
        tensor_list.append(act6)
        bn6 = tfa.layers.InstanceNormalization()(act6)
        tensor_list.append(bn6)

        # hidden 7
        hidden7 = conv2d_layer_valid(1,1,16)(bn6)
        tensor_list.append(hidden7)
        # reduce mean layer, 짧기 때문에 직접 작성하였음
        rml = tf.reduce_mean(hidden7, [1,2,3],keepdims=False) # 파라미터는 순서대로 tensor, axis, keepdims
        tensor_list.append(rml)

        model = keras.Model(inputs=input_, outputs=rml)
        return model, tensor_list
    
    def img_L2_loss(self,img1, img2):
        return tf.reduce_mean(tf.square(tf.subtract(img1, img2)))
    
    def d_loss_fn(self,f_logit, r_logit):
        f_loss = tf.reduce_mean(f_logit)
        r_loss = tf.reduce_mean(r_logit)
        return f_loss - r_loss


    def g_loss_fn(self,f_logit):
        f_loss = -tf.reduce_mean(f_logit)
        return f_loss
    
    def gradient_penalty(self, f, real, fake):
        alpha = tf.random.uniform([self.batch_size, 1, 1, 1], 0., 1.)
        diff = fake - real
        inter = real + (alpha * diff)
        with tf.GradientTape() as t:
            t.watch(inter)
            pred = f(inter)
        grad = t.gradient(pred, [inter])[0]
        slopes = tf.sqrt(tf.reduce_sum(tf.square(grad), axis=[1, 2, 3]))
        gp = tf.reduce_mean((slopes - 1.)**2)
        return gp
    
    @tf.function
    def train_g(self,raw):
        with tf.GradientTape() as t:
            x_fake = self.G(raw, training=True)
            fake_logits = self.D(x_fake, training=True)
            loss = self.g_loss_fn(fake_logits)
        grad = t.gradient(loss, self.G.trainable_variables)
        self.g_opt.apply_gradients(zip(grad, self.G.trainable_variables))
        
        return loss

        
    @tf.function
    def train_d(self, raw, clean):
        with tf.GradientTape() as t:
            x_fake = self.G(raw, training=True)
            fake_logits = self.D(x_fake, training=True)
            real_logits = self.D(clean, training=True)
            
            cost = self.d_loss_fn(fake_logits, real_logits)
            gp = self.gradient_penalty(partial(self.D, training=True), clean, x_fake)
            
            cost += self.grad_penalty_weight * gp
        
        grad = t.gradient(cost, self.D.trainable_variables)
        self.d_opt.apply_gradients(zip(grad, self.D.trainable_variables))
        
        return cost
    
    def train(self, raw, clean):
        g_train_loss = keras.metrics.Mean()
        d_train_loss = keras.metrics.Mean()
        
        for epoch in range(self.epochs):
            bar = pbar(self.total_images,self.batch_size,epoch,self.epochs)
            for batch in range(0, self.total_images, self.batch_size):
#                 raw_data = (raw[batch:batch+self.batch_size].astype(np.float32) - 127.5) / 127.5
#                 clean_data = (clean[batch:batch+self.batch_size].astype(np.float32) - 127.5) / 127.5

                raw_data = raw[batch:batch+self.batch_size].astype(np.float32)
                clean_data = clean[batch:batch+self.batch_size].astype(np.float32)

                for _ in range(self.n_critic):
                    self.train_d(raw_data,clean_data)
                    d_loss = self.train_d(raw_data,clean_data)
                    d_train_loss(d_loss)
                g_loss = self.train_g(raw_data)
                g_train_loss(g_loss)
                self.train_g(raw_data)
                
                bar.postfix['g_loss'] = f'{g_train_loss.result():6.3f}'
                bar.postfix['d_loss'] = f'{d_train_loss.result():6.3f}'
                bar.update(self.batch_size)
                
            g_train_loss.reset_states()
            d_train_loss.reset_states()
            
            bar.close()
            del bar
        
            if epoch % 10 == 0 and epoch != 0:
                keras.models.save_model(self.G,"gp_images/gp_model_epoch_%2d.h5"%epoch)
                
                sample_data = raw[np.random.randint(4500,size=3)]
    
                samples = self.G(sample_data,training=False)

                print(np.max(samples[0]),"\n",np.min(samples[0]))
#                 plt.subplot(131)
#                 plt.imshow(tf.cast((samples[0] + 1) * 127.5,np.uint8))
#                 plt.subplot(132)
#                 plt.imshow(tf.cast((samples[1] + 1) * 127.5,np.uint8))
#                 plt.subplot(133)
#                 plt.imshow(tf.cast((samples[2] + 1) * 127.5 ,np.uint8))
#                 plt.savefig("gp_images/epoch_%2d_sample.png"%epoch)
#                 plt.show()
                plt.subplot(131)
                plt.imshow(tf.cast(samples[0],np.uint8))
                plt.subplot(132)
                plt.imshow(tf.cast(samples[1],np.uint8))
                plt.subplot(133)
                plt.imshow(tf.cast(samples[2],np.uint8))
                plt.savefig("gp_images/epoch_%2d_sample.png"%epoch)
                plt.show()

In [15]:
iml = Loader()

raw, clean = iml.load_from_npy_old()

raw = raw[:4500]
clean = clean[:4500]
clean_after = []
for c in clean:
    clean_after.append(cv2.cvtColor(c,cv2.COLOR_RGB2BGR))
clean_after = np.array(clean_after)

In [16]:
wgan_gp = WGANGP()

In [None]:
wgan_gp.train(raw,clean_after)

Epoch 1/100


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=4500.0), HTML(value='')), layout=Layout(d…