In [None]:
%config Completer.use_jedi = False

In [None]:
import os
import datetime
import glob
import time
import cv2
import itertools
from tqdm.notebook import tqdm
import shutil

from PIL import Image

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import torchvision.transforms as transforms
from torchvision.utils import make_grid

## Data loader

In [None]:
img_path = '../input/gan-getting-started/'
monet_path = glob.glob(img_path + 'monet_jpg/*')
photo_path = glob.glob(img_path + 'photo_jpg/*')

print('Dataset')
print(f'- monet data : {len(monet_path)}\n- photo data : {len(photo_path)}')

In [None]:
class Custom_dataset(Dataset):
    def __init__(self, img_path : list , transforms = None, mode = 'train'):
        super().__init__()

        self.path_monet = img_path[0]
        self.path_photo = img_path[1]
        self.transforms = transforms
        self.mode = mode
        
    def __getitem__(self, idx):
        if self.mode == 'train':
            monet_img = self.path_monet[idx]
            monet_img = Image.open(monet_img).convert('RGB')
            monet_img = self.transforms(monet_img)
            
            photo_idx = np.random.randint(0, len(self.path_photo))
            photo_img = self.path_photo[photo_idx]
            photo_img = Image.open(photo_img).convert('RGB')
            photo_img = self.transforms(photo_img)
            
            return monet_img, photo_img
    
    def __len__(self):
        if self.mode == 'train':
            return len(self.path_monet)
        elif self.mode == 'test':
            return len(self.path_photo)

In [None]:
def show_img(img):
    img = make_grid(img, nrow = 6).permute([1, 2, 0]).detach().numpy()
    plt.figure(figsize = (12, 8))
    plt.imshow(img)
    plt.show()

In [None]:
img_path = [monet_path, photo_path]
sample_transform = transforms.Compose([
    transforms.ToTensor()
])
sample = Custom_dataset(img_path = img_path, transforms = sample_transform, mode = 'train')
sample_loader = DataLoader(sample, batch_size = 6)
sample_monet, sample_photo = next(iter(sample_loader))

In [None]:
print('Monet data')
show_img(sample_monet)

print('Photo data')
show_img(sample_photo)

## Modeling

In [None]:
# Define Conv Block
'''
1. Conv_up
2. Conv_down
3. Residual_block
'''

# 1.Conv_up
class Conv_up(nn.Module):
    '''
    convTranspose - instanceNorm - ReLU - (dropout)
    '''
    def __init__(self, in_ch, out_ch, kernel_size = 4, stride = 2, 
                 padding = 1, output_padding = 1, drop_out = True):
        super().__init__()
        
        self.convT = nn.ConvTranspose2d(in_ch, out_ch,
                                       kernel_size = kernel_size,
                                       stride = stride,
                                       padding = padding,
                                       output_padding = output_padding,
                                       bias = False)
        self.instance_norm = nn.InstanceNorm2d(out_ch)
        self.relu = nn.ReLU()
        self.drop_out = drop_out
        
    def forward(self, x):
        x = self.convT(x)
        x = self.instance_norm(x)
        x = self.relu(x)
        if self.drop_out:
            x = nn.Dropout2d(0.5)(x)

        return x

# 2. Conv_down
class Conv_down(nn.Module):
    '''
    Conv2d - instanceNorm - LeakyReLU
    '''
    def __init__(self, in_ch, out_ch,
                 kernel_size = 4,
                 stride = 2,
                 padding = 1,
                 batch_Norm = True):
        super().__init__()
        
        self.conv = nn.Conv2d(in_ch, out_ch,
                             kernel_size = kernel_size,
                             stride = stride,
                             padding = padding,
                             bias = True)
        self.instance_norm = nn.InstanceNorm2d(out_ch)
        self.relu = nn.ReLU()
        self.batch = batch_Norm
        
    def forward(self, x):
        x = self.conv(x)
        if self.batch:
            x = self.instance_norm(x)
        x = self.relu(x)
        
        return x
    
