In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
file_list = []
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        file_list.append(os.path.join(dirname, filename))
# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import re
def photo_jpg(x): return bool(re.search('photo_jpg',x))
def monet_jpg(x): return bool(re.search('monet_jpg',x))

photo_path_list = list(filter(photo_jpg, file_list))
monet_path_list = list(filter(monet_jpg, file_list))

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchtext import data, datasets
import random
import matplotlib.pyplot as plt
import itertools
from torchvision import transforms

#Make Dataset

In [None]:
from PIL import Image

class GANDataset(torch.utils.data.Dataset):
    def __init__(self, photo_path_list, monet_path_list, transform):
        self.photo_path_list = photo_path_list
        self.monet_path_list = monet_path_list
        self.transform = transform
        self.path = {'Photo': photo_path_list, 'Monet': monet_path_list}

    def __len__(self):
        if (len(self.path['Photo']) < len(self.path['Monet'])):
            return len(self.path['Photo'])
        else:
            return len(self.path['Monet'])

    def __getitem__(self, idx):
        imgP_path = os.path.join(self.path['Photo'][idx])
        imgM_path = os.path.join(self.path['Monet'][idx])
        imgP = Image.open(imgP_path)
        imgM = Image.open(imgM_path)
        imgP_transformed = self.transform(imgP)
        imgM_transformed = self.transform(imgM)
        data = {'Photo': imgP_transformed, 'Monet': imgM_transformed}
        return data


class ImgTransform(object):
    def __init__(self, mean, std):
        self.data_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((mean, mean, mean), (std, std, std))
        ])
    
    def __call__(self, img):
        return self.data_transform(img)

In [None]:
plt.imshow(GANDataset(photo_path_list, monet_path_list, ImgTransform(0.5, 0.5))[15]['Monet'].permute(1, 2, 0)) 

Define models ###
reference : https://github.com/davidADSP/GDL_code/blob/master/models/cycleGAN.py

# Define Resnet Generator model

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, num_filter, kernel_size=3, stride=1, padding=1):
        super().__init__()
        self.residual = nn.Sequential(
            nn.Conv2d(num_filter, num_filter, kernel_size=kernel_size,
                             stride=stride, padding=padding, padding_mode='reflect', bias=True),
            nn.InstanceNorm2d(num_filter),
            nn.ReLU(),
            nn.Conv2d(num_filter, num_filter, kernel_size=kernel_size,
                             stride=stride, padding=padding, padding_mode='reflect', bias=True),
            nn.InstanceNorm2d(num_filter)
        )

    def forward(self, x):
        return x + self.residual(x)
    
class ConvBlock(torch.nn.Module):
    def __init__(self,in_dim, out_dim, kernel_size=3, stride=2, padding=1, activation='relu', batch_norm=True):
        super(ConvBlock,self).__init__()
        self.conv = torch.nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding)
        self.batch_norm = batch_norm
        self.bn = torch.nn.InstanceNorm2d(out_dim)
        self.activation = activation
        self.relu = torch.nn.ReLU(True)
        self.tanh = torch.nn.Tanh()
    def forward(self,x):
        x = self.conv(x)
        if self.batch_norm:
            x = self.bn(x)
        if self.activation == 'relu':
            return self.relu(x)
        elif self.activation == 'tanh':
            return self.tanh(x)
            
class DeconvBlock(torch.nn.Module):
    def __init__(self,in_dim,out_dim,kernel_size=3,stride=2,padding=1,output_padding=1,activation='relu',batch_norm=True):
        super(DeconvBlock,self).__init__()
        self.deconv = torch.nn.ConvTranspose2d(in_dim,out_dim,kernel_size,stride,padding,output_padding)
        self.batch_norm = batch_norm
        self.bn = torch.nn.InstanceNorm2d(out_dim)
        self.activation = activation
        self.relu = torch.nn.ReLU(True)
    def forward(self,x):
        x = self.deconv(x)
        if self.batch_norm:
            out = self.bn(x)
        if self.activation == 'relu':
            return self.relu(out)


