<a href="https://colab.research.google.com/github/zhisheng-hua/Reproducing-and-Extending-the-CycleGAN-Model/blob/main/CycleGAN_Draft1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Environment Setup

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import itertools
import os
import random
import sys
import time

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import Dataset
from torchvision.utils import save_image
from torchvision import transforms
from tqdm import tqdm

# Hyperparameters and Constant Variables

In [4]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

Device: cuda


In [5]:
DATASET_DIR = "/content/drive/MyDrive/ECE-GY 6953 DL/Project/Datasets/horse2zebra/"
TRAIN_A_PATH = DATASET_DIR + "trainA"
TRAIN_B_PATH = DATASET_DIR + "trainA"
TEST_A_PATH = DATASET_DIR + "testA"
TEST_B_PATH = DATASET_DIR + "testA"

MODEL_DIR = "/content/drive/MyDrive/ECE-GY 6953 DL/Project/SavedModels/"

VISUAL_DIR = "/content/drive/MyDrive/ECE-GY 6953 DL/Project/Visualizations/"

In [24]:
BATCH_SIZE = 1
NUM_EPOCHS = 100
NUM_EPOCHS_DECAY = 0 # the number of last epoches during which lr decays to 0

LR = 0.0002
LAMBDA_CYCLE = 10.0
LAMBDA_IDT = 0.5
IMAGE_POOL_SIZE = 50

MAX_NUM_SAMPLES = 1000 # if larger than the number of training/testing samples, use all samples
NUM_DATALOADER_WORKERS = 2
SAVE_MODELS = True
SAVE_LOSS_PLOTS = False

# Helper Functions

In [7]:
# The code in this cell is adapted from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix.git.

# Method to obtain/insert images from/into the image pool
def query_image_pool(image_pool, pool_size, curr_images):
  if pool_size == 0:
      return curr_images
  ret_images = []
  for image in curr_images:
      image = torch.unsqueeze(image.data, 0)
      if len(image_pool) < pool_size: # if buffer is not full
          image_pool.append(image)
          ret_images.append(image)
      else: # if buffer is full
          # by 50% chance, return an old image from the pool and replace it with
          # a current image; otherwise, return the current image
          p = random.uniform(0, 1)
          if p > 0.5:
              random_id = random.randint(0, pool_size - 1)
              old_image = image_pool[random_id].clone()
              image_pool[random_id] = image
              ret_images.append(old_image)
          else:
              ret_images.append(image)
  return torch.cat(ret_images, 0)

In [8]:
# The code in this cell is adapted from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix.git.

# Method to initialize weights of the given model
def init_weights(model):
  def init_func(m):
    classname = m.__class__.__name__
    if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
      init.normal_(m.weight.data, 0.0, 0.02)
      if hasattr(m, 'bias') and m.bias is not None:
          init.constant_(m.bias.data, 0.0)
    elif classname.find('BatchNorm2d') != -1:
      init.normal_(m.weight.data, 1.0, 0.02)
      init.constant_(m.bias.data, 0.0)
  model.apply(init_func)

In [9]:
# The code in this cell is adapted from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix.git.

# Method to post-processing the output image of the generator
def post_processing_image(image):
  image = image.cpu().float().numpy()
  image_np = (np.transpose(image, (1, 2, 0)) + 1) / 2.0 * 255.0
  return image_np.astype(np.uint8)

# Data Preparation

In [10]:
class CustomDataset(Dataset):
    def __init__(self, A_dir, B_dir, transform, random_pairs=True):
        self.A_dir = A_dir
        self.B_dir = B_dir
        self.transform = transform
        self.random_pairs = random_pairs

        self.A_img_filenames = os.listdir(self.A_dir)[:MAX_NUM_SAMPLES]
        self.B_img_filenames = os.listdir(self.B_dir)[:MAX_NUM_SAMPLES]
        self.A_size = len(self.A_img_filenames)
        self.B_size = len(self.B_img_filenames)

    def __len__(self):
        return max(self.A_size, self.B_size)

    def __getitem__(self, index):
        A_img_path = os.path.join(self.A_dir,
                                  self.A_img_filenames[index % self.A_size])
        if self.random_pairs:
          B_index = random.randint(0, self.B_size - 1)
        else:
          B_index = index % self.B_size
        B_img_path = os.path.join(self.B_dir, self.B_img_filenames[B_index])

        A_img = self.transform(Image.open(A_img_path).convert("RGB"))
        B_img = self.transform(Image.open(B_img_path).convert("RGB"))

        return A_img, B_img

In [11]:
# Define image transformation
train_data_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomHorizontalFlip(0.5),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
])

test_data_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.ToTensor()
])

In [12]:
# Setup datasets and dataloaders
train_dataset = CustomDataset(TRAIN_A_PATH, TRAIN_B_PATH, train_data_transform,
                              random_pairs=True)
