In [None]:
import os
import itertools

import numpy as np

import torch 
import torch.nn as nn
import torch.nn.functional as F

from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset, DataLoader

import torchvision.transforms.v2 as T
import torchvision.transforms.v2.functional as TF
from torchvision.utils import make_grid

import math

import random

import functools

from tqdm import tqdm

import warnings
warnings.filterwarnings('ignore')

# Creating datasets and dataloaders

In [None]:
torch.cuda.empty_cache()

In [None]:
from glob import glob
from PIL import Image

In [None]:
class CustomDataset(Dataset):
    def __init__(self, path, transform=None):
        super().__init__()
        
        self.path = path
        self.files = glob(f'{path}/*.jpg')
        self.transform = transform
        
    def __len__(self):
        return len(self.files) 
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        img = Image.open(self.files[idx])
        
        if self.transform:
            img = self.transform(img)
            
        return img

In [None]:
transform = T.Compose(
    [
        T.Resize((268,268)),
        T.RandomCrop((256,256)),
        T.RandomHorizontalFlip(),
        T.ColorJitter(contrast=0.2, saturation=0.2, hue=0.1),
        T.ToTensor(),
        T.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
    ]
)

In [None]:
def inverse_transform(img):
    return (img * 0.5) + 0.5

In [None]:
monet_ds = CustomDataset('/kaggle/input/gan-getting-started/monet_jpg', transform=transform)
photo_ds = CustomDataset('/kaggle/input/gan-getting-started/photo_jpg', transform=transform)

In [None]:
class CombinedLoader:
    def __init__(self, photo_ds, monet_ds, batch_size, num_workers):
        self.photo_ds = photo_ds
        self.monet_ds = monet_ds
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.len = max(len(photo_ds), len(monet_ds))//batch_size

    def __iter__(self):
        self.photo_loader = DataLoader(self.photo_ds, batch_size=self.batch_size, drop_last=True, num_workers=self.num_workers, shuffle=True)
        self.monet_loader = DataLoader(self.monet_ds, batch_size=self.batch_size, drop_last=True, num_workers=self.num_workers, shuffle=True)
        self.photo_iter = itertools.cycle(self.photo_loader)
        self.monet_iter = itertools.cycle(self.monet_loader)
        self.counter = 0
        return self
    
    def __len__(self):
        return self.len

    def __next__(self):
        if self.counter > self.len:
            raise StopIteration
        
        self.counter += 1
        return next(self.photo_iter), next(self.monet_iter)     

# Defining the Model
original paper - [Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://arxiv.org/abs/1703.10593)


### The Generator
The original paper proposes a generator architecture made up of multiple residual blocks, however that fails to converge easily.
The U-Net architecture that is used by pix2pix converges much faster

Reflection padding is used to reduce artifacts

In the upsample blocks, we don't use Conv2dTranspose. Instead, we use a NearestNeighbour Upsample block followed by a Conv2d layer.
This is done to further reduce artifacts and checkboard patterns in the generated images. Source: https://distill.pub/2016/deconv-checkerboard/

### The Discriminator
We use a PatchGAN discriminator as mentioned in the original paper

In [None]:
class G_down_block(nn.Module):
    def __init__(self, in_channels, out_channels, norm=True):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(
                in_channels, 
                out_channels, 
                kernel_size=4, 
                stride=2, 
                padding=1, 
                padding_mode="reflect"
            )
        )
        
        if norm:
            self.layers.append(nn.InstanceNorm2d(out_channels))
        
        self.layers.append(nn.LeakyReLU(0.2))
        
    def forward(self, x):
        return self.layers(x) 