# 3. Residual_block
class Residual_block(nn.Module):
    '''
    Conv2d - InstanceNorm - Relu - Conv2d - InstanceNorm
    '''
    def __init__(self, in_ch, out_ch, kernel_size = 3, stride = 1, padding = 1):
        super().__init__()
        
        self.res = nn.Sequential(
            nn.Conv2d(in_ch, out_ch,
                     kernel_size = kernel_size,
                     stride = stride,
                     padding = padding,
                     padding_mode = 'reflect',
                     bias = False),
            nn.InstanceNorm2d(out_ch),
            nn.ReLU(),
            nn.Conv2d(out_ch, out_ch,
                     kernel_size = kernel_size,
                     stride = stride,
                     padding = padding,
                     padding_mode = 'reflect',
                     bias = False),
            nn.InstanceNorm2d(out_ch)
        )
    
    def forward(self, x):
        return x + self.res(x)

In [None]:
class Discriminator(nn.Module):
    '''
    C64 - C128 - C256 - C512 - C512 - 1
    '''
    def __init__(self, n_features = 64):
        super().__init__()
        
        self.main = nn.Sequential(
            Conv_down(3, n_features, batch_Norm = False),
            Conv_down(n_features * 1, n_features * 2),
            Conv_down(n_features * 2, n_features * 4),
            Conv_down(n_features * 4, n_features * 8, stride = 1),
            nn.Conv2d(n_features * 8, 1, 4, 1, 1, bias = False)
        )
    
    def forward(self, x):
        x = self.main(x)
        
        return x
    
def test():
    D = Discriminator()
    x = torch.randn(1, 3, 256, 256)
    out = D(x)
    print(out.shape)
    print('Discriminator is OK')
    
test()

In [None]:
class Generator(nn.Module):
    '''
    kernel   D64 - D128 - D256 - R256 * n - U128 - U64 - U3
    filter   7x7 - 3x3  - 3x3  -          -  3x3 - 3x3 - 7x7
    stride    1     2      2       1          2     2     1
    '''
    def __init__(self, n_features = 64, n_res = 9):
        super().__init__()
        
        self.main = nn.Sequential(
            nn.ReflectionPad2d(3),
            Conv_down(3, n_features, 7, 1, 0, False),
            Conv_down(n_features * 1, n_features * 2, 3, 2),
            Conv_down(n_features * 2, n_features * 4, 3, 2),
            *[
                Residual_block(n_features * 4, n_features * 4) for _ in range(n_res)
            ],
            Conv_up(n_features * 4, n_features * 2, 3, 2, 1),
            Conv_up(n_features * 2, n_features * 1, 3, 2, 1),
            nn.ReflectionPad2d(3),
            nn.Conv2d(n_features, 3, 7, 1, 0, bias = False),
            nn.Tanh()   
        )
        
    def forward(self, x):
        x = self.main(x)
        
        return x
    
def test():
    G = Generator()
    x = torch.randn(1, 3, 256, 256)
    out = G(x)
    print(out.shape)
    print('Generator is OK')
    
test()

In [None]:
# weight initialization 
def weight_init(m : 'model'):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('Instance') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

## Train

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Define Generator
# monet -> photo
netG_A = Generator().to(device)
# photo -> monet
netG_B = Generator().to(device)

# Define Discrimator
netD_A = Discriminator().to(device)
netD_B = Discriminator().to(device)

# weight initialization
weight_init(netG_A)
weight_init(netG_B)
weight_init(netD_A)
weight_init(netD_B)

# setting
EPOCHS = 100
lr = 2e-4
b1 = 0.5
b2 = 0.999

# optimizer
netG_optim = optim.Adam(itertools.chain(netG_A.parameters(), netG_B.parameters()), lr = lr, betas = (b1, b2))
netD_optim = optim.Adam(itertools.chain(netD_A.parameters(), netD_B.parameters()), lr = lr, betas = (b1, b2))

In [None]:
# Data loader
transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], 
                                [0.5, 0.5, 0.5])
])

train_dataset = Custom_dataset([monet_path, photo_path], transforms = transform, mode = 'train')
train_loader = DataLoader(train_dataset, batch_size = 1, shuffle = True)

In [None]:
# loss function
'''
1. GAN loss - L2
2. Cycle loss - L1
3. Identity loss - L1
'''

