> # **âŸ²CycleGAN_Pytorch (Training)âŸ²**

In [None]:
!pip install wandb
import wandb
#Wandb Login
wandb.login()
# wandb config
run = wandb.init(project='Monet Cycle GAN')

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
from PIL import Image
import itertools
from tqdm import tqdm
import cv2


import torch
import torchvision

import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torch.nn as nn

In [None]:
path = ['../input/gan-getting-started/monet_jpg/','../input/gan-getting-started/photo_jpg/']
monet = os.listdir('../input/gan-getting-started/monet_jpg')
photo = os.listdir('../input/gan-getting-started/photo_jpg')

In [None]:
#visulization

def de_norm(input):
    mean = 0.5 
    std = 0.5
    return input * std + mean

def visulization(x,y,z):
    x = x.cpu().detach().numpy().transpose(1,2,0)
    y = y.cpu().detach().numpy().transpose(1,2,0)
    z = z.cpu().detach().numpy().transpose(1,2,0)
    x = de_norm(x)
    y = de_norm(y)
    z = de_norm(z)
    
    print('photo Data Range', x.max(), x.min())
    print('Fake monet Data Range',y.max(), y.min())
    print('Cycle photo Data Range',z.max(), z.min())
    plt.figure(figsize = (10,5))
    plt.subplot(1,3,1)
    plt.title('actual')
    plt.imshow(x)
    plt.tick_params(left = False, bottom = False, labelleft = False, labelbottom = False)
    plt.subplot(1,3,2)
    plt.title('fake')
    plt.imshow(y)
    plt.tick_params(left = False, bottom = False, labelleft = False, labelbottom = False)
    plt.subplot(1,3,3)
    plt.title('cycle')
    plt.imshow(z)
    plt.tick_params(left = False, bottom = False, labelleft = False, labelbottom = False)
    plt.show()
    
# weight initialization 
def weight_init(m : 'model'):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('Instance') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

def update_req_grad(models, requires_grad=True):
    for model in models:
        for param in model.parameters():
            param.requires_grad = requires_grad