test_dataset = CustomDataset(TEST_A_PATH, TEST_B_PATH, test_data_transform,
                             random_pairs=False)
print("Number of training samples:", len(train_dataset))
print("Number of testing samples:", len(train_dataset))

train_data_loader = torch.utils.data.DataLoader(train_dataset,
                                                batch_size=BATCH_SIZE,
                                                shuffle=True,
                                                num_workers=NUM_DATALOADER_WORKERS)
test_data_loader = torch.utils.data.DataLoader(test_dataset,
                                               batch_size=BATCH_SIZE,
                                               shuffle=True,
                                               num_workers=NUM_DATALOADER_WORKERS)

Number of training samples: 1
Number of testing samples: 1


# Model Architecture

## Generator
TODO Modify the code

In [13]:
class ConvBlock(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, padding, down=True, use_act=True, **kwargs):
        super().__init__()
        if down:
          self.conv1 = nn.Conv2d(in_channel, out_channel, padding_mode="reflect", kernel_size=kernel_size, padding=padding, **kwargs)
        else:
          self.conv1 = nn.ConvTranspose2d(in_channel, out_channel, kernel_size=kernel_size, padding=padding, **kwargs)
        self.IN = nn.InstanceNorm2d(out_channel)
        if use_act:
          self.acv = nn.ReLU(inplace=True)
        else:
          self.acv = nn.Identity()
    def forward(self, x):
      x = self.conv1(x)
      x = self.IN(x)
      return  self.acv(x)

In [14]:
class ResBlock(nn.Module):
  def __init__(self, channels):
    super().__init__()
    self.Conv1 = ConvBlock(channels, channels, kernel_size=3, padding=1)
    self.Conv2 = ConvBlock(channels, channels, kernel_size=3, padding=1, use_act=False)
  
  def forward(self, x):
    x2 = self.Conv1(x)
    x2 = self.Conv2(x2)
    return x + x2

In [15]:

class Generator(nn.Module):
  def __init__(self, img_channels, num_features=64, num_residual=9):
    super().__init__()
    self.conv1 = nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect")
    self.acv1 = nn.ReLU(inplace=True)

    self.DownBlock = nn.ModuleList([
        nn.Conv2d(num_features, num_features*2, kernel_size=3, stride=2, padding=1),
        nn.Conv2d(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1),
    ])

    self.residual_blocks = nn.Sequential(
        *[ResBlock(num_features*4) for _ in range(num_residual)]
    )

    self.up_blocks = nn.ModuleList([
        ConvBlock(num_features*4, num_features*2, kernel_size=3, stride=2, padding=1, output_padding=1, down=False),
        ConvBlock(num_features*2, num_features*1, kernel_size=3, stride=2, padding=1, output_padding=1, down=False),
    ])

    self.last_layer = nn.Conv2d(num_features, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")

  def forward(self, x):
    x = self.conv1(x)
    x = self.acv1(x)
    for layer in self.DownBlock:
      x = layer(x)

    for layer in self.up_blocks:
      x = layer(x)
    return torch.tanh(self.last_layer(x))
    


## Discriminator
TODO Modify the code

In [16]:
class DiscBlock(nn.Module):
  def __init__(self, in_channel, out_channel, stride):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channel, out_channel, 4, stride, 1, bias=True, padding_mode="reflect")
    self.IN1 = nn.InstanceNorm2d(out_channel)
    self.act1 = nn.LeakyReLU(0.2)
  def forward(self, x):
    x = self.conv1(x)
    x = self.IN1(x)
    x = self.act1(x)
    return x