In [None]:
class Resnet_Generator(nn.Module):
    def __init__(self, input_dim=3, output_dim=3, num_filter=64, num_resnet=9):
        super(Resnet_Generator, self).__init__()
        self.pad = torch.nn.ReflectionPad2d(3)
        self.conv1 = ConvBlock(input_dim, num_filter, kernel_size=7, stride=1, padding=0)
        self.conv2 = ConvBlock(num_filter, num_filter * 2)
        self.conv3 = ConvBlock(num_filter * 2, num_filter * 4)
        self.resnet_blocks = []
        for i in range(num_resnet):
            self.resnet_blocks.append(ResidualBlock(4 * num_filter))
        self.resnet_blocks = torch.nn.Sequential(*self.resnet_blocks)
        
        self.deconv1 = DeconvBlock(num_filter * 4, num_filter * 2)
        self.deconv2 = DeconvBlock(num_filter * 2, num_filter)
        self.deconv3 = ConvBlock(num_filter, output_dim,
                                 kernel_size=7, stride=1, padding=0, activation='tanh', batch_norm=False)

    def forward(self, x):
        x = self.conv1(self.pad(x))
        x = self.conv2(x)
        x = self.conv3(x)
        
        x = self.resnet_blocks(x)
        
        x = self.deconv1(x)
        x = self.deconv2(x)
        x = self.deconv3(self.pad(x))
        return x

# Define U-net Generator model

In [None]:
#downsample
def downsample(in_dim, out_dim, act_fn, f_size=4):
    model = nn.Sequential(
        nn.Conv2d(in_dim, out_dim, kernel_size=f_size, stride=2, padding=1),
        nn.InstanceNorm2d(out_dim),
        act_fn,
    )
    return model

In [None]:
#upsample without concatenate
def upsample(in_dim, out_dim, act_fn, f_size=4):
  model = nn.Sequential(
      nn.ConvTranspose2d(in_dim, out_dim, kernel_size=f_size, stride=2, padding=1),
      nn.InstanceNorm2d(out_dim),
      act_fn
  )
  return model
    