GAN_LOSS = nn.MSELoss()
Cycle_LOSS = nn.L1Loss()
Identity_LOSS = nn.L1Loss()

def train_model(EPOCHS = EPOCHS):
    train_hist = {}
    train_hist['G_losses'] = []
    train_hist['D_losses'] = []
    
    print('train is starting')
    
    for epoch in range(EPOCHS):
        t = time.time()
        
        netG_A.train()
        netG_B.train()
        
        netD_A.train()
        netD_B.train()
        
        G_losses = 0
        D_losses = 0
        
        for A, B in train_loader:
            # A : monet, B : photo
            A, B = A.to(device), B.to(device)
            
            A2B = netG_A(A) # fake B
            B2A = netG_B(B) # fake A
            
            pred_real_A = netD_A(A)
            pred_fake_A = netD_A(B2A.detach())
            
            # GAN LOSS for Discrimiator
            D_A_loss = GAN_LOSS(pred_real_A, torch.ones_like(pred_real_A)) +\
                        GAN_LOSS(pred_fake_A, torch.zeros_like(pred_fake_A))

            pred_real_B = netD_B(B)
            pred_fake_B = netD_B(A2B.detach())
            
            D_B_loss = GAN_LOSS(pred_real_B, torch.ones_like(pred_real_B)) +\
                        GAN_LOSS(pred_fake_B, torch.zeros_like(pred_fake_B))
            
            D_loss = (D_A_loss + D_B_loss) / 2
                    
            
            A2B2A = netG_B(A2B) # fake A
            B2A2B = netG_A(B2A) # fake B
            
            pred_fake_A = netD_A(B2A)
            pred_fake_B = netD_B(A2B)
            
            # GAN LOSS for Generator
            
            G_GAN_loss = GAN_LOSS(pred_fake_A, torch.ones_like(pred_fake_A)) +\
                          GAN_LOSS(pred_fake_B, torch.ones_like(pred_fake_B))
            
            # Cycle LOSS
            G_Cycle_loss = Cycle_LOSS(A2B2A, A) + Cycle_LOSS(B2A2B, B)
            
            # Identity LOSS
            G_identity_loss = Identity_LOSS(netG_A(B), B) + Identity_LOSS(netG_B(A), A)
            
            G_loss = G_GAN_loss + 10 * G_Cycle_loss + 5 * G_identity_loss
               
            # backward Generator
            netG_optim.zero_grad()
            G_loss.backward()
            netG_optim.step()
         
            # backward Discriminator
            netD_optim.zero_grad()
            D_loss.backward()
            netD_optim.step()     
            
            D_losses += D_loss / len(train_loader)
            G_losses += G_loss / len(train_loader)
            
        print(f'[{epoch + 1}/{EPOCHS}]\tD_loss : {D_losses:.6f}\tG_loss : {G_losses:.6f}\ttime : {time.time() - t:.3f}s')
     
        train_hist['G_losses'].append(G_losses.item())
        train_hist['D_losses'].append(D_losses.item())

        # save model per 10epochs
        if (epoch + 1) % 10 == 0:
            if not os.path.exists('/kaggle/working/model_G'):
                os.makedirs('/kaggle/working/model_G')

            if not os.path.exists('/kaggle/working/model_D'):
                os.makedirs('/kaggle/working/model_D')

            # save Generator
            torch.save(netG_A, '/kaggle/working/model_G/' + f'netG(uNet)_A{epoch + 1}.pt')
            torch.save(netG_B, '/kaggle/working/model_G/' + f'netG(uNet)_B{epoch + 1}.pt')

            # save Discriminator
            torch.save(netD_A, '/kaggle/working/model_D/' + f'netD(uNet)_A{epoch + 1}.pt')
            torch.save(netD_B, '/kaggle/working/model_D/' + f'netD(uNet)_B{epoch + 1}.pt')

            print(f'Model is saved at {epoch + 1}epochs')


    return train_hist

In [None]:
train_hist = train_model(EPOCHS)

In [None]:
plt.plot(train_hist['G_losses'], label = 'G loss')
plt.plot(train_hist['D_losses'], label = 'D loss')
plt.legend()
plt.show()