In [None]:
import numpy as np

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import missingno as msno

import torchvision

import torch
import torch.nn as nn
import torch.nn.functional as F

import os
import time
import shutil

import itertools

from torchvision import transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from tqdm.notebook import tqdm

from PIL import Image

# Preprocessing

In [None]:
class trainDataset(Dataset):
    def __init__(self, data_dir,mode = 'train',transforms=None):
            A_dir = os.path.join(data_dir, 'monet_jpg')
            B_dir = os.path.join(data_dir, 'photo_jpg')
            
            if mode == 'train':
                self.A = [os.path.join(A_dir, name) for name in sorted(os.listdir(A_dir))[:300]]
                self.B = [os.path.join(B_dir, name) for name in sorted(os.listdir(B_dir))[:300]]
            elif mode == 'test':
                self.A = [os.path.join(B_dir, name) for name in sorted(os.listdir(B_dir))]
                self.B = [os.path.join(B_dir, name) for name in sorted(os.listdir(B_dir))]

            self.transforms = transforms

    def __len__(self):
        return len(self.A)

    def __getitem__(self, index):
        A = self.A[index]
        B = self.B[index]

        A = Image.open(A)
        B = Image.open(B)

        if self.transforms is not None:
            A = self.transforms(A)
            B = self.transforms(B)

        return B,A

In [None]:
def unnorm(img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
    for t, m, s in zip(img, mean, std):
        t.mul_(s).add_(s)
        
    return img

In [None]:
transform_train = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5), (0.5, 0.5, 0.5))
])
transform_test = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5), (0.5, 0.5, 0.5))
])


In [None]:
trainLoader = DataLoader(
    trainDataset('../input/gan-getting-started','train',transform_train),
    batch_size = 1,
    shuffle = True,
    pin_memory = True
)
testLoader = DataLoader(
    trainDataset('../input/gan-getting-started','test',transform_test),
    batch_size = 1,
    shuffle = True,
    pin_memory = True
)

# Some additional classes and functions

In [None]:
class AvgStats(object):
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.losses =[]
        self.its = []
        
    def append(self, loss, it):
        self.losses.append(loss)
        self.its.append(it)


In [None]:
class sample_fake(object):
    def __init__(self, max_imgs=50):
        self.max_imgs = max_imgs
        self.cur_img = 0
        self.imgs = list()

    def __call__(self, imgs):
        ret = list()
        for img in imgs:
            if self.cur_img < self.max_imgs:
                self.imgs.append(img)
                ret.append(img)
                self.cur_img += 1
            else:
                if np.random.ranf() > 0.5:
                    idx = np.random.randint(0, self.max_imgs)
                    ret.append(self.imgs[idx])
                    self.imgs[idx] = img
                else:
                    ret.append(img)
        return ret

In [None]:
class lr_sched():
    def __init__(self, decay_epochs=100, total_epochs=200):
        self.decay_epochs = decay_epochs
        self.total_epochs = total_epochs

    def step(self, epoch_num):
        if epoch_num <= self.decay_epochs:
            return 1.0
        else:
            fract = (epoch_num - self.decay_epochs)  / (self.total_epochs - self.decay_epochs)
            return 1.0 - fract

In [None]:
def update_grad(models,requires_grad = True):
    for model in models:
        for param in model.parameters():
            param.requires_grad = requires_grad

