In [None]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from PIL import Image
from torchvision.utils import make_grid
import PIL

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import random
import os
import glob

In [None]:
def show_tensor_images(image_tensor, num_images=2, size=(3,256,256)):
    image_tensor = (image_tensor + 1) / 2
    image_shifted = image_tensor
    image_unflat = image_shifted.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=2)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

In [None]:
class ImageDataset(Dataset):
    def __init__(self,root_dir,transform=None):
        self.transform = transform
        self.file_monet = sorted(glob.glob(os.path.join(root_dir,"monet_jpg")+"/*.*"))
        self.file_photo = sorted(glob.glob(os.path.join(root_dir,"photo_jpg")+"/*.*"))
        self.new_perm()
        
    def new_perm(self):
        self.randperm = torch.randperm(len(self.file_photo))[:len(self.file_monet)]
        
    def __getitem__(self,index):
        item_monet = self.transform(Image.open(self.file_monet[index%len(self.file_monet)]))
        item_photo = self.transform(Image.open(self.file_photo[self.randperm[index]]))
        return (item_monet-0.5)*2,(item_photo-0.5)*2
        
    def __len__(self):
        return min(len(self.file_monet),len(self.file_photo))

## Cycle GAN: Generator

#### residual blocks

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, input_channels):
        super(ResidualBlock,self).__init__()
        self.conv1 = nn.Conv2d(input_channels,input_channels,kernel_size=3,padding=1,padding_mode='reflect')
        self.conv2 = nn.Conv2d(input_channels,input_channels,kernel_size=3,padding=1,padding_mode='reflect')
        self.instancenorm = nn.InstanceNorm2d(input_channels)
        self.activation = nn.ReLU()
        
    def forward(self,x):
        original_x = x.clone()
        x = self.conv1(x)
        x = self.instancenorm(x)
        x = self.activation(x)
        
        x = self.conv2(x)
        x = self.instancenorm(x)
        return original_x + x

#### contracting and expanding blocks

In [None]:
class ContractingBlock(nn.Module):
    def __init__(self,input_channels,use_bn=True,kernel_size=3,activation='relu'):
        super(ContractingBlock,self).__init__()
        self.conv1 = nn.Conv2d(input_channels,input_channels*2,kernel_size=kernel_size,padding=1,stride=2,padding_mode='reflect')
        self.activation = nn.ReLU() if activation == 'relu' else nn.LeakyReLU(0.2)
        if use_bn:
            self.instancenorm = nn.InstanceNorm2d(input_channels * 2)
        self.use_bn = use_bn
    
    def forward(self,x):
        x = self.conv1(x)
        if self.use_bn:
            x = self.instancenorm(x)
        x = self.activation(x)
        return x
    

class ExpandingBlock(nn.Module):
    def __init__(self,input_channels,use_bn=True):
        super(ExpandingBlock,self).__init__()
        self.conv1 = nn.ConvTranspose2d(input_channels,input_channels//2,kernel_size=3,stride=2,padding=1,output_padding=1)
        if use_bn:
            self.instancenorm = nn.InstanceNorm2d(input_channels//2)
        self.use_bn = use_bn
        self.activation = nn.ReLU()
    
    def forward(self,x):
        x = self.conv1(x)
        if self.use_bn:
            x = self.instancenorm(x)
        x = self.activation(x)
        return x


class FeatureMapBlock(nn.Module):
    def __init__(self,input_channels,output_channels):
        super(FeatureMapBlock,self).__init__()
        self.conv = nn.Conv2d(input_channels,output_channels,kernel_size=7,padding=3,padding_mode='reflect')
        
    def forward(self,x):
        x = self.conv(x)
        return x

#### cycleGAN generator

In [None]:
class Generator(nn.Module):
    def __init__(self,input_channels,output_channels,hidden_channels=64):
        super(Generator,self).__init__()
        self.upfeature = FeatureMapBlock(input_channels,hidden_channels)
        self.contract1 = ContractingBlock(hidden_channels)
        self.contract2 = ContractingBlock(hidden_channels*2)
        res_mult = 4
        self.res0 = ResidualBlock(hidden_channels*res_mult)
        self.res1 = ResidualBlock(hidden_channels*res_mult)
        self.res2 = ResidualBlock(hidden_channels*res_mult)
        self.res3 = ResidualBlock(hidden_channels*res_mult)
        self.res4 = ResidualBlock(hidden_channels*res_mult)
        self.res5 = ResidualBlock(hidden_channels*res_mult)
        self.res6 = ResidualBlock(hidden_channels*res_mult)
        self.res7 = ResidualBlock(hidden_channels*res_mult)
        self.res8 = ResidualBlock(hidden_channels*res_mult)
        self.expand1 = ExpandingBlock(hidden_channels*4)
        self.expand2 = ExpandingBlock(hidden_channels*2)
        self.downfeature = FeatureMapBlock(hidden_channels,output_channels)
        self.tanh = torch.nn.Tanh()
        
    def forward(self,x):
        x0 = self.upfeature(x)
        x1 = self.contract1(x0)
        x2 = self.contract2(x1)
        x3 = self.res0(x2)
        x4 = self.res1(x3)
        x5 = self.res2(x4)
        x6 = self.res3(x5)
        x7 = self.res4(x6)
        x8 = self.res5(x7)
        x9 = self.res6(x8)
        x10 = self.res7(x9)
        x11 = self.res8(x10)
        x12 = self.expand1(x11)
        x13 = self.expand2(x12)
        xn = self.downfeature(x13)
        return self.tanh(xn)

## PatchGAN Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self,input_channels,hidden_channels=64):
        super(Discriminator,self).__init__()
        self.upfeature = FeatureMapBlock(input_channels,hidden_channels)
        self.contract1 = ContractingBlock(hidden_channels,use_bn=False,kernel_size=4,activation='lrelu')
        self.contract2 = ContractingBlock(hidden_channels*2,kernel_size=4,activation='lrelu')
        self.contract3 = ContractingBlock(hidden_channels*4,kernel_size=4,activation='lrelu')
        self.final = nn.Conv2d(hidden_channels*8,1,kernel_size=1)
        
    def forward(self,x):
        x0 = self.upfeature(x)
        x1 = self.contract1(x0)
        x2 = self.contract2(x1)
        x3 = self.contract3(x2)
        xn = self.final(x3)
        return xn

#### Training parameters

In [None]:
adv_criterion = nn.MSELoss()    #BCE?
recon_criterion = nn.L1Loss()

n_epochs = 40
dim_P=3
dim_M=3
display_step = 300    # img/2epoch
batch_size=2
lr = 0.0002      #0.0002->0.001->0.0004->0.0001->0.0003
device='cuda'

In [None]:
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor()
])
dataset = ImageDataset("../input/gan-getting-started",transform=transform)

