In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import cv2
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Input, Flatten, Dense,LeakyReLU,Conv2D,Reshape,Conv2DTranspose,Lambda,BatchNormalization,Dropout,ReLU,Concatenate,UpSampling2D
from tensorflow.keras.optimizers import Adam
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow_addons.layers import InstanceNormalization

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
  try:
    # Currently, memory growth needs to be the same across GPUs
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
    logical_gpus = tf.config.experimental.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
  except RuntimeError as e:
    # Memory growth must be set before GPUs have been initialized
    print(1)

2 Physical GPUs, 2 Logical GPUs


In [2]:
img_path = os.path.join('./apple2orange/')
img_list = os.listdir(img_path)
print(img_list)

['trainA', 'testB', 'trainB', 'testA']


In [3]:
data_A = []
data_B = []
arr = ['trainA','trainB']
for i in range(len(arr)):
    path = img_path + arr[i]
    img_dir = os.listdir(path)
    for j in range(len(img_dir)):
        img = cv2.imread(path + "/" + img_dir[j])
        img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
        img = cv2.resize(img,(128,128),interpolation = cv2.INTER_CUBIC)
        if i == 0:
            data_A.append(img)
        else:
            data_B.append(img)

data_A = np.array(data_A)
data_B = np.array(data_B)
data_A = data_A / 255.0
data_B = data_B / 255.0
print(data_A.shape, data_B.shape)

(995, 128, 128, 3) (1019, 128, 128, 3)


In [4]:
## make A,B, G_AB, G_BA

def build_descriminator(input_shape = (128,128,3)):
    input = Input(shape=input_shape)
    x = input
    stride = [2,2,2,1]
    filter = [32,64,128,256]
    for i in range(len(stride)):
        if i == 0:
            x = Conv2D(filters=filter[i],strides=stride[i],kernel_size=4, padding='same')(x)
            x = LeakyReLU()(x)
        else:
            x = Conv2D(filters=filter[i],strides=stride[i],kernel_size=4, padding='same')(x)
            x = InstanceNormalization(axis = -1, center = False, scale = False)(x)
            x = LeakyReLU()(x)
    output = Conv2D(filters=1,kernel_size=4,strides=1,padding='same')(x)
    model = Model(input,output)
    return model            


In [5]:
test = build_descriminator()
test.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 128, 128, 3)]     0         
_________________________________________________________________
conv2d (Conv2D)              (None, 64, 64, 32)        1568      
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 64, 64, 32)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 32, 32, 64)        32832     
_________________________________________________________________
instance_normalization (Inst (None, 32, 32, 64)        0         
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 32, 32, 64)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 16, 16, 128)       131200

In [6]:
def downsample(layer_input, filters, f_size=4):
    d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
    d = InstanceNormalization(axis = -1, center = False, scale = False)(d)
    d = ReLU()(d)
    return d