class G_up_block(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=False, act="leaky-relu"):
        super().__init__()
        self.layers = nn.Sequential(
            nn.UpsamplingNearest2d(scale_factor=2),
            nn.Conv2d(
                in_channels, 
                out_channels, 
                kernel_size=3, 
                padding=1,
                padding_mode="reflect"
            )
        )
        
        self.layers.append(nn.InstanceNorm2d(out_channels))
        
        if act == "tanh":
            self.layers.append(nn.Tanh())
        else:
            self.layers.append(nn.LeakyReLU(0.2))
        
        if dropout:
            self.layers.append(nn.Dropout(0.5))
        
    def forward(self, x):
        return self.layers(x) 

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.down = nn.Sequential(
            G_down_block(3,64,norm=False),
            G_down_block(64,128),
            G_down_block(128,256),
            G_down_block(256,512),
            G_down_block(512,512),
            G_down_block(512,512),
            G_down_block(512,512),
            G_down_block(512,512,norm=False)
        )
        self.up = nn.Sequential(
            G_up_block(2*512,512,dropout=True),
            G_up_block(2*512,512,dropout=True),
            G_up_block(2*512,512,dropout=True),
            G_up_block(2*512,512),
            G_up_block(2*512,256),
            G_up_block(2*256,128),
            G_up_block(2*128,64),
            G_up_block(2*64,3,act="tanh"),
        )
    def forward(self, x):
        buffer = []
        for net in self.down:
            x = net(x)
            buffer.append(x)
        
        for res, net in zip(buffer[::-1], self.up):
            x = net(torch.cat([x,res], dim=1))
            
        return x

class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels,64,4,2,1),
            nn.LeakyReLU(0.2,True),
            nn.Conv2d(64,128,4,2,1,bias=True),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2,True),
            nn.Conv2d(128,256,4,2,1,bias=True),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2,True),
            nn.Conv2d(256,512,4,1,1,bias=True),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2,True),
            nn.Conv2d(512,1,4,1,1),
        )
    
    def forward(self, x):
        return self.layers(x)

In [None]:
class CycleGAN(nn.Module):
    def __init__(self):
        super().__init__()
        self.photo_G = Generator()
        self.photo_D = Discriminator()
        
        self.monet_G = Generator()
        self.monet_D = Discriminator()

        self.apply(self._init_weights)

        print(f'Cycle GAN initialized with {self._get_num_params()} parameters')

    def _init_weights(self, module):
        if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        
    def _get_num_params(self):
        p_G = sum([p.numel() for p in self.photo_G.parameters()])
        p_D = sum([p.numel() for p in self.photo_D.parameters()])
        m_G = sum([p.numel() for p in self.monet_G.parameters()])
        m_D = sum([p.numel() for p in self.monet_D.parameters()])
        
        return p_G + p_D + m_G + m_D
    
    def photo_to_monet(self, photo):
        self.monet_G.eval()
        with torch.no_grad():
            out = self.monet_G(photo)
        return out
    
    def monet_to_photo(self, monet):
        self.photo_G.eval()
        with torch.no_grad():
            out = self.photo_G(monet)
        return out

### Image Buffer
Inspired by the original paper.
This is used in training to randomly fetch images for training the discriminator 

In [None]:
class ImageBuffer:
    def __init__(self, size=70):
        self.size = size
        self.buffer = []
    def pass_images(self, imgs):
        out_buffer = []
        for img in imgs:
            if len(self.buffer) < self.size:
                self.buffer.append(img)
                out_buffer.append(img)
            else:
                #faster than random.choice https://stackoverflow.com/questions/6824681/get-a-random-boolean-in-python
                if bool(random.getrandbits(1)): 
                    fetch_img_idx = random.choice(range(self.size))
                    out_buffer.append(self.buffer[fetch_img_idx])
                    self.buffer[fetch_img_idx] = img
                else:
                    out_buffer.append(img)
        return torch.stack(out_buffer, dim=0)

# The Trainer

### Logging
We use tensorboard for logging (disabled for submission notebooks)
Scalar metrics are logged for every step and Images are logged at the end of every epoch 

### Learning Rate Scheduler
A custom lr scheduler is used. (for more control and also pytorch lr schedulers confuse me)

In [None]:
from dataclasses import dataclass
from typing import Tuple

from torch.utils.tensorboard import SummaryWriter

import itertools

In [None]:
@dataclass
class TrainerConfig:
    log_dir: str
    batch_size: int = 16
    num_workers: int = 2
    device: str|None = None

    h_lambda: int = 10
    lr: float = 3e-4
    betas: Tuple[float, float] = (0.9, 0.999)
    eps: float = 1e-8
        
    weight_decay: float = 0
    warmup_iters: int = 500
    min_lr: float = 1e-5
    lr_decay_iters: int = 4000

    def __post_init__(self):
        if self.device is None:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