In [None]:
#Defining Dataset
train_transform = T.Compose([T.RandomHorizontalFlip(p = 0.5),
                             T.RandomVerticalFlip(p = 0.5),
                             T.RandomRotation(180),
                             T.ToTensor(),
                             T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


class CustomDataset(Dataset):
    def __init__(self, path, monet, photo, transforms = None, seed = 777):
        self.path = path
        self.monet = monet
        self.photo = photo
        self.seed = seed
        self.transforms = transforms
        self.photo_len = len(self.monet)
        self.monet_len = len(self.photo)
        self.length_dataset = max(self.photo_len, self.monet_len)
        
        
    def __len__(self):
        return len(self.monet)
    
    def __getitem__(self, idx):
       
        #get path
        monet_path = self.path[0] + self.monet[idx]
        photo_idx = np.random.randint(0, len(self.photo))
        photo_path = self.path[1] + self.photo[photo_idx]
        #get image
        monet = Image.open(monet_path).convert('RGB')
        photo = Image.open(photo_path).convert('RGB')
        #image Transform
        if self.transforms:
            torch.manual_seed(self.seed)
            monet = self.transforms(monet)
        if self.transforms:
            torch.manual_seed(self.seed)
            photo = self.transforms(photo)
        return monet, photo

> # **âŸ²CycleGAN NetworkâŸ²**

![](https://www.researchgate.net/profile/Irene-Gu-2/publication/343075688/figure/fig3/AS:916527951917056@1595528700573/Architecture-of-the-generator-and-discriminator-of-unpaired-CycleGAN-Conv-2D.ppm)


In [None]:
#Cycle GAN
#Generator(Photo <--> Monet)
#Defining layers
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        #1. Conv layers
        
        def CLayer(in_ch, out_ch, kernel_size = 3, stride = 1, padding = 1, bias = True, norm = 'bnorm', relu = 'relu'):
            layers = []

            layers += [nn.ReflectionPad2d(padding)]
                         
            layers += [nn.Conv2d(in_channels = in_ch,
                                 out_channels = out_ch,
                                 kernel_size = kernel_size,
                                 stride = stride,
                                 padding = 0,
                                 bias = bias)]
            if not norm is None:
                if norm == 'bnorm':
                    layers += [nn.BatchNorm2d(num_features = out_ch)]
                elif norm =='inorm':
                    layers += [nn.InstanceNorm2d(num_features = out_ch)]
            
            if relu == 'relu':
                layers += [nn.ReLU()]
            elif relu == 'leakyrelu':
                layers += [nn.LeakyReLU()]
                
            return nn.Sequential(*layers)
        
        #2. Residual Blocks
        
        def Rblock(in_ch, out_ch, kernel_size = 3, stride = 1, padding = 1, bias = True, norm = 'bnorm', relu = 0.0):
            layers = []
            layers += [CLayer(in_ch = in_ch,
                              out_ch = out_ch,
                              kernel_size = kernel_size,
                              stride = stride,
                              padding = padding,
                              bias = bias,
                              norm = norm)]
            
            layers += [CLayer(in_ch = in_ch,
                              out_ch = out_ch,
                              kernel_size = kernel_size,
                              stride = stride,
                              padding = padding,
                              bias = bias,
                              norm = norm,
                             relu = None)]
            return nn.Sequential(*layers)
        
        #3. Transpose Conv layers
        
        def TCLayer(in_ch, out_ch, kernel_size = 3, stride = 2, padding = 1, output_padding = 1, bias = True, norm = 'bnorm', relu = 'relu', dropout = None):
            layers = []
            layers += [nn.ConvTranspose2d(in_channels = in_ch,
                                         out_channels = out_ch,
                                         kernel_size = kernel_size,
                                         stride = stride,
                                         padding = padding,
                                         output_padding = output_padding,
                                         bias = bias)]
            if not dropout is None:
              if dropout == True:
                layers += [nn.Dropout(0.5)]

            if not norm is None:
                if norm == 'bnorm':
                    layers += [nn.BatchNorm2d(num_features = out_ch)]
                elif norm =='inorm':
                    layers += [nn.InstanceNorm2d(num_features = out_ch)]
            if not relu is None:        
              if relu == 'relu':
                  layers += [nn.ReLU()]
              elif relu == 'leakyrelu':
                  layers += [nn.LeakyReLU(0.2)]
                
            return nn.Sequential(*layers)

        #Encoder
        self.encoder1 = CLayer(in_ch = 3, out_ch = 64 , kernel_size = 7, stride = 1, padding = 3, norm = None, relu = 'relu')
        self.encoder2 = CLayer(in_ch = 64, out_ch = 128 , kernel_size = 3, stride = 2, padding = 1, norm = 'inorm', relu = 'relu')
        self.encoder3 = CLayer(in_ch = 128, out_ch = 256 , kernel_size = 3, stride = 2, padding = 1, norm = 'inorm', relu = 'relu')

        #Transformer
        res_layer = []
        
        for i in range(9):
            res_layer += [Rblock(in_ch = 256, out_ch = 256, kernel_size = 3, stride = 1, padding = 1, norm = 'inorm', relu = 'relu')]
            
        self.trans = nn.Sequential(*res_layer)

        #Decoder
        self.decoder1 = TCLayer(in_ch = 256, out_ch = 128, kernel_size = 3, stride = 2, padding = 1, output_padding = 1, norm = 'inorm', relu = 'relu', dropout=True)
        self.decoder2 = TCLayer(in_ch = 128, out_ch = 64, kernel_size = 3, stride = 2, padding = 1, output_padding = 1, norm = 'inorm', relu = 'relu', dropout=True)
        self.decoder3 = CLayer(in_ch = 64, out_ch = 3 , kernel_size = 7, stride = 1, padding = 3, norm = 'inorm', relu = None) #No activation Function
        
    def forward(self, input):
        
        #Encoder
        x = self.encoder1(input)
        x = self.encoder2(x)
        x = self.encoder3(x)
        
        #Transformer
        x = self.trans(x)
        
        #Decoder
        x = self.decoder1(x)
        x = self.decoder2(x)
        x = self.decoder3(x)
        
        output = torch.tanh(x)
        
        return output

In [None]:
#Cycle GAN
#Discriminator(Photo <--> Monet)
#Defining layers
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        #1. Conv layers
        
        def CLayer(in_ch, out_ch, kernel_size = 3, stride = 1, padding = 1, bias = True, norm = 'bnorm', relu = 'relu'):
            layers = []

            layers += [nn.Conv2d(in_channels = in_ch,
                                 out_channels = out_ch,
                                 kernel_size = kernel_size,
                                 stride = stride,
                                 padding = 0,
                                 bias = bias)]
            if not norm is None:
                if norm == 'bnorm':
                    layers += [nn.BatchNorm2d(num_features = out_ch)]
                elif norm =='inorm':
                    layers += [nn.InstanceNorm2d(num_features = out_ch)]
            
            if relu == 'relu':
                layers += [nn.ReLU()]
            elif relu == 'leakyrelu':
                layers += [nn.LeakyReLU(0.2, inplace = True)]
                
            return nn.Sequential(*layers)
        
        self.decoder1 = CLayer(in_ch = 3, out_ch = 64, kernel_size = 4, stride = 2, padding = 1, bias = False, norm = None, relu = 'leakyrelu')
        self.decoder2 = CLayer(in_ch = 64, out_ch = 128, kernel_size = 4, stride = 2, padding = 1, bias = False, norm = None, relu = 'leakyrelu')
        self.decoder3 = CLayer(in_ch = 128, out_ch = 256, kernel_size = 4, stride = 2, padding = 1, bias = False, norm = None, relu = 'leakyrelu')
        self.decoder4 = CLayer(in_ch = 256, out_ch = 512, kernel_size = 4, stride = 1, padding = 1, bias = False, norm = None, relu = 'leakyrelu')
        self.decoder5 = CLayer(in_ch = 512, out_ch = 1, kernel_size = 4, stride = 1, padding = 1, bias = False, norm = None, relu = None)
        
    def forward(self, input):
        
        x = self.decoder1(input)
        x = self.decoder2(x)
        x = self.decoder3(x)
        x = self.decoder4(x)
        output = self.decoder5(x)
        
        
        return output

In [None]:
#Training
#configure
lr = 0.0002
n_epoch = 500
batch_size = 1
setting_patience = 7



device = 'cuda' if torch.cuda.is_available() else 'cpu'
#Defineing Model
GM2P = Generator().to(device) #Monet ---> Photo
GP2M = Generator().to(device) #Photo ---> Monet
DM2P = Discriminator().to(device) #Monet ---> Photo
DP2M = Discriminator().to(device) #Photo ---> Monet

#Weight Initialize
weight_init(GM2P)
weight_init(GP2M)
weight_init(DM2P)
weight_init(DP2M)

#Load Model

#GM2P_path = '/content/drive/MyDrive/kaggle/model/GM2P.pt'
#GP2M_path = '/content/drive/MyDrive/kaggle/model/GP2M.pt'
#DM2P_path = '/content/drive/MyDrive/kaggle/model/DM2P.pt'
#DP2M_path = '/content/drive/MyDrive/kaggle/model/DP2M.pt'


#GM2P.load_state_dict(torch.load(GM2P_path))
#GP2M.load_state_dict(torch.load(GP2M_path))
#DM2P.load_state_dict(torch.load(DM2P_path))
#DP2M.load_state_dict(torch.load(DP2M_path))



#Loss Functions
#1. GAN loss - L2
gan_loss = nn.MSELoss().to(device)
#2. Cycle loss - L1
cyc_loss = nn.L1Loss().to(device)
#3. Identity loss - L1
iden_loss = nn.L1Loss().to(device)

#Optimizer (Connet Monet <---> Photo)
OptimG = torch.optim.Adam(itertools.chain(GM2P.parameters(), GP2M.parameters()), lr = lr, betas=(0.5, 0.999))
OptimD = torch.optim.Adam(itertools.chain(DM2P.parameters(), DP2M.parameters()), lr = lr, betas=(0.5, 0.999))

schedulerG = lr_scheduler.LambdaLR(optimizer=OptimG,
                                        lr_lambda=lambda epoch: 0.95 ** epoch,
                                        last_epoch=-1,
                                        verbose=False)
schedulerD = lr_scheduler.LambdaLR(optimizer=OptimD,
                                        lr_lambda=lambda epoch: 0.95 ** epoch,
                                        last_epoch=-1,
                                        verbose=False)

#Get Data from Dataloader
train_dataset = CustomDataset(path, monet, photo, transforms = train_transform, seed = 26)
train_dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle  = True, num_workers = 2)
train_total_batch = len(train_dataloader)
loss_G_list = []
loss_D_list = []

> # **âŸ²CycleGAN ArchitectureâŸ²**

![](https://www.researchgate.net/profile/Mihir-Parmar/publication/335840935/figure/fig2/AS:883763248656384@1587716986091/Proposed-CycleGAN-architecture-Here-W-Whisper-and-S-Speech-After-35.jpg)

Proposed CycleGAN architecture. Here, W: Whisper, and S: Speech. After [35].

In [None]:
#Training & Evaluation

best_G_loss = 100
best_D_loss = 100
best_loss = 100
total_patience = 0
for epoch in range(n_epoch):
    
    
    GM2P.train()
    GP2M.train()
    DM2P.train()
    DP2M.train()

    loss_G_avg = 0.0
    loss_D_avg = 0.0
    
    with tqdm(train_dataloader, unit = 'batch') as train_bar:
        
        for monets, photos in train_bar:  
            
            torch.cuda.empty_cache()
            monets = monets.float().to(device)
            photos = photos.float().to(device)
            
            #forward Generator
            update_req_grad([DM2P, DP2M], False)
            
            
            fake_photo = GM2P(monets)
            fake_monet = GP2M(photos)

            cycl_monet = GP2M(fake_photo)
            cycl_photo = GM2P(fake_monet)
            
            ident_monet = GP2M(monets)
            ident_photo = GM2P(photos)
            
            
            #Caculating loss (Identity, Advrsarial, cycle consistency)
            #identity loss
            ident_loss_monet = iden_loss(ident_monet, monets) * 10 * 0.5
            ident_loss_photo = iden_loss(ident_photo, photos) * 10 * 0.5
            #Cycle loss
            cycle_loss_monet = cyc_loss(cycl_monet, monets) * 10
            cycle_loss_photo = cyc_loss(cycl_photo, photos) * 10
            #Adversarial loss
            pred_fake_monet = DM2P(fake_monet.detach())
            pred_fake_photo = DP2M(fake_photo.detach())    
            
            adv_loss_monet = gan_loss(pred_fake_monet, torch.ones_like(pred_fake_monet))
            adv_loss_photo = gan_loss(pred_fake_photo, torch.ones_like(pred_fake_photo))
            
            #Generater Loss
            loss_G = (ident_loss_monet + ident_loss_photo ) + (cycle_loss_monet + cycle_loss_photo) + (adv_loss_monet + adv_loss_photo)
            loss_G_avg += loss_G.item() / train_total_batch
            #Generator Backward
    
            OptimG.zero_grad()
            loss_G.backward(retain_graph=True)
            OptimG.step()
            
            #forward Discriminator
            update_req_grad([DM2P, DP2M], True)
            OptimD.zero_grad()
            
            pred_real_monet = DP2M(photos)
            pred_real_photo = DM2P(monets)
            
            #Discriminator loss
            loss_D_monet = gan_loss(pred_real_monet, torch.ones_like(pred_real_monet)) + gan_loss(pred_fake_monet, torch.zeros_like(pred_fake_monet))
            loss_D_photo = gan_loss(pred_real_photo, torch.ones_like(pred_real_photo)) + gan_loss(pred_fake_photo, torch.zeros_like(pred_fake_photo))
    
            loss_D = (loss_D_monet + loss_D_photo) / 2
            loss_D_avg += loss_D.item() / train_total_batch
            
            #backward
            loss_D.backward()
            OptimD.step()

            train_bar.set_postfix(epoch = epoch+1, loss_G = loss_G.item(),loss_D = loss_D.item())

    schedulerG.step()
    schedulerD.step()

    wandb.log({'Epoch' : epoch+1, "Generator loss": loss_G_avg, 
               "Discriminator loss": loss_D_avg, 
               'G_lr' : OptimG.param_groups[0]['lr'],
               'D_lr' : OptimD.param_groups[0]['lr']})
    
    if ((epoch + 1) == 1) | ((epoch + 1) % 10 == 0):
        print('epoch : {} model save....'.format(epoch+1))

        with torch.no_grad():
          fake_monets = GP2M(photos)
          cycle_monets = GM2P(fake_monets)

    
        photos_ = photos[0,:,:,:]
        fake_monets = fake_monets[0,:,:,:]
        cycle_monets = cycle_monets[0,:,:,:]
        visulization(photos_, fake_monets, cycle_monets)
                
        print('ðŸ”¨Model SaveðŸ”¨')
        torch.save(GM2P.state_dict(), '/content/drive/MyDrive/kaggle/model/GM2P_epoch_{}.pt'.format(epoch+1))
        torch.save(GP2M.state_dict(), '/content/drive/MyDrive/kaggle/model/GP2M_epoch_{}.pt'.format(epoch+1))
        torch.save(DM2P.state_dict(), '/content/drive/MyDrive/kaggle/model/DM2P_epoch_{}.pt'.format(epoch+1))
        torch.save(DP2M.state_dict(), '/content/drive/MyDrive/kaggle/model/DP2M_epoch_{}.pt'.format(epoch+1))
        
wandb.finish()