In [None]:
DEBUG = False

In [None]:
import os
import pandas as pd
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import itertools
from tqdm.notebook import tqdm
import wandb

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms

In [None]:
!wandb login 45e03af61cb9a1b88a119d38b52d0396b3d33437

In [None]:
NAME = 'debug' if DEBUG else 'cycle-gan-exp01'

config = dict( imgsize = 256,
             batch_size = 1 if DEBUG else 2,
             input_nc = 3,
             output_nc = 3,
             lr = 2e-4,
             epoch = 0,
             n_epochs = 30,
             decay_epoch = 15)

wandb.init(project = 'cyclegan', name = NAME, config = config)

In [None]:
IMG_PATH = '../input/gan-getting-started/photo_jpg'
MONET_PATH = '../input/gan-getting-started/monet_jpg'

img_list = os.listdir(IMG_PATH)
monet_list = os.listdir(MONET_PATH)

print(len(img_list))
print(len(monet_list))

In [None]:
plt.figure(figsize = (15,10))

for i in range(10):
    plt.subplot(2,5,i+1)
    img = Image.open(os.path.join(IMG_PATH,img_list[i])).convert('RGB')
    plt.imshow(img)
    plt.axis('off')
    
plt.tight_layout()

In [None]:
plt.figure(figsize = (15,10))

for i in range(10):
    plt.subplot(2,5,i+1)
    img = Image.open(os.path.join(MONET_PATH,monet_list[i])).convert('RGB')
    plt.imshow(img)
    plt.axis('off')
    
plt.tight_layout()

In [None]:
class IMGDataset(Dataset):
    def __init__(self, config, tfm):
        self.size = config['imgsize']
        self.img_list = os.listdir(IMG_PATH)
        self.monet_list = os.listdir(MONET_PATH)
        self.tfm = tfm
        
    def __len__(self):
        return min(len(self.img_list), len(self.monet_list))
    
    def __getitem__(self, idx):
        
#         img_index = idx % len(self.img_list)
#         monet_index = idx % len(self.monet_list)

        img_index = idx
        monet_index = idx
        
        img = Image.open(os.path.join(IMG_PATH, self.img_list[img_index])).convert('RGB')
        monet_img = Image.open(os.path.join(MONET_PATH, self.monet_list[monet_index])).convert('RGB')
        
        img = self.tfm(img)
        monet_img = self.tfm(monet_img)
        
        return img, monet_img

In [None]:
tfm =  transforms.Compose([transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ])

In [None]:
dataset = IMGDataset(config, tfm)
dataloader = DataLoader(dataset, batch_size = config['batch_size'], shuffle = True, num_workers = True, pin_memory=True)

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()

        conv_block = [  nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features),
                        nn.ReLU(inplace=True),
                        nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features)  ]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)

class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9):
        super(Generator, self).__init__()

        # Initial convolution block       
        model = [   nn.ReflectionPad2d(3),
                    nn.Conv2d(input_nc, 64, 7),
                    nn.InstanceNorm2d(64),
                    nn.ReLU(inplace=True) ]

        # Downsampling
        in_features = 64
        out_features = in_features*2
        for _ in range(2):
            model += [  nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features*2

        # Residual blocks
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]

        # Upsampling
        out_features = in_features//2
        for _ in range(2):
            model += [  nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features//2

        # Output layer
        model += [  nn.ReflectionPad2d(3),
                    nn.Conv2d(64, output_nc, 7),
                    nn.Tanh() ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

class Discriminator(nn.Module):
    def __init__(self, input_nc):
        super(Discriminator, self).__init__()

        # A bunch of convolutions one after another
        model = [   nn.Conv2d(input_nc, 64, 4, stride=2, padding=1),
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(64, 128, 4, stride=2, padding=1),
                    nn.InstanceNorm2d(128), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(128, 256, 4, stride=2, padding=1),
                    nn.InstanceNorm2d(256), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(256, 512, 4, padding=1),
                    nn.InstanceNorm2d(512), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        # FCN classification layer
        model += [nn.Conv2d(512, 1, 4, padding=1)]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        x =  self.model(x)
        # Average pooling and flatten
        return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)

In [None]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        torch.nn.init.normal(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant(m.bias.data, 0.0)

In [None]:
class LambdaLR():
    def __init__(self, n_epochs, offset, decay_start_epoch):
        assert ((n_epochs - decay_start_epoch) > 0), "Decay must start before the training session ends!"
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch

    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch)/(self.n_epochs - self.decay_start_epoch)

In [None]:
netG_A2B = Generator(config['input_nc'], config['output_nc'])
netG_B2A = Generator(config['input_nc'], config['output_nc'])
netD_A = Discriminator(config['input_nc'])
netD_B = Discriminator(config['output_nc'])

In [None]:
netG_A2B.apply(weights_init_normal)
netG_B2A.apply(weights_init_normal)
netD_A.apply(weights_init_normal)
netD_B.apply(weights_init_normal)

print('Weights Initialized')

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

netG_A2B.to(device)
netG_B2A.to(device)
netD_A.to(device) 
netD_B.to(device) 

print(f'Transferred to {device}')

In [None]:
# Lossess
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

# Optimizers & LR schedulers
optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()),
                                lr=config['lr'], betas=(0.5, 0.999))
optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=config['lr'], betas=(0.5, 0.999))
optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=config['lr'], betas=(0.5, 0.999))

lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(config['n_epochs'], config['epoch'], config['decay_epoch']).step)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(config['n_epochs'], config['epoch'], config['decay_epoch']).step)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(config['n_epochs'], config['epoch'], config['decay_epoch']).step)

In [None]:
target_real = torch.ones(config['batch_size'], dtype=torch.float).unsqueeze(1).to(device)
target_fake = torch.ones(config['batch_size'], dtype=torch.float).unsqueeze(1).to(device)

In [None]:
wandb_step = 0
log_image_step = 50
loader_len = len(dataloader)

###### Training ######
for epoch in range(config['epoch'], config['n_epochs']):
    for i, (photo, monet_img) in enumerate(tqdm(dataloader, total = loader_len)):
        
        # Set model input
        real_A = photo.to(device)
        real_B = monet_img.to(device)

        ###### Generators A2B and B2A ######
        optimizer_G.zero_grad()

        # Identity loss
        # G_A2B(B) should equal B if real B is fed
        same_B = netG_A2B(real_B)
        loss_identity_B = criterion_identity(same_B, real_B)*5.0
        # G_B2A(A) should equal A if real A is fed
        same_A = netG_B2A(real_A)
        loss_identity_A = criterion_identity(same_A, real_A)*5.0

        # GAN loss
        fake_B = netG_A2B(real_A)
        pred_fake = netD_B(fake_B)
        loss_GAN_A2B = criterion_GAN(pred_fake, target_real)

        fake_A = netG_B2A(real_B)
        pred_fake = netD_A(fake_A)
        loss_GAN_B2A = criterion_GAN(pred_fake, target_real)

        # Cycle loss
        recovered_A = netG_B2A(fake_B)
        loss_cycle_ABA = criterion_cycle(recovered_A, real_A)*10.0

        recovered_B = netG_A2B(fake_A)
        loss_cycle_BAB = criterion_cycle(recovered_B, real_B)*10.0

        # Total loss
        loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
        loss_G.backward(retain_graph = True)
        
        
        ###################################

        ###### Discriminator A ######
        optimizer_D_A.zero_grad()

        # Real loss
        pred_real = netD_A(real_A)
        loss_D_real = criterion_GAN(pred_real, target_real)

        # Fake loss
        pred_fake = netD_A(fake_A)
        loss_D_fake = criterion_GAN(pred_fake, target_fake)

        # Total loss
        loss_D_A = (loss_D_real + loss_D_fake)*0.5
        loss_D_A.backward()

        
        ###################################

        ###### Discriminator B ######
        optimizer_D_B.zero_grad()

        # Real loss
        pred_real = netD_B(real_B)
        loss_D_real = criterion_GAN(pred_real, target_real)
        
        # Fake loss
        pred_fake = netD_B(fake_B)
        loss_D_fake = criterion_GAN(pred_fake, target_fake)

        # Total loss
        loss_D_B = (loss_D_real + loss_D_fake)*0.5
        loss_D_B.backward()

        
        ###################################
        
        # Update the weights
        optimizer_G.step()
        optimizer_D_A.step()
        optimizer_D_B.step()
        
        

#         # Progress report (http://localhost:8097)

        wandb_step = loader_len*epoch + i
        wandb.log({'loss_G': loss_G, 'loss_G_identity': (loss_identity_A + loss_identity_B), 'loss_G_GAN': (loss_GAN_A2B + loss_GAN_B2A),
                    'loss_G_cycle': (loss_cycle_ABA + loss_cycle_BAB), 'loss_D': (loss_D_A + loss_D_B)}, step = wandb_step)
        
        if wandb_step % log_image_step == 0:
    
            wandb.log({'exp01': [wandb.Image(real_A.cpu().detach().numpy()[0].transpose(1,2,0), caption='real_A'), 
                                wandb.Image(real_B.cpu().detach().numpy()[0].transpose(1,2,0), caption='real_B'),
                                wandb.Image(fake_A.cpu().detach().numpy()[0].transpose(1,2,0), caption = 'fake_A'),
                                wandb.Image(fake_B.cpu().detach().numpy()[0].transpose(1,2,0), caption = 'fake_B')]}, step = wandb_step)



    # Update learning rates
    lr_scheduler_G.step()
    lr_scheduler_D_A.step()
    lr_scheduler_D_B.step()