class Trainer:
    def __init__(self, config, model):
        self.model = model
        self.device = config.device
        
        self.lr = config.lr
        
        self.warmup_iters = config.warmup_iters
        self.min_lr = config.min_lr
        self.lr_decay_iters = config.lr_decay_iters

        self.h_lambda = config.h_lambda
        
        self.g_optim, self.d_optim = self._get_optimizers(config)
        
        #self.log_writer = SummaryWriter(config.log_dir)
        
        self.loader = self._get_dataloaders(config)
        
        self.fake_monet_buffer = ImageBuffer()
        self.fake_photo_buffer = ImageBuffer()
        
        self.metrics_buffer = []

    #custom lr scheduler inspired by https://github.com/karpathy/nanoGPT
    def get_lr(self, it):
        # 1) linear warmup for warmup_iters steps
        if it < self.warmup_iters:
            return self.lr * it / self.warmup_iters
        # 2) if it > lr_decay_iters, return min learning rate
        if it > self.lr_decay_iters:
            return self.min_lr
        # 3) in between, use cosine decay down to min learning rate
        decay_ratio = (it - self.warmup_iters) / (self.lr_decay_iters - self.warmup_iters)
        assert 0 <= decay_ratio <= 1
        coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
        return self.min_lr + coeff * (self.lr - self.min_lr)
        
    def _get_optimizers(self, config):
        g_optim = Adam(
            itertools.chain(
                self.model.photo_G.parameters(), 
                self.model.monet_G.parameters()
            ),
            lr=config.lr,
            betas=config.betas,
            eps=config.eps,
            weight_decay=config.weight_decay
        )
        d_optim = Adam(
            itertools.chain(
                self.model.photo_D.parameters(), 
                self.model.monet_D.parameters()
            ),
            lr=config.lr,
            betas=config.betas,
            eps=config.eps,
            weight_decay=config.weight_decay
        )
        return g_optim, d_optim
    
    def _get_dataloaders(self, config):
        return CombinedLoader(photo_ds, monet_ds, batch_size=config.batch_size, num_workers=config.num_workers)
       
    def _set_discriminator_requires_grad(self, requires_grad=False):
        for net in [self.model.photo_D, self.model.monet_D]:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad
        
    def _get_gan_loss(self, x, is_real):
        target = torch.full(x.shape, float(is_real), device=self.device)
        return F.mse_loss(x, target)
    
    def _get_discriminator_loss(self, net, real, fake):
        loss_real = self._get_gan_loss(net(real), is_real=True)
        loss_fake = self._get_gan_loss(net(fake), is_real=False)
        loss = (loss_real + loss_fake) * 0.5
        loss.backward()
        return loss
    
    def train_step(self, real_photo, real_monet):
        self.model.train()
        
        fake_photo = self.model.photo_G(real_monet)
        fake_monet = self.model.monet_G(real_photo)
        cycle_photo = self.model.photo_G(fake_monet)
        cycle_monet = self.model.monet_G(fake_photo)
        
        #lock discriminators
        self._set_discriminator_requires_grad(False)
        
        #generator optimization
        self.g_optim.zero_grad()
        
        #identity loss
        id_loss_1 = F.l1_loss(self.model.photo_G(real_photo), real_photo) * 0.5 * self.h_lambda
        id_loss_2 = F.l1_loss(self.model.monet_G(real_monet), real_monet) * 0.5 * self.h_lambda
        id_loss = id_loss_1 + id_loss_2
        
        #generator loss
        photo_G_loss = self._get_gan_loss(
            self.model.photo_D(fake_photo),
            is_real=True
        )
        monet_G_loss = self._get_gan_loss(
            self.model.monet_D(fake_monet),
            is_real=True
        )
        gen_loss = photo_G_loss + monet_G_loss
        
        #cycle loss
        loss_photo_cycle = F.l1_loss(cycle_photo, real_photo) * self.h_lambda
        loss_monet_cycle = F.l1_loss(cycle_monet, real_monet) * self.h_lambda
        cycle_loss = loss_photo_cycle + loss_monet_cycle
        
        #backprop for generators
        total_G_loss =  id_loss + gen_loss + cycle_loss
        total_G_loss.backward()
        
        nn.utils.clip_grad_norm_(
            itertools.chain(
                self.model.photo_G.parameters(), 
                self.model.monet_G.parameters()
            ), 
            1.0
        )
        
        self.g_optim.step()
        
        #unfreeze discriminators
        self._set_discriminator_requires_grad(True)
        
        #discriminator optimization
        self.d_optim.zero_grad()
        
        photo_D_loss = self._get_discriminator_loss(
            self.model.photo_D,
            real_photo,
            self.fake_photo_buffer.pass_images(fake_photo.detach())
        )
        monet_D_loss = self._get_discriminator_loss(
            self.model.monet_D,
            real_monet,
            self.fake_monet_buffer.pass_images(fake_monet.detach())
        )
        
        nn.utils.clip_grad_norm_(
            itertools.chain(
                self.model.photo_D.parameters(), 
                self.model.monet_D.parameters()
            ), 
            1.0
        )
        
        self.d_optim.step()

        return {
            "monet_g_loss": monet_G_loss.item(),
            "photo_g_loss": photo_G_loss.item(),
            "monet_d_loss": monet_D_loss.item(),
            "photo_d_loss": photo_D_loss.item(),
            "identity_loss": id_loss.item(),
            "cyclic_loss": cycle_loss.item(),
            "total_g_loss": total_G_loss.item(),
            "total_d_loss": (photo_D_loss + monet_D_loss).item()
        }
    
    def eval_step(self, step):
        eval_photo = random.choice(photo_ds).to(self.device)
        gen_monet = self.model.photo_to_monet(eval_photo.unsqueeze(0)).squeeze(0)
        
        photo2monet_grid = inverse_transform(make_grid([eval_photo, gen_monet]))

        #self.log_writer.add_image('photo_to_monet', photo2monet_grid, step)
        display(TF.to_pil_image(photo2monet_grid))
        
        
    
    def train(self, epochs):
        step = 0
        tqdm_g_loss = 0.0  
        tqdm_d_loss = 0.0
        for epoch in range(epochs):
            running_total_g_loss = 0.0
            running_total_d_loss = 0.0
            for photo, monet in (pbar := tqdm(self.loader)):
                pbar.set_description(f"Epoch{epoch:2d} G_loss:{tqdm_g_loss:5.3f} D_loss:{tqdm_d_loss:5.3f}")
                
                photo, monet = photo.to(self.device), monet.to(self.device)
                
                new_lr = self.get_lr(step)
                for param_group in itertools.chain(
                    self.d_optim.param_groups, 
                    self.g_optim.param_groups
                ):
                    param_group['lr'] = new_lr

                metrics_dict = self.train_step(photo, monet)
                self.metrics_buffer.append(metrics_dict)
                
                tqdm_g_loss = metrics_dict["total_g_loss"]
                tqdm_d_loss = metrics_dict["total_d_loss"]

                running_total_g_loss += tqdm_g_loss
                running_total_d_loss += tqdm_d_loss

                # for metric in metrics_dict.keys():
                #     self.log_writer.add_scalar(str(metric), metrics_dict[metric], step)
                
                step += 1

            # cli log 
            print(f"g_loss: {running_total_g_loss/len(self.loader)} d_loss: {running_total_d_loss/len(self.loader)}")

            # image log
            self.eval_step(step)