In [17]:
class Discriminator(nn.Module):
  def __init__(self, in_channel, num_features=[64, 128, 256, 512]):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channel, num_features[0], kernel_size=4, stride=2, padding=1, padding_mode="reflect")
    self.act1 = nn.LeakyReLU(0.2)

    layers = []
    in_channel = num_features[0]
    
    for counter in range(1,len(num_features)):

      layers.append(DiscBlock(in_channel, num_features[counter], stride=1 if counter==(len(num_features)-1) else 2))
      in_channel = num_features[counter]
    layers.append(nn.Conv2d(in_channel, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"))
    self.model = nn.Sequential(*layers)


  def forward(self, x):
    x = self.conv1(x)
    x = self.act1(x)
    x = self.model(x)
    return torch.sigmoid(x)

   

## CycleGAN

In [18]:
# Method to train CycleGAN for one epoch
# Combine generators and discriminators together
# Calculate losses and backpropagation
def train_CycleGAN_for_one_epoch(G_A, D_B, G_B, D_A,
                                 cri_GAN, cri_Cycle, cri_Idt,
                                 opt_G, opt_D,
                                 data_loader,
                                 use_image_pools=False, image_pools=None):
  model_name = "CycleGAN_" + os.path.basename(DATASET_DIR[:-1])

  avg_loss_G = 0
  avg_loss_D_B = 0 # Discriminator for A->B
  avg_loss_D_A = 0 # Discriminator for B->A

  G_A.train()
  D_B.train()
  G_B.train()
  D_A.train()

  # labels for discriminators
  # shape needs to change as the output of discriminators changes
  D_labels_real = torch.ones([BATCH_SIZE,1,30,30]).to(DEVICE)
  D_labels_fake = torch.zeros([BATCH_SIZE,1,30,30]).to(DEVICE)

  for i, (A_real, B_real) in enumerate(tqdm(data_loader)):
    A_real = A_real.to(DEVICE)
    B_real = B_real.to(DEVICE)

    # forward pass; generate fake images
    B_fake = G_A(A_real) # G_A(A)
    A_fake = G_B(B_real) # G_B(B)

    # if use old generated images to train discriminators
    if use_image_pools:
      B_fake_for_D = query_image_pool(image_pools['B'], IMAGE_POOL_SIZE,
                                      B_fake.detach())
      A_fake_for_D = query_image_pool(image_pools['A'], IMAGE_POOL_SIZE,
                                      A_fake.detach())
      if i == 0:
        model_name += "_ImagePool"
    else:
      B_fake_for_D = B_fake.detach()
      A_fake_for_D = A_fake.detach()
      if i == 0:
        model_name += "_NoImagePool"

    ### Update discriminators
    opt_D.zero_grad()
    ## Update D_B (discriminator for A->B)
    # Real images
    loss_D_B_real = cri_GAN(D_B(B_real), D_labels_real)
    # Fake images
    loss_D_B_fake = cri_GAN(D_B(B_fake_for_D), D_labels_fake)
    # Combine losses and compute gradients
    loss_D_B = (loss_D_B_real + loss_D_B_fake) * 0.5
    loss_D_B.backward()
    ## Update D_A (discriminator for B->A)
    # Real images
    loss_D_A_real = cri_GAN(D_A(A_real), D_labels_real)
    # Fake images
    loss_D_A_fake = cri_GAN(D_A(A_fake_for_D), D_labels_fake)
    # Combine losses and compute gradients
    loss_D_A = (loss_D_A_real + loss_D_A_fake) * 0.5
    loss_D_A.backward()
    # Update the optimizer
    opt_D.step()

    ### Update generators
    opt_G.zero_grad()
    # GAN losses
    loss_GAN_G_A = cri_GAN(D_B(B_fake), D_labels_real)
    loss_GAN_G_B = cri_GAN(D_A(A_fake), D_labels_real)
    # Cycle losses
    loss_Cycle_G_A = cri_Cycle(G_B(B_fake), A_real) # loss(G_B(G_A(A)), A)
    loss_Cycle_G_B = cri_Cycle(G_A(A_fake), B_real) # loss(G_A(G_B(B)), B)
    # Identity losses
    loss_Idt_G_A = cri_Idt(G_A(B_real), B_real) # loss(G_A(B), B)
    loss_Idt_G_B = cri_Idt(G_B(A_real), A_real) # loss(G_B(A), A)
    # Combine all losses
    loss_G = (loss_GAN_G_A + loss_GAN_G_B
              + LAMBDA_CYCLE * (loss_Cycle_G_A + loss_Cycle_G_B
                                + LAMBDA_IDT * (loss_Idt_G_A + loss_Idt_G_B)))
    # Update
    loss_G.backward()
    opt_G.step()

    avg_loss_G += loss_G.item()
    avg_loss_D_B += loss_D_B.item()
    avg_loss_D_A += loss_D_A.item()

  avg_loss_G = avg_loss_G / len(data_loader)
  avg_loss_D_B = avg_loss_D_B / len(data_loader)
  avg_loss_D_A = avg_loss_D_A / len(data_loader)

  return model_name, (avg_loss_G, avg_loss_D_B, avg_loss_D_A)

# Training

In [22]:
# Training setup

# Generator and Discriminator models
# Mapping A->B
G_A = Generator(img_channels=3, num_residual=9).to(DEVICE)
D_B = Discriminator(in_channel=3).to(DEVICE)
# Mapping B->A
G_B = Generator(img_channels=3, num_residual=9).to(DEVICE)
D_A = Discriminator(in_channel=3).to(DEVICE)

# Initialize the weights of models
init_weights(G_A)
init_weights(D_B)
init_weights(G_B)
init_weights(D_A)

# Losses
cri_GAN = nn.MSELoss() # Least-Squares GAN; see Reference 35 of the paper
cri_Cycle = nn.L1Loss()
cri_Idt = nn.L1Loss()

# Optimizers
opt_G = torch.optim.Adam(itertools.chain(G_A.parameters(), G_B.parameters()),
                         lr=LR, betas=(0.5, 0.999))
opt_D = torch.optim.Adam(itertools.chain(D_B.parameters(), D_A.parameters()),
                         lr=LR, betas=(0.5, 0.999))

# Image pools
image_pools = {'A': [], 'B': []}

# Schedulers for learning rate decay
def lambda_rule(epoch): # adapted from authors' Github repo 
  lr_l = 1.0 - max(0, epoch - (NUM_EPOCHS - NUM_EPOCHS_DECAY)) / float(NUM_EPOCHS_DECAY + 1)
  return lr_l
sclr_G = lr_scheduler.LambdaLR(opt_G, lr_lambda=lambda_rule)
sclr_D = lr_scheduler.LambdaLR(opt_D, lr_lambda=lambda_rule)

In [23]:
# Train for all epoches
losses_G = []
losses_D_B = []
losses_D_A = []
for epoch in range(1, NUM_EPOCHS+1):
  print("Epoch: {} / {}".format(epoch, NUM_EPOCHS))

  # Training process
  model_name, stats = train_CycleGAN_for_one_epoch(G_A, D_B, G_B, D_A,
                                                   cri_GAN, cri_Cycle, cri_Idt,
                                                   opt_G, opt_D,
                                                   train_data_loader,
                                                   use_image_pools=True,
                                                   image_pools=image_pools)
  sclr_G.step()
  sclr_D.step()

  losses_G.append(stats[0])
  losses_D_B.append(stats[1])
  losses_D_A.append(stats[2])

  if epoch == NUM_EPOCHS:
    print()
  print("----------------------------------------------------------------------------------")
  print("  Loss G: %f, Loss D_B: %f, Loss D_A: %f, Next LR: %f" % (
      stats[0], stats[1], stats[2], sclr_G.get_last_lr()[0]
  ))
  print("----------------------------------------------------------------------------------")

Epoch: 1 / 1


100%|██████████| 1/1 [00:00<00:00,  3.90it/s]


----------------------------------------------------------------------------------
  Loss G: 19.023766, Loss D_B: 0.309831, Loss D_A: 0.311158, Next LR: 0.000100
----------------------------------------------------------------------------------





In [31]:
if SAVE_MODELS:
  torch.save(G_A.cpu().state_dict(), MODEL_DIR + model_name + "_GA.pt")
  torch.save(D_B.cpu().state_dict(), MODEL_DIR + model_name + "_DB.pt")
  torch.save(G_B.cpu().state_dict(), MODEL_DIR + model_name + "_GB.pt")
  torch.save(D_A.cpu().state_dict(), MODEL_DIR + model_name + "_DA.pt")

In [32]:
!ls "/content/drive/MyDrive/ECE-GY 6953 DL/Project/SavedModels/" -l

total 221082
-rw------- 1 root root 11062975 May 13 07:48 CycleGAN_horse2zebra_ImagePool_DA.pt
-rw------- 1 root root 11062975 May 13 07:48 CycleGAN_horse2zebra_ImagePool_DB.pt
-rw------- 1 root root 45533417 May 13 07:48 CycleGAN_horse2zebra_ImagePool_GA.pt
-rw------- 1 root root 45533417 May 13 07:48 CycleGAN_horse2zebra_ImagePool_GB.pt
-rw------- 1 root root 11062999 May 12 01:48 CycleGAN_horse2zebra_NoImagePool_DA.pt
-rw------- 1 root root 11062999 May 12 01:48 CycleGAN_horse2zebra_NoImagePool_DB.pt
-rw------- 1 root root 45533517 May 12 01:48 CycleGAN_horse2zebra_NoImagePool_GA.pt
-rw------- 1 root root 45533517 May 12 01:48 CycleGAN_horse2zebra_NoImagePool_GB.pt


# Evalutaion

## Loss Curves

In [None]:
plt.figure()
plt.plot(losses_G, label="Loss G")
plt.plot(losses_D_B, label="Loss D_B")
plt.plot(losses_D_A, label="Loss D_A")
plt.xlabel("Epochs")
plt.ylabel("Losses")
plt.legend()

if SAVE_LOSS_PLOTS:
  plt.savefig(VISUAL_DIR + model_name + "_losses.png")

## Test Sample Visualization

In [None]:
tmp_real = next(iter(test_data_loader))[0][0] # A
print(tmp_real.dtype)
plt.imshow(tmp_real.permute(1,2,0))
plt.show()

In [None]:
G_A.eval()
tmp_fake = G_A(tmp_real.float().to(DEVICE)).cpu().detach()

In [None]:
tmp_fake_post = tmp_fake.cpu().float().numpy()
tmp_fake_post = (np.transpose(tmp_fake_post, (1, 2, 0)) + 1) / 2.0 * 255.0
tmp_fake_post = tmp_fake_post.astype(np.uint8)
plt.imshow(tmp_fake_post)
plt.show()