In [None]:
class Unet_Generator(nn.Module):
    def __init__(self, in_dim=3, out_dim=3, num_filter=64):
        super(Unet_Generator, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.num_filter = num_filter
        act_fn = nn.ReLU()
        
        self.down_1 = downsample(in_dim, num_filter*1, act_fn)
        self.down_2 = downsample(num_filter*1, num_filter*2, act_fn)
        self.down_3 = downsample(num_filter*2, num_filter*4, act_fn)
        self.down_4 = downsample(num_filter*4, num_filter*8, act_fn)
        
        self.up_1 = upsample(num_filter*8, num_filter*4, act_fn)
        self.up_2 = upsample(num_filter*8, num_filter*2, act_fn)
        self.up_3 = upsample(num_filter*4, num_filter*1, act_fn)
        self.up_4 = nn.Upsample(scale_factor=2, mode='nearest')
        
        self.out = nn.Sequential(
            nn.Conv2d(num_filter*2, out_dim, kernel_size=3,stride=1,padding=1),
            nn.Tanh(),
        )

    def forward(self, input):
        down_1 = self.down_1(input)
        down_2 = self.down_2(down_1)
        down_3 = self.down_3(down_2)
        down_4 = self.down_4(down_3)        
        
        up_1 = self.up_1(down_4)
        concat_1 = torch.cat([up_1, down_3], dim=1)
        up_2 = self.up_2(concat_1)
        concat_2 = torch.cat([up_2, down_2], dim=1)
        up_3 = self.up_3(concat_2)
        concat_3 = torch.cat([up_3, down_1], dim=1)
        up_4 = self.up_4(concat_3)
        out = self.out(up_4)

        return out

# Define Discriminator

In [None]:
def conv4(in_dim, num_filter, stride, act_fn, norm=True):
    kernel_size = 3 + stride - 1
    if norm:
      model = nn.Sequential(
          nn.Conv2d(in_dim, num_filter, kernel_size=4, stride=stride, padding=1),
          nn.InstanceNorm2d(num_filter),
          act_fn
      )
    else:
      model = nn.Sequential(
        nn.Conv2d(in_dim, num_filter, kernel_size=4, stride=stride, padding=1),
        act_fn
      )
    return model

In [None]:
# define discriminator

class Discriminator(nn.Module):
    def __init__(self, in_dim=3, num_filter=64):
        super(Discriminator, self).__init__()
        self.in_dim = in_dim
        self.num_filter = num_filter
        act_fn = nn.LeakyReLU(0.2)
        self.layer_1 = conv4(in_dim, num_filter, 2, act_fn, norm=False)
        self.layer_2 = conv4(num_filter, num_filter*2, 2, act_fn)
        self.layer_3 = conv4(num_filter*2, num_filter*4, 2, act_fn)
        self.layer_4 = conv4(num_filter*4, num_filter*8, 1, act_fn)

        self.out = nn.Conv2d(num_filter*8, 1, kernel_size=4, stride=1, padding=0)
    
    def forward(self, input):
        y = self.layer_1(input)
        y = self.layer_2(y)
        y = self.layer_3(y)
        y = self.layer_4(y)
        out = self.out(y)
        return out

In [None]:
train_dataset = GANDataset(
    photo_path_list=photo_path_list, monet_path_list=monet_path_list, 
    transform = ImgTransform(mean=0.5, std=0.5)
)

batch_size = 1
epochs = 130
train_dataloader = torch.utils.data.DataLoader(train_dataset, 
                                               batch_size=batch_size, shuffle=True)

In [None]:
#reference: https://medium.com/humanscape-tech/ml-practice-cyclegan-f9153ef72297

# Define accelerator
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define Network
#netG = Unet_Generator().to(device)
#netF = Unet_Generator().to(device)
netG = Resnet_Generator().to(device)
netF = Resnet_Generator().to(device)
netDP = Discriminator().to(device)
netDM = Discriminator().to(device)

#initializing network
def init_weights(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        #initializing Con2d, ConvTranspose2d
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
    
    elif classname.find('Instance') != -1:
        #initializing InstanceNorm
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


# initialize network
init_weights(netG)
init_weights(netF)
init_weights(netDP)
init_weights(netDM)

In [None]:
class LambdaLR():
    def __init__(self, n_epochs, offset, decay_start_epoch):
        assert ((n_epochs - decay_start_epoch) > 0), "Decay must start before the training session ends!"
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch

    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch)/(self.n_epochs - self.decay_start_epoch)

In [None]:
# define loss function
ganLoss = nn.MSELoss().to(device)
cycleConsistencyLoss = nn.L1Loss().to(device)
identityLoss = nn.L1Loss().to(device)

# set up optimizer
initial_learning_rate = 2e-4

optimG = torch.optim.Adam(itertools.chain(netG.parameters(), netF.parameters()), lr=initial_learning_rate, betas=(0.5, 0.999))
optimDP = torch.optim.Adam(netDP.parameters(), lr=initial_learning_rate, betas=(0.5, 0.999))
optimDM = torch.optim.Adam(netDM.parameters(), lr=initial_learning_rate, betas=(0.5, 0.999))

# set up scheduler
schedulerG = torch.optim.lr_scheduler.LambdaLR(optimizer=optimG,
                                lr_lambda=LambdaLR(epochs, 0, 100).step)
schedulerDP = torch.optim.lr_scheduler.LambdaLR(optimizer=optimDP,
                                lr_lambda=LambdaLR(epochs, 0, 100).step)
schedulerDM = torch.optim.lr_scheduler.LambdaLR(optimizer=optimDM,
                                lr_lambda=LambdaLR(epochs, 0, 100).step)

In [None]:
denormalize = transforms.Normalize(
    mean=(-1 * 0.5/0.5, -1 * 0.5/0.5, -1 * 0.5/0.5), 
    std=(1/0.5, 1/0.5, 1/0.5)
)

In [None]:
import time
import datetime
import matplotlib.pyplot as plt
from IPython.display import clear_output

epoch_lossDM = []
epoch_lossDP = []
epoch_lossGen = []
epoch_ = []
save_path_list = []
start=time.time()
for epoch in range(1, epochs + 1):
    # train start
    netG.train()
    netF.train()
    netDP.train()
    netDM.train()

    # loss array declaration
    lossFakeDP = []
    lossRealDP = []
    lossFakeDM = []
    lossRealDM = []

    lossGanG = []
    lossGanF = []
    lossPositiveCycle = []
    lossNegativeCycle = []
    lossIdentityG = []
    lossIdentityF = []

    for batch, data in enumerate(train_dataloader, 1):
        # forward path
        realP = data['Photo'].to(device)
        realM = data['Monet'].to(device)

        # Validity : Whether the image created by each Generator tricks the discriminator
        fakeM = netG(realP)
        fakeP = netF(realM)
        
        # loss by discrimiator
        discFakeP = netDP(fakeP)
        discFakeM = netDM(fakeM)
        
        lossG = ganLoss(discFakeM, torch.ones_like(discFakeM))
        lossF = ganLoss(discFakeP, torch.ones_like(discFakeP))
        
        
        # Reconstruction : Cycle reconstruction for original image
        reconP = netF(fakeM)
        reconM = netG(fakeP)
        
        lossPositiveCycle = cycleConsistencyLoss(reconP, realP)
        lossNegativeCycle = cycleConsistencyLoss(reconM, realM)

        # Identity : Whether each Generator maintain their target domain image's identity
        identityM = netG(realM)
        identityP = netF(realP)
        
        lossIdentityG = identityLoss(identityM, realM)
        lossIdentityF = identityLoss(identityP, realP)
        
        # Generator's total loss
        lossGen = (lossG + lossF) + (lossPositiveCycle + lossNegativeCycle) + (lossIdentityG + lossIdentityF)

        optimG.zero_grad()
        lossGen.backward()
        optimG.step()
        
        ######## Discriminator for Photo ########
        discFakeP = netDP(fakeP.detach())
        discRealP = netDP(realP)

        # torch.zeros_like() -> fake label, torch.ones_like() -> real label
        lossDFakeP = ganLoss(discFakeP, torch.zeros_like(discFakeP))
        lossDRealP = ganLoss(discRealP, torch.ones_like(discRealP))

        lossDP = (lossDFakeP + lossDRealP) * 0.5
        optimDP.zero_grad()
        lossDP.backward()
        optimDP.step()

        
        ######## Discriminator for Monet ########
        discFakeM = netDM(fakeM.detach())
        discRealM = netDM(realM)
        
        lossDFakeM = ganLoss(discFakeM, torch.zeros_like(discFakeM))
        lossDRealM = ganLoss(discRealM, torch.ones_like(discRealM))

        lossDM = (lossDFakeM + lossDRealM) * 0.5
        optimDM.zero_grad()
        lossDM.backward()
        optimDM.step()
    schedulerG.step()
    schedulerDP.step() 
    schedulerDM.step()

    #Saving models
    if epoch % 10 == 0:
        if not os.path.exists('cycleGANmodel'):
            os.makedirs('cycleGANmodel')
        save_path = 'cycleGANmodel/Photo2Monet_generator_{0}.pt'.format(epoch)
        save_path_list.append(save_path)
        torch.save(netG.state_dict(), save_path)
    epoch_lossGen.append(lossGen)
    epoch_lossDP.append(lossDP)
    epoch_lossDM.append(lossDM)
    epoch_.append(epoch)
    clear_output(wait = True)
    print("epoch : {0}  lossGen : {1} lossDP : {2} lossDM : {3}".format(epoch, lossGen, lossDP, lossDM))

    
fig = plt.figure(figsize=(8,8))
fig.set_facecolor('white')
ax = fig.add_subplot()
 
ax.plot(epoch_,epoch_lossGen, label='lossGen')
ax.plot(epoch_,epoch_lossDP, label='lossDP') 
ax.plot(epoch_,epoch_lossDM, label='lossDM') 
ax.legend()
ax.set_xlabel('epoch')
ax.set_ylabel('loss')

plt.show()

end = time.time()-start
times = str(datetime.timedelta(seconds=end)).split(".")
print('Finished in {0}'.format(times[0]))

# Plot 5 samples of each model's fake_photo result

In [None]:
from torchvision.utils import save_image
import os
import zipfile
#model_path = 'cycleGANmodel'

random_sample = random.sample(photo_path_list, 5)
num_sample = 10

for file in save_path_list:
    print(file)
    model_idx = file
    netG.load_state_dict(torch.load('./{0}'.format(file)))
    #os.makedirs('{0}'.format(model_idx[:-3]), exist_ok=True)
    
    fakeM_ReverseNormed_list = []
    
    for photo_path in random_sample:
      img_path = os.path.join(photo_path)
      imgP = Image.open(img_path)
        
      transform = ImgTransform(mean=0.5, std=0.5)
      imgP_transformed = transform(imgP).unsqueeze(dim=0)
      imgP_transformed = imgP_transformed.to(device)
    
      fakeM = netG(imgP_transformed)
      fakeM = fakeM.to('cpu')
      fakeM_ReverseNormed = denormalize(fakeM)
      fakeM_ReverseNormed = np.transpose(fakeM_ReverseNormed.detach().numpy().squeeze(), (1,2,0))
      fakeM_ReverseNormed_list.append(fakeM_ReverseNormed)
    
    plt.figure(figsize=(50,50))
    for num, img in enumerate(fakeM_ReverseNormed_list):
      plt.subplot(1, 5, num+1)
      plt.axis('off'), plt.xticks([]), plt.yticks([])
      plt.imshow(img)

    plt.tight_layout()
    plt.subplots_adjust(left = 0, bottom = 0, right = 1, top = 1, hspace = 0, wspace = 0)
    plt.show()

# Save fake images as [model_name]/images.zip

In [None]:
from torchvision.utils import save_image
import os
import zipfile
import shutil
import time
import datetime
import glob
model_path = 'cycleGANmodel'
start = time.time()

last_path = save_path_list[-1]
netG.load_state_dict(torch.load('./{0}'.format(last_path)))
netG = netG.to(device)
os.makedirs('{0}'.format(last_path[:-3]), exist_ok=True)
cnt = 0
for photo_path in photo_path_list:
    cnt += 1
    img_path = os.path.join(photo_path)
    imgP = Image.open(img_path)
    transform = ImgTransform(mean=0.5, std=0.5)
    imgP_transformed = transform(imgP).unsqueeze(dim=0)
    imgP_transformed = imgP_transformed.to(device)
    fakeM = netG(imgP_transformed)
    fakeM = fakeM.to('cpu')
    fakeM_ReverseNormed = denormalize(fakeM)
    fakeM_ReverseNormed = fakeM_ReverseNormed.detach().squeeze()
    filename = re.findall(r"[0-9a-zA-Z]*[.]jpg", photo_path)[0]
    save_image(fakeM_ReverseNormed, '{0}/{1}'.format(last_path[:-3], filename))
    print('Saved transformed images by last model to {0}'.format(last_path[:-3]))
    print('{0}/{1} completed'.format(cnt, len(photo_path_list)))
    clear_output(wait = True)
    



zip_file = zipfile.ZipFile('{0}.zip'.format('images'), "w")

for file in os.listdir(last_path[:-3]):
    if file.endswith('.jpg'):
        jpg_file_path = os.path.join(last_path[:-3], file)
        zip_file.write(jpg_file_path, os.path.basename(jpg_file_path), compress_type=zipfile.ZIP_DEFLATED)
zip_file.close()

[os.remove(f) for f in glob.glob("./{0}/*.jpg".format(last_path[:-3]))]

end = time.time() - start
times = str(datetime.timedelta(seconds=end)).split(".")
print('Finished in {0}'.format(times[0]))