# Model

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock,self).__init__()
        self.seq = nn.Sequential(
        nn.ReflectionPad2d(1),
        nn.Conv2d(in_channels = in_channels,out_channels = out_channels ,kernel_size = (3,3),stride = 1),
        nn.LeakyReLU(negative_slope=0.2, inplace=True),
        nn.InstanceNorm2d(out_channels),
        nn.Dropout(0.5),
        nn.ReflectionPad2d(1),
        nn.Conv2d(in_channels = in_channels,out_channels = out_channels,kernel_size = (3,3),stride = 1),
        nn.InstanceNorm2d(out_channels)
        )
        
        
        
    def forward(self,X):
        return X+self.seq(X)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels):
        super(Discriminator, self).__init__()
       
        self.seq = nn.Sequential(
            
        nn.Conv2d(in_channels = in_channels,out_channels = 64,kernel_size = (4,4),stride = 2),
        nn.LeakyReLU(negative_slope=0.2, inplace=True),
        nn.InstanceNorm2d(num_features = 64),
        nn.Conv2d(in_channels = 64,out_channels = 128,kernel_size = (4,4),stride = 2),
        nn.InstanceNorm2d(num_features = 128),
        nn.LeakyReLU(negative_slope=0.2, inplace=True),
        nn.Conv2d(in_channels = 128,out_channels = 256,kernel_size = (4,4),stride = 2),
        nn.LeakyReLU(negative_slope=0.2, inplace=True),
        nn.InstanceNorm2d(num_features = 256),
        nn.Conv2d(in_channels = 256,out_channels = 512,kernel_size = (4,4),stride = 2),
        nn.InstanceNorm2d(num_features = 512),
        nn.Conv2d(in_channels = 512,out_channels = 1,kernel_size = (4,4),stride = 1)
            
        )
    
    
    def forward(self,X):
        output = self.seq(X)
        return output

In [None]:
class Generator(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Generator,self).__init__()
        
        self.seq = nn.Sequential(
        nn.ReflectionPad2d(3),
        nn.Conv2d(in_channels = in_channels,out_channels = 64,kernel_size = (7,7),stride = 1,padding = (0,0)),
        nn.InstanceNorm2d(64),
        nn.Conv2d(in_channels = 64,out_channels = 128,kernel_size = (3,3),stride = 2,padding = (1,1)),
        nn.InstanceNorm2d(128),
        nn.Conv2d(in_channels =128,out_channels = 256,kernel_size = (3,3),stride = 2,padding = (1,1)),
        nn.InstanceNorm2d(256),
        ResidualBlock(in_channels = 256,out_channels = 256),
        ResidualBlock(in_channels = 256,out_channels = 256),
        ResidualBlock(in_channels = 256,out_channels = 256),
        ResidualBlock(in_channels = 256,out_channels = 256),
        ResidualBlock(in_channels = 256,out_channels = 256),
        ResidualBlock(in_channels = 256,out_channels = 256),
        ResidualBlock(in_channels = 256,out_channels = 256),
        ResidualBlock(in_channels = 256,out_channels = 256),
        ResidualBlock(in_channels = 256,out_channels = 256),
        nn.ConvTranspose2d(in_channels = 256,out_channels = 128,kernel_size = (3,3),stride = 2,padding=1, output_padding=1),
        nn.InstanceNorm2d(128),
        nn.Dropout(0.5),
        nn.GELU(),
        nn.ConvTranspose2d(in_channels = 128,out_channels = 64,kernel_size = (3,3),stride = 2,padding=1, output_padding=1),
        nn.InstanceNorm2d(64),
        nn.Dropout(0.5),
        nn.GELU(),
        nn.ReflectionPad2d(3),
        nn.Conv2d(in_channels = 64,out_channels = 3,kernel_size = (7,7),stride = 1,padding = (0,0)),
        nn.Tanh()
        )
    
    def forward(self,X):
        output = self.seq(X)
        return output

# GAN Architecture