def upsample(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
    u = UpSampling2D(size=2)(layer_input)
    u = Conv2D(filters, kernel_size=f_size, strides=1, padding='same')(u)
    u = InstanceNormalization(axis = -1, center = False, scale = False)(u)
    u = ReLU()(u)
    if dropout_rate:
        u = Dropout(dropout_rate)(u)
    u = Concatenate()([u, skip_input])
    return u

def build_generator_unet(input_shape=(128,128,3),filter=32):
    input = Input(shape=input_shape)
    
    d1 = downsample(input,filters=filter)
    d2 = downsample(d1,filters=filter * 2)
    d3 = downsample(d2,filters=filter * 4)
    d4 = downsample(d3,filters=filter * 8)

    u1 = upsample(d4,d3,filters=filter * 4)
    u2 = upsample(u1,d2,filters=filter * 2)
    u3 = upsample(u2,d1,filters=filter)

    u4 = UpSampling2D(size=2)(u3)
    output = Conv2D(filters=3, kernel_size=4, strides=1,padding='same')(u4)

    return Model(input,output)

In [7]:
test2 = build_generator_unet()
test2.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(None, 128, 128, 3) 0                                            
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 64, 64, 32)   1568        input_2[0][0]                    
__________________________________________________________________________________________________
instance_normalization_3 (Insta (None, 64, 64, 32)   0           conv2d_5[0][0]                   
__________________________________________________________________________________________________
re_lu (ReLU)                    (None, 64, 64, 32)   0           instance_normalization_3[0][0]   
____________________________________________________________________________________________

In [8]:
d_A = build_descriminator()
d_B = build_descriminator()
d_A.compile(loss = 'mse',optimizer = tf.keras.optimizers.Adam(learning_rate=0.00002,beta_1=0.5),metrics=['acc'])
d_B.compile(loss = 'mse',optimizer = tf.keras.optimizers.Adam(learning_rate=0.00002,beta_1=0.5),metrics=['acc'])

In [9]:
# generator G_AB, G_BA
# A -> B,  B -> A 
# apple => orange, orange => apple

g_AB = build_generator_unet()
g_BA = build_generator_unet()

d_A.trainable = False
d_B.trainable = False

#A = apple, B = orange
img_A = Input(shape=(128,128,3))
img_B = Input(shape=(128,128,3))

fake_A = g_BA(img_A)
fake_B = g_AB(img_B)

valid_A = d_A(fake_A)
valid_B = d_B(fake_B)

reconstruct_A = g_BA(fake_B)
reconstruct_B = g_AB(fake_A)

img_A_id = g_BA(img_A)
img_B_id = g_AB(img_B)

combined = Model(inputs=[img_A,img_B],outputs=[valid_A,valid_B,reconstruct_A,reconstruct_B,img_A_id,img_B_id])

combined.compile(loss=['mse','mse','mae','mae','mae','mae'],
                 loss_weights=[1,1,10,10,2,2])

In [20]:
def train_discriminators(imgs_A, imgs_B, valid, fake):

        # Translate images to opposite domain
        fake_B = g_AB.predict(imgs_A)
        fake_A = g_BA.predict(imgs_B)

        # Train the discriminators (original images = real / translated = Fake)
        dA_loss_real = d_A.train_on_batch(imgs_A, valid)
        dA_loss_fake = d_A.train_on_batch(imgs_A, fake)
        dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)

        dB_loss_real = d_B.train_on_batch(imgs_B, valid)
        dB_loss_fake = d_B.train_on_batch(imgs_B, fake)
        dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)

        # Total disciminator loss
        d_loss_total = 0.5 * np.add(dA_loss, dB_loss)

        return (
            d_loss_total[0]
            , dA_loss[0], dA_loss_real[0], dA_loss_fake[0]
            , dB_loss[0], dB_loss_real[0], dB_loss_fake[0]
            , d_loss_total[1]
            , dA_loss[1], dA_loss_real[1], dA_loss_fake[1]
            , dB_loss[1], dB_loss_real[1], dB_loss_fake[1]
        )

In [21]:
def train_generators(imgs_A, imgs_B, valid):

    return combined.train_on_batch([imgs_A, imgs_B],
                                                [valid, valid,
                                                imgs_A, imgs_B,
                                                imgs_A, imgs_B])

In [22]:
def train_normal(batch_size,img_row,epochs,print_every_n_batches = 10):
    
    d_lossess = []
    g_lossess = []
    patch = int(img_row / 2 ** 3)
    disc_patch = (patch,patch,1)
    
    valid = np.ones((batch_size,) + disc_patch)
    fake = np.zeros((batch_size,) + disc_patch)

    epoch = 0

    for epoch in range(epoch,epochs):
        for i in range(100):

            data_batch_A = data_A[batch_size * epoch:batch_size * epoch + 1]
            data_batch_B = data_B[batch_size * epoch:batch_size * epoch + 1]
            d_loss = train_discriminators(data_batch_A, data_batch_B, valid, fake)
            g_loss = train_generators(data_batch_A, data_batch_B, valid)

            if i % 10 == 0:
                print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f, adv: %05f, recon: %05f, id: %05f] "\
                        % ( epoch, epochs,
                            i, min(len(data_A),len(data_B)),
                            d_loss[0], 100*d_loss[7],
                            g_loss[0],
                            np.sum(g_loss[1:3]),
                            np.sum(g_loss[3:5]),
                            np.sum(g_loss[5:7]),
                            ))
                sample_images(i)
            d_lossess.append(d_loss)
            g_lossess.append(g_loss)     
        
    return d_lossess,g_lossess         

