<a href="https://colab.research.google.com/github/tristanoprofetto/neural-networks/blob/main/GAN/CycleGAN/CycleGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# CycleGAN

In [None]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import VOCSegmentation
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torchvision
import glob
import random
import os
from torch.utils.data import Dataset
from PIL import Image
import torch.nn.functional as F
from skimage import color

### Image Processing

In [None]:
# Function to display images during training/testing
def showImages(image_tensor, n=25, size=(1, 28, 28)):
  image = (image_tensor + 1) / 2
  image = image.detach().cpu().view(-1, *size)
  image_grid = make_grid(image[:n], nrow=5)
  plt.imshow(image.grid.permute(1, 2, 0).squeeze())
  plt.show()


# Object class to load and split the two groups of images
class ImageData(Dataset):

  def __init__(self, root, transform=None, mode='train'):
    self.transform = transform
    self.files_A = sorted(glob.glob(os.path.join(root, '%sA' % mode) + '/*.*'))
    self.files_B = sorted(glob.glob(os.path.join(root, '%sB' % mode) + '/*.*'))

    if len(self.files_A) > len(self.files_B):
      self.files_A, self.files_B = self.files_B, self.files_A
    
    self.newPerm()
    assert len(self.files_A) > 0, "make sure to download the images!!!!"

  
  def newPerm(self):
    self.randperm = torch.randperm(len(self.files_B))[:len(self.files_A)]

  def __getItem__(self, index):

    A = self.transform(Image.open(Self.files_A[index % len(self.files_A)]))
    B = self.transform(Image.open(Self.files_B[self.randperm[index]]))

    if A.shape[0] != 3:
      A = A.repeat(3, 1, 1)
    
    if B.shape[0] != 3:
      B = B.repeat(3, 1, 1)

    if index== len(self) - 1:
      self.newPerm()
    
    return 2*(A - 0.5), 2*(B - 0.5)

  
  def __len__(self):
    return min(len(self.files_A), len(self.files_B))


### Building Blocks

In [None]:
# Residual Blocks for adding previous outputs to the original inputs
class ResidualBlock(nn.Module):

  def __init__(self, inputChannels):
    super(ResidualBlock, self).__init__()

    # Defining Convolutional layers
    self.c1 = nn.Conv2d(inputChannels, inputChannels, kernel_size=3, padding=1, padding_mode='reflect')
    self.c2 = nn.Conv2d(inputChannels, inputChannels, kernel_size=3, padding=1, padding_mode='reflect')
    self.norm = nn.InstanceNorm2d(inputChannels)
    self.activation = nn.ReLU()


  # Function for completing a forward pass of the residual block
  def feedForward(self, x):
    original = x.clone()

    x = self.c1(x)
    x = self.norm(x)
    x = self.activation(x)

    x = self.c2(x)
    x = self.norm(x)
    
    return original + x



# Object class for Downsampling images 
class ContractingBlock(nn.Module):

  def __init__(self, inputChannels, use_bn=True, kernel_size=3, activation='relu'):
    super(ContractingBlock, self).__init__()

    self.c1 = nn.Conv2d(inputChannels, 2 * inputChannels, kernel_size=kernel_size, padding=1, stride=2, padding_mode='reflect')
    self.activation = nn.ReLU()

    if use_bn:
      self.norm = nn.InstanceNorm2d(2 * inputChannels)
    self.use_bn = use_bn


  def feedForward(self, x):

    x = self.c1(x)

    if self.use_bn:
      x = self.norm(x)
    
    x = self.activation(x)
    return x