In [None]:
gen_PM = Generator(dim_P,dim_M).to(device)
gen_MP = Generator(dim_M,dim_P).to(device)
gen_opt = torch.optim.Adam(list(gen_PM.parameters())+list(gen_MP.parameters()),lr=lr,betas=(0.5,0.999))

disc_P = Discriminator(dim_P).to(device)
disc_P_opt = torch.optim.Adam(disc_P.parameters(),lr=lr,betas=(0.5,0.999))
disc_M = Discriminator(dim_M).to(device)
disc_M_opt = torch.optim.Adam(disc_M.parameters(),lr=lr,betas=(0.5,0.999))

def weights_init(m):
    if isinstance(m,nn.Conv2d) or isinstance(m,nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight,0.0,0.02)
    if isinstance(m,nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight,0.0,0.02)
        torch.nn.init.constant_(m.bias,0)

pretrain = False
if pretrain:
    pre_dict = torch.load("")
    gen_PM.load_state_dict(pre_dict["gen_PM"])
    gen_MP.load_state_dict(pre_dict["gen_MP"])
    gen_opt.load_state_dict(pre_dict["gen_opt"])
    disc_P.load_state_dict(pre_dict["disc_P"])
    disc_P_opt.load_state_dict(pre_dict["disc_P_opt"])
    disc_M.load_state_dict(pre_dict["disc_M"])
    disc_M_opt.load_state_dict(pre_dict["disc_M_opt"])
else:
    gen_PM = gen_PM.apply(weights_init)
    gen_MP = gen_MP.apply(weights_init)
    disc_P = disc_P.apply(weights_init)
    disc_M = disc_M.apply(weights_init)

## Discriminator Loss

In [None]:
def disc_loss(real_x,fake_x,disc_x,adv_criterion):
    fake_pred = disc_x(fake_x.detach())
    real_pred = disc_x(real_x)
    fake_loss = adv_criterion(fake_pred,torch.zeros_like(fake_pred))
    real_loss = adv_criterion(real_pred,torch.ones_like(real_pred))
    
    disc_loss = (fake_loss+real_loss)/2
    return disc_loss

## Generator Loss

#### adversarial loss

In [None]:
def gen_adversarial_loss(real_x,disc_y,gen_xy,adv_criterion):
    fake_y = gen_xy(real_x)
    y_pred = disc_y(fake_y)
    adversarial_loss = adv_criterion(y_pred,torch.ones_like(y_pred))
    return adversarial_loss,fake_y

#### identity loss

In [None]:
def identity_loss(real_x,gen_yx,identity_criterion):
    identity_x = gen_yx(real_x)
    identity_loss = identity_criterion(identity_x,real_x)
    return identity_loss,identity_x

#### cycle consistency loss

In [None]:
def cycle_consistency_loss(real_x,fake_y,gen_yx,cycle_criterion):
    cycle_x = gen_yx(fake_y)
    cycle_loss = cycle_criterion(cycle_x,real_x)
    return cycle_loss,cycle_x

### generator loss (combine together)