In [None]:
class CycleGan(object):
    def __init__(self,in_channels,out_channels,epochs,device,decay_epoch,lmbda,idt_coef):
        self.epochs = epochs
        self.decay_epoch = decay_epoch
        self.device = device
        self.GeneratorAB = Generator(in_channels = in_channels,out_channels = out_channels).to(device) 
        self.GeneratorBA = Generator(in_channels = in_channels,out_channels = out_channels).to(device) 
        self.DiscriminatorA = Discriminator(in_channels = in_channels).to(device)
        self.DiscriminatorB = Discriminator(in_channels = in_channels).to(device)

        self.AdamGenerator = torch.optim.Adam(itertools.chain(self.GeneratorBA.parameters(),
                                                              self.GeneratorAB.parameters()),
                                              lr = 2e-4,betas=(0.5, 0.999))
        self.AdamDiscriminator = torch.optim.Adam(itertools.chain(self.DiscriminatorB.parameters(), 
                                                                  self.DiscriminatorA.parameters()),
                                              lr = 2e-4,betas=(0.5, 0.999))
        

        self.mse_loss = nn.MSELoss()
        self.l1_loss = nn.L1Loss()
        self.sample_B = sample_fake()
        self.sample_A = sample_fake()  
        
        self.lmbda = lmbda
        self.idt_coef = idt_coef
        
        self.Generator_stats = AvgStats()
        self.Discriminator_stats = AvgStats()
        
        Generator_lr = lr_sched(self.decay_epoch, self.epochs)
        Discriminator_lr = lr_sched(self.decay_epoch, self.epochs)
        self.Generator_lr_sched = torch.optim.lr_scheduler.LambdaLR(self.AdamGenerator, Generator_lr.step)
        self.Discriminator_lr_sched = torch.optim.lr_scheduler.LambdaLR(self.AdamDiscriminator, Discriminator_lr.step)

    def train(self,trainLoader):
        for epoch in range(self.epochs):
            start_time = time.time()
            avg_gen_loss = 0.0
            avg_disc_loss = 0.0
            t = tqdm(trainLoader,leave = False,total = trainLoader.__len__())
            for i,(A,B) in enumerate(t):


                A_real,B_real = A.to(device),B.to(device)

                update_grad([self.DiscriminatorA,self.DiscriminatorB],False)
                self.AdamGenerator.zero_grad()

                fake_B = self.GeneratorAB(A_real)
                fake_A = self.GeneratorBA(fake_B)


                cycle_A = self.GeneratorBA(fake_B) 
                cycle_B = self.GeneratorAB(fake_A) 


                id_B = self.GeneratorAB(B_real) 
                id_A = self.GeneratorBA(A_real) 
                
                
                loss_id_B =  self.l1_loss(cycle_B,id_B) * self.lmbda * self.idt_coef
                loss_id_A =  self.l1_loss(cycle_A,id_A) * self.lmbda * self.idt_coef
                
                
                loss_cycle_B =  self.l1_loss(cycle_B,B_real) * self.lmbda
                loss_cycle_A =  self.l1_loss(cycle_A,A_real) * self.lmbda


                disc_A = self.DiscriminatorA(fake_A)
                disc_B = self.DiscriminatorB(fake_B)
                
                real = torch.ones(disc_A.size()).to(device)

                
                loss_adversial_B = self.mse_loss(disc_B,real)
                loss_adversial_A = self.mse_loss(disc_A,real)

                generator_loss =    loss_id_B+loss_id_A + \
                                    loss_cycle_B + loss_cycle_A +\
                                    loss_adversial_B+loss_adversial_A
                
                avg_gen_loss += generator_loss.item()

                generator_loss.backward()
                self.AdamGenerator.step()

                update_grad([self.DiscriminatorA,self.DiscriminatorB],True)

                self.AdamDiscriminator.zero_grad()


                fake_B = self.sample_B([fake_B.cpu().data.numpy()])[0]
                fake_A = self.sample_A([fake_A.cpu().data.numpy()])[0]
                fake_B = torch.tensor(fake_B).to(self.device)
                fake_A = torch.tensor(fake_A).to(self.device)

                disc_A_real = self.DiscriminatorA(A_real)
                disc_B_real = self.DiscriminatorB(B_real)
                disc_A_fake = self.DiscriminatorA(fake_A)
                disc_B_fake = self.DiscriminatorB(fake_B)

                real = torch.ones(disc_A_real.size()).to(device)
                fake = torch.zeros(disc_A_fake.size()).to(device)


                B_desc_real_loss = self.mse_loss(disc_B_real, real)
                B_desc_fake_loss = self.mse_loss(disc_B_fake, fake)
                A_desc_real_loss = self.mse_loss(disc_A_real, real)
                A_desc_fake_loss = self.mse_loss(disc_A_fake, fake)

                B_desc_loss = (B_desc_real_loss + B_desc_fake_loss) / 2
                A_desc_loss = (A_desc_real_loss + A_desc_fake_loss) / 2
                disc_loss = B_desc_loss + A_desc_loss
                avg_disc_loss += disc_loss.item()

                B_desc_loss.backward()
                A_desc_loss.backward()
                self.AdamDiscriminator.step()

                t.set_postfix(gen_loss=generator_loss.item(), disc_loss=disc_loss.item())
    

            save_dict = {
                    'epoch': epoch+1,
                    'GeneratorAB': gan.GeneratorAB.state_dict(),
                    'GeneratorBA': gan.GeneratorBA.state_dict(),
                    'disc_m': gan.DiscriminatorA.state_dict(),
                    'disc_p': gan.DiscriminatorB.state_dict(),
                    'optimizer_gan': gan.AdamGenerator.state_dict(),
                    'optimizer_disc': gan.AdamDiscriminator.state_dict()
                }
            save_checkpoint(gan, './current.ckpt')

            avg_gen_loss /= trainLoader.__len__()
            avg_disc_loss /= trainLoader.__len__()
            time_req = time.time() - start_time

            self.Generator_stats.append(avg_gen_loss, time_req)
            self.Discriminator_stats.append(avg_disc_loss, time_req)

            print("Epoch: (%d) | Generator Loss:%f | Discriminator Loss:%f" % 
                                                    (epoch+1, avg_gen_loss, avg_disc_loss))
            self.Generator_lr_sched.step()
            self.Discriminator_lr_sched.step()