# Performs a convolutional transpose to Upsample our images
class ExpandingBlock(nn.Module):

  def __init__(self, inputChannels, use_bn=True):
    super(ExpandingBlock, self).__init__()

    self.c1 = nn.ConvTranspose2d(inputChannels, inputChannels // 2, kernel_size=3, stride=2, padding=1, output_padding=1)
    if use_bn:
      self.norm = nn.InstanceNorm2d(inputChannels // 2)
    self.use_bn = use_bn
    self.activation = nn.ReLU()

  
  def feedForward(self, x):

    x= self.c1(x)
    if self.use_bn:
      x = self.norm(x)
    x = self.activation

    return x


# Final Layer of Generator ... maps each output to the desired number of output channels
class FeatureMapBlock(nn.Module):

  def __init__(self, inputChannels, outputChannels):
    super(FeatureMapBlock, self).__init__()
    self.conv = nn.Conv2d(inputChannels, outputChannels, kernel_size=7, padding=3, padding_mode='reflect')

  def feedForward(self, x):
    x = self.conv(x)
    return x

### Generator

In [None]:
class Generator(nn.Module):

  def __init__(self, inputChannels, outputChannels, hiddenChannels=64):
    super(Generator, self).__init__()

    self.upsample = FeatureMapBlock(inputChannels, hiddenChannels)

    self.contract1 = ContractingBlock(hiddenChannels)
    self.contract2 = ContractingBlock(2 * hiddenChannels)

    self.res0 = ResidualBlock(4 * hiddenChannels)
    self.res1 = ResidualBlock(4 * hiddenChannels)
    self.res2 = ResidualBlock(4 * hiddenChannels)
    self.res3 = ResidualBlock(4 * hiddenChannels)
    self.res4 = ResidualBlock(4 * hiddenChannels)
    self.res5 = ResidualBlock(4 * hiddenChannels)
    self.res6 = ResidualBlock(4 * hiddenChannels)
    self.res7 = ResidualBlock(4 * hiddenChannels)
    self.res8 = ResidualBlock(4 * hiddenChannels)

    self.expand1 = ExpandingBlock(4 * hiddenChannels)
    self.expand2 = ExpandingBlock(2 * hiddenChannels)

    self.downSample = FeatureMapBlock(hiddenChannels, outputChannels)

    self.activation = nn.Tanh()


  def feedForward(self, x):

    x = self.upsample(x)
    x = self.contract1(x)
    x = self.contract2(x)
    x = self.res0(x)
    x = self.res1(x)
    x = self.res2(x)
    x = self.res3(x)
    x = self.res4(x)
    x = self.res5(x)
    x = self.res6(x)
    x = self.res7(x)
    x = self.res8(x)
    x = self.expand1(x)
    x = self.expand2(x)
    x = self.activation(x)

    return x

### Discriminator

In [None]:
class Discriminator(nn.Module):

  def __init__(self, inputChannels, hiddenChannels=64):
    super(Discriminator, self).__init__()

    self.upsample = FeatureMapBlock(inputChannels, hiddenChannels)

    self.contract1 = ContractingBlock(hiddenChannels, use_bn=False, kernel_size=4, activation='lrelu')
    self.contract2 = ContractingBlock(2 * hiddenChannels, kernel_size=4, activation='lrelu')
    self.contract3 = ContractingBlock(4 * hiddenChannels, kernel_size=4, activation='lrelu')

    self.output = nn.Conv2d(8 * inputChannels, 1, kernel_size=1)

  def feedForward(self, x):
    x = self.upsample(x)
    x = self.contract1(x)
    x = self.contract2(x)
    x = self.contract3(x)
    x = self.output(x)

    return x

In [None]:
# Data Loading & Augmentation
load_shape = 286
target_shape = 256

transform = transforms.Compose([
                                transforms.Resize(load_shape),
                                transforms.RandomCrop(target_shape),
                                transforms.RandomHorizontalFlip(),
                                transforms.ToTensor()
])

data = ImageData("horse2zebra", transform=transform)

AssertionError: ignored

In [None]:
# Initializing Training Parameters
ganCriterion = nn.MSELoss()
reconstructCriterion = nn.L1Loss()

epochs = 20
dim_A = 3 # Number of channels for images in set A
dim_B = 3 # Number of channels for images in set B
displayStep = 200 # How often to display generated images to track progress during training
batchSize = 1
learnRate = 0.0002
device = 'cuda'


# Initializing Training Variables
gen_AB = Generator(dim_A, dim_B).to(device)
gen_BA = Generator(dim_B, dim_A).to(device)
gen_opt = torch.optim.Adam(list(gen_AB.parameters()) + list(gen_BA.parameters()), lr=learnRate, betas=(0.5, 0.999)) # optimizer for Generator

disc_A = Discriminator(dim_A).to(device)
disc_A_opt = torch.optim.Adam(disc_A.parameters(), lr=learnRate, betas=(0.5, 0.999))
disc_B = Discriminator(dim_B).to(device)
disc_B_opt = torch.optim.Adam(disc_B.parameters(), lr=learnRate, betas=(0.5, 0.999))

def weights(m):
  if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
  if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

###Loss Functions

In [None]:
def genLoss(real_X, disc_Y, gen_XY, ganCriterion):

  fakeY = gen_XY(real_X)
  discFakePred = disc_Y(fake_Y)
  ganLoss = ganCriterion(discFakePred, torch.ones_like(discFakePred))

  return ganLoss, fakeY


In [None]:
def discLoss(real_X, fake_X, disc_X, ganCriterion):

  discFakePred = disc_X(fake_X.detach())
  discFakeLoss = ganCriterion(discFakePred, torch.ones_like(discFakePred))

  discRealPred = disc_X(real_X)
  discRealLoss = ganCriterion(discRealPred, torch.ones_like(discRealPred))

  disc_Loss = (discFakeLoss + discRealLoss) /2

  return disc_Loss

In [None]:
def IdentityLoss(real_X, gen_YX, identityCriterion):
  identity_X = gen_YX(real_X)
  identity_loss = identityCriterion(identity_X, real_X)

  return identity_loss, identity_X



def CycleConsistencyLoss(real_X, fake_Y, fake_X, cycleCriterion):
  cycle_X  = fake_X(fake_Y)
  cycle_loss = cycleCriterion(cycle_X, real_X)

  return cycle_loss, cycle_X


In [None]:
def ModelLoss(real_A, real_B, gen_AB, gen_BA, disc_A, disc_B, ganCriterion, identityCriterion, cycleCriterion, lambdaIdentity=0.1, lambdaCycle=10):

  ganLoss_BA, fake_A = genLoss(real_B, disc_A, gen_BA, ganCriterion)
  ganLoss_AB, fake_B = genLoss(real_A, disc_B, gen_AB, ganCriterion)
  genLoss = ganLoss_BA + ganLoss_AB

  indentityLoss_A, identity_A = IdentityLoss(real_A, gen_BA, identityCriterion)
  identityLoss_B, indentity_B = IdentityLoss(real_B, gen_AB, identityCriterion)
  indentityLoss = identityLoss_A + identityLoss_B

  cycleLoss_BA, cycle_A = get_cycle_consistency_loss(real_A, fake_B, gen_BA, cycleCriterion)
  cycleLoss_AB, cycle_B = get_cycle_consistency_loss(real_B, fake_A, gen_AB, cycleCriterion)
  cycleLoss = cycleLoss_BA + cycleLoss_AB

  totalLoss = genLoss + lambdaIdentity * identityLoss + lambdaCycle * cycleLoss
  

  return totalLoss, fake_A, fake_B


### Model Training

In [None]:
# Function for training the CycleGAN
def train(save_model=False):
    mean_generator_loss = 0
    mean_discriminator_loss = 0
    images = DataLoader(data, batch_size=batchSize, shuffle=True)
    current_step = 0

    for epoch in range(epochs):

      for real_A, real_B in tqdm(images):

        real_A = nn.functional.interpolate(real_A, size=target_shape)
        real_A = real_A.to(device)

        real_B = nn.functional.interpolate(real_B, size=target_shape)
        real_b = real_B.to(device)


        # UPDATING DISCRIMINATOR WEIGHTS

        # Discriminator: A
        disc_A_opt.zero_grad()

        with torch.no_grad():
          fake_A = gen_BA(real_B)
        
        discLoss_A = discLoss(real_A, fake_A, disc_A, ganCriterion)
        discLoss_A.backward(retain_graph=True)
        disc_A_opt.step()

        # Discriminator: B
        disc_B_opt.zero_grad()

        with torch.no_grad():
          fake_B = genAB(real_A)
        
        discLoss_B = discLoss(real_B, fake_B, disc_B, ganCriterion)
        discLoss_B.backward(retain_graph=True)
        disc_B_opt.step()


        # UPDATING GENERATOR WEIGHTS

        gen_opt.zero_grad()
        genLoss, fake_A, fake_B = ModelLoss(real_A, real_B, gen_AB, gen_BA, disc_A, disc_B, ganCriterion, reconstructCriterion, reconstructCriterion)

        genLoss.backward(retain_graph=True)
        gen_opt.step()

        mean_discriminator_loss += discLoss_A.item() / displayStep
        mean_generator_loss += genLoss.item() / displayStep

        
        # VISUALIZE TRAINING PROGRESS
        if current_step % displayStep == 0:

          print(f"Epoch {epoch}: Step {current_step}: Generator Loss: {mean_generator_loss}, Discriminator Loss: {mean_discriminator_loss}")

          showImages(torch.cat([real_A, real_B]), size=(dim_a, target_shape, target_shape))
          showImages(torch.cat([fake_B, fake_A]), size=(dim_B, target_shape, target_shape))

          if save_model:
            torch.save({
                'gen_AB': gen_AB.state_dict(),
                        'gen_BA': gen_BA.state_dict(),
                        'gen_opt': gen_opt.state_dict(),
                        'disc_A': disc_A.state_dict(),
                        'disc_A_opt': disc_A_opt.state_dict(),
                        'disc_B': disc_B.state_dict(),
                        'disc_B_opt': disc_B_opt.state_dict()
                    }, f"cycleGAN_{current_step}.pth")
  
            current_step +=1