In [None]:
def gen_loss_func(real_x,real_y,gen_xy,gen_yx,disc_x,disc_y,adv_criterion,identity_criterion,cycle_criterion,lambda_identity=0.1,lambda_cycle=10):
    adversarial_loss_y,fake_y = gen_adversarial_loss(real_x,disc_y,gen_xy,adv_criterion)
    adversarial_loss_x,fake_x = gen_adversarial_loss(real_y,disc_x,gen_yx,adv_criterion)
    
    identity_loss_x,identity_x = identity_loss(real_x,gen_yx,identity_criterion)
    identity_loss_y,identity_y = identity_loss(real_y,gen_xy,identity_criterion)
    
    cycle_loss_x,cycle_x = cycle_consistency_loss(real_x,fake_y,gen_yx,cycle_criterion)
    cycle_loss_y,cycle_y = cycle_consistency_loss(real_y,fake_x,gen_xy,cycle_criterion)
    
    gen_loss = (adversarial_loss_x+adversarial_loss_y)+lambda_cycle*(cycle_loss_x+cycle_loss_y)+lambda_identity*(identity_loss_x+identity_loss_y)
    return gen_loss,fake_x,fake_y

## Training

In [None]:
plt.rcParams["figure.figsize"] = (10,10)

genloss_plot=[]
discloss_plot=[]
def train(save_model):
    mean_generator_loss = 0
    mean_discriminator_loss = 0
    dataloader = DataLoader(dataset,batch_size=batch_size,shuffle=True)
    cur_step=0
    
    for epoch in range(n_epochs):
        for real_M,real_P in tqdm(dataloader):
            #discriminator M
            cur_batch_size = len(real_M)
            real_M = real_M.to(device)
            real_P = real_P.to(device)
            
            disc_M_opt.zero_grad()
            with torch.no_grad():
                fake_M = gen_PM(real_P)
            disc_M_loss = disc_loss(real_M,fake_M,disc_M,adv_criterion)
            disc_M_loss.backward(retain_graph=True)
            disc_M_opt.step()
            
            #discriminator P
            disc_P_opt.zero_grad()
            with torch.no_grad():
                fake_P = gen_MP(real_M)
            disc_P_loss = disc_loss(real_P,fake_P,disc_P,adv_criterion)
            disc_P_loss.backward(retain_graph=True)
            disc_P_opt.step()
            
            #generator
            gen_opt.zero_grad()
            gen_loss,fake_M,fake_P = gen_loss_func(
                real_M,real_P,gen_MP,gen_PM,disc_M,disc_P,adv_criterion,recon_criterion,recon_criterion
            )
            gen_loss.backward()
            gen_opt.step()
            
            mean_discriminator_loss += disc_M_loss.item() / display_step
            mean_generator_loss += gen_loss.item() / display_step
            
            if cur_step%display_step == 0:
                genloss_plot.append(mean_generator_loss)
                discloss_plot.append(mean_discriminator_loss)
                
                print(f"epoch {epoch}: Step {cur_step}: Generator(U-Net) loss: {mean_generator_loss}, Discriminator loss: {mean_discriminator_loss}")
                show_tensor_images(torch.cat([real_M[0],real_P[0]]), size=(3,256,256))
                show_tensor_images(torch.cat([fake_P[0],fake_M[0]]), size=(3,256,256))
                mean_generator_loss = 0
                mean_discriminator_loss = 0
                ##save model##
                if save_model:
                    torch.save({
                        'gen_PM':gen_PM.state_dict(),
                        'gen_MP':gen_MP.state_dict(),
                        'gen_opt':gen_opt.state_dict(),
                        'disc_P':disc_P.state_dict(),
                        'disc_P_opt':disc_P_opt.state_dict(),
                        'disc_M':disc_M.state_dict(),
                        'disc_M_opt':disc_M_opt.state_dict()
                    },f"cycleGAN_{cur_step}.pth")
                ##save model##
            cur_step += 1

In [None]:
train(False)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
plt.figure(figsize=(15,6))

df = pd.DataFrame({"generator_loss":genloss_plot,
                   "discriminator_loss":discloss_plot})
with sns.axes_style("darkgrid"):
    sns.lineplot(data=df,palette='flare')

## submission

In [None]:
os.makedirs('../images')

In [None]:
photo_imgs = sorted(glob.glob(os.path.join("../input/gan-getting-started","photo_jpg")+"/*.*"))
i=0
gen_PM.eval()
for i in tqdm(range(len(photo_imgs))):
    photo_img = transform(Image.open(photo_imgs[i]))
    photo_img = photo_img.detach().cuda().view(-1, *(3,256,256))
    predict = gen_PM(photo_img).detach()
    predict = (predict.cpu() + 1) / 2
    predict = make_grid(predict[:1])
    predict = predict.permute(1, 2, 0).squeeze().numpy()
    predict  = (255 * predict).astype(np.uint8)
    
    im = PIL.Image.fromarray(predict)
    im.save("../images/" + str(i) + ".jpg")
    i+=1

In [None]:
import shutil
shutil.make_archive("/kaggle/working/images", 'zip', "/kaggle/images")