# Save and load model

In [None]:
def save_checkpoint(state, save_path):
    torch.save(state, save_path)

In [None]:
def load_checkpoint(ckpt_path, map_location=None):
    ckpt = torch.load(ckpt_path, map_location=map_location)
    print(' [*] Loading checkpoint from %s succeed!' % ckpt_path)
    return ckpt

# Model evaluation

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
gan = CycleGan(3, 3, 200, device = device,decay_epoch = 100,lmbda = 10,idt_coef = 0.5)

In [None]:
save_dict = {
    'epoch': 0,
    'gen_mtp': gan.GeneratorAB.state_dict(),
    'gen_ptm': gan.GeneratorBA.state_dict(),
    'desc_m': gan.DiscriminatorA.state_dict(),
    'desc_p': gan.DiscriminatorB.state_dict(),
    'optimizer_gen': gan.AdamGenerator.state_dict(),
    'optimizer_desc': gan.AdamDiscriminator.state_dict()
}
save_checkpoint(save_dict, './init.ckpt')

In [None]:
gan.train(trainLoader)

In [None]:
load_checkpoint('./current.ckpt')

In [None]:
plt.xlabel("Epochs")
plt.ylabel("Losses")
plt.plot(gan.Generator_stats.losses, 'r', label='Generator Loss')
plt.plot(gan.Discriminator_stats.losses, 'b', label='Discriminator Loss')
plt.legend()
plt.show()

# Run Generator over all images 

In [None]:
!mkdir ../images

In [None]:
t = tqdm(testLoader, leave=False, total=testLoader.__len__())
for i, photo in enumerate(t):
    with torch.no_grad():
        pred_ = gan.GeneratorBA(photo[0].to(device)).cpu().detach()
    pred_ = unnorm(pred_)
    img = transforms.ToPILImage()(pred_[0]).convert("RGB")
    img.save("../images/" + str(i+1) + ".jpg")

In [None]:
shutil.make_archive("/kaggle/working/images", 'zip', "/kaggle/images")