# Training

In [None]:
config = TrainerConfig(
    log_dir="./logs",
    batch_size=8,
    lr=2e-4, #lr used in the paper
    weight_decay=0.01
)

model = CycleGAN().to(config.device)

trainer = Trainer(
    config,
    model
)

In [None]:
trainer.train(24)

In [None]:
photo_test_ds = CustomDataset('/kaggle/input/gan-getting-started/photo_jpg', transform=T.Compose([T.ToTensor(), T.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])]))
test_loader = DataLoader(photo_test_ds, num_workers=config.num_workers, drop_last=False, batch_size=config.batch_size)

### Save model weights (optional)

In [None]:
#torch.save(trainer.model.state_dict(), "/tmp/7_epochs_UNET_modified.pt")

In [None]:
def gen_monets(photos):
    photos = photos.to(config.device)
    monets = model.photo_to_monet(photos)
    return [TF.to_pil_image(inverse_transform(monet)) for monet in monets]

### Save generated monets

In [None]:
!mkdir /tmp/images

In [None]:
generated_monets = []
for photo_batch in tqdm(test_loader, "generating monets"):
    monets = gen_monets(photo_batch)
    generated_monets += monets

for idx, monet in enumerate(tqdm(generated_monets, "saving images")):
    monet.save(f'/tmp/images/{idx}.jpg')

In [None]:
import shutil
shutil.make_archive("/kaggle/working/images", 'zip', "/tmp/images")