In [23]:
def sample_images(batch_i):
        
    r, c = 2, 4
    arr = ['trainA', 'testB', 'trainB', 'testA']
    for p in range(2):

        if p == 1:
            np.random.seed(42)
            path_A = img_path + arr[0]
            path_B = img_path + arr[2]
            img_list_A = os.listdir(path_A)
            img_list_B = os.listdir(path_B)
            imgs_A = cv2.imread(path_A + "/" + img_list_A[np.random.randint(0,len(img_list_A))])
            imgs_A = cv2.cvtColor(imgs_A,cv2.COLOR_BGR2RGB)
            imgs_A = cv2.resize(imgs_A,(128,128),interpolation = cv2.INTER_CUBIC)
            imgs_B = cv2.imread(path_B + "/" + img_list_B[np.random.randint(0,len(img_list_B))])
            imgs_B = cv2.cvtColor(imgs_B,cv2.COLOR_BGR2RGB)
            imgs_B = cv2.resize(imgs_B,(128,128),interpolation = cv2.INTER_CUBIC)
        else:
            np.random.seed(42)
            path_A = img_path + arr[0]
            path_B = img_path + arr[2]
            img_list_A = os.listdir(path_A)
            img_list_B = os.listdir(path_B)
            imgs_A = cv2.imread(path_A + "/" + img_list_A[np.random.randint(0,len(img_list_A))])
            imgs_A = cv2.cvtColor(imgs_A,cv2.COLOR_BGR2RGB)
            imgs_A = cv2.resize(imgs_A,(128,128),interpolation = cv2.INTER_CUBIC)
            imgs_B = cv2.imread(path_B + "/" + img_list_B[np.random.randint(0,len(img_list_B))])
            imgs_B = cv2.cvtColor(imgs_B,cv2.COLOR_BGR2RGB)
            imgs_B = cv2.resize(imgs_B,(128,128),interpolation = cv2.INTER_CUBIC)

        imgs_A = imgs_A.reshape((1,128,128,3))
        imgs_B = imgs_B.reshape((1,128,128,3))

        imgs_A = imgs_A / 255.0
        imgs_B = imgs_B / 255.0
            # Translate images to the other domain
        fake_B = g_AB.predict(imgs_A)
        fake_A = g_BA.predict(imgs_B)
            # Translate back to original domain
        reconstr_A = g_BA.predict(fake_B)
        reconstr_B = g_AB.predict(fake_A)

            # ID the images
        id_A = g_BA.predict(imgs_A)
        id_B = g_AB.predict(imgs_B)

        gen_imgs = np.concatenate([imgs_A, fake_B, reconstr_A, id_A, imgs_B, fake_A, reconstr_B, id_B])

            # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5
        gen_imgs = np.clip(gen_imgs, 0, 1)

        titles = ['Original', 'Translated', 'Reconstructed', 'ID']
        fig, axs = plt.subplots(r, c, figsize=(25,12.5))
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt])
                axs[i, j].set_title(titles[j])
                axs[i,j].axis('off')
                cnt += 1
        plt.close()

In [19]:
d_losses,g_losses = train_normal(1,128,5)

[Epoch 0/5] [Batch 0/995] [D loss: 0.301325, acc:   8%] [G loss: 4.091131, adv: 0.376502, recon: 0.303146, id: 0.341583] 
[Epoch 0/5] [Batch 10/995] [D loss: 0.300359, acc:   8%] [G loss: 4.337029, adv: 0.395026, recon: 0.325746, id: 0.342274] 
[Epoch 0/5] [Batch 20/995] [D loss: 0.299454, acc:   8%] [G loss: 5.007251, adv: 0.377204, recon: 0.384688, id: 0.391583] 
[Epoch 0/5] [Batch 30/995] [D loss: 0.298584, acc:   7%] [G loss: 3.664577, adv: 0.380269, recon: 0.275238, id: 0.265964] 
[Epoch 0/5] [Batch 40/995] [D loss: 0.297759, acc:   7%] [G loss: 3.351560, adv: 0.335259, recon: 0.245854, id: 0.278881] 
[Epoch 0/5] [Batch 50/995] [D loss: 0.296946, acc:   7%] [G loss: 3.768960, adv: 0.394864, recon: 0.278904, id: 0.292530] 
[Epoch 0/5] [Batch 60/995] [D loss: 0.296203, acc:   6%] [G loss: 3.076727, adv: 0.328759, recon: 0.223722, id: 0.255375] 
[Epoch 0/5] [Batch 70/995] [D loss: 0.295457, acc:   6%] [G loss: 4.539282, adv: 0.304641, recon: 0.349170, id: 0.371470] 
[Epoch 0/5] [Batc

KeyboardInterrupt: 