<a href="https://colab.research.google.com/github/zhisheng-hua/Reproducing-and-Extending-the-CycleGAN-Model/blob/cy/CycleGAN_Draft1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports

In [1]:
# Import necessary packages
import os 
import itertools

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
from PIL import Image
# from tqdm import tqdm
# import albumentations as A
# from albumentations.pytorch import ToTensor
from torch.types import Number
import sys
# from torch.utils.data import DataLoader
# import torch.optim as optim
# from tqdm import tqdm
from torchvision.utils import save_image
# import albumentations as A

from torchvision import transforms

import time

import random

# Hyperparameters

In [2]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

Device: cuda


In [20]:
LR = 1e-5
LAMBDA_IDT = 10
LAMBDA_CYCLE = 5

BATCH_SIZE = 64
NUM_EPOCHS = 3

NUM_WORKERS = 4

LOAD_MODEL=False
SAVE_MODEL=False

# Functions

In [4]:

def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)



# Dataset

In [4]:
class CustomDataSet(Dataset):
    def __init__(self, A_dir, B_dir, transform, random_pairs=False):
        self.A_dir = A_dir
        self.B_dir = B_dir
        self.transform = transform
        self.random_pairs = random_pairs

        self.A_img_filenames = os.listdir(self.A_dir)
        self.B_img_filenames = os.listdir(self.B_dir)
        self.A_size = len(self.A_img_filenames)
        self.B_size = len(self.B_img_filenames)

    def __len__(self):
        return max(self.A_size, self.B_size)

    def __getitem__(self, index):
        A_img_path = os.path.join(self.A_dir,
                                  self.A_img_filenames[index % self.A_size])
        if self.random_pairs:
          B_index = random.randint(0, self.B_size - 1)
        else:
          B_index = index % self.B_size
        B_img_path = os.path.join(self.B_dir, self.B_img_filenames[B_index])

        A_img = self.transform(Image.open(A_img_path))
        B_img = self.transform(Image.open(B_img_path))

        return A_img, B_img

In [6]:
# !mkdir ./Segment ./Ground

In [5]:
path1 = '/content/drive/MyDrive/ECE-GY 6953 DL/Project/Datasets/mini/trainA'
path2 = '/content/drive/MyDrive/ECE-GY 6953 DL/Project/Datasets/mini/trainB'

In [6]:
# Define image transformation

train_data_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomHorizontalFlip(0.5),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
])

test_data_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.ToTensor()
])

In [7]:
# Prepare data for input
train_dataset = CustomDataSet(path1, path2, train_data_transform)
# TODO: test_dataset

train_data_loader = torch.utils.data.DataLoader(train_dataset,
                                                batch_size=BATCH_SIZE,
                                                shuffle=True,
                                                num_workers=NUM_WORKERS)
# TODO: test_data_loader



# Generator
TODO Modify the code

In [8]:
class ConvBlock(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, padding, down=True, use_act=True, **kwargs):
        super().__init__()
        if down:
          self.conv1 = nn.Conv2d(in_channel, out_channel, padding_mode="reflect", kernel_size=kernel_size, padding=padding, **kwargs)
        else:
          self.conv1 = nn.ConvTranspose2d(in_channel, out_channel, kernel_size=kernel_size, padding=padding, **kwargs)
        self.IN = nn.InstanceNorm2d(out_channel)
        if use_act:
          self.acv = nn.ReLU(inplace=True)
        else:
          self.acv = nn.Identity()
    def forward(self, x):
      x = self.conv1(x)
      x = self.IN(x)
      return  self.acv(x)

In [9]:
class ResBlock(nn.Module):
  def __init__(self, channels):
    super().__init__()
    self.Conv1 = ConvBlock(channels, channels, kernel_size=3, padding=1)
    self.Conv2 = ConvBlock(channels, channels, kernel_size=3, padding=1, use_act=False)
  
  def forward(self, x):
    x2 = self.Conv1(x)
    x2 = self.Conv2(x2)
    return x + x2

In [10]:

class Generator(nn.Module):
  def __init__(self, img_channels, num_features=64, num_residual=9):
    super().__init__()
    self.conv1 = nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect")
    self.acv1 = nn.ReLU(inplace=True)

    self.DownBlock = nn.ModuleList([
        nn.Conv2d(num_features, num_features*2, kernel_size=3, stride=2, padding=1),
        nn.Conv2d(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1),
    ])

    self.residual_blocks = nn.Sequential(
        *[ResBlock(num_features*4) for _ in range(num_residual)]
    )

    self.up_blocks = nn.ModuleList([
        ConvBlock(num_features*4, num_features*2, kernel_size=3, stride=2, padding=1, output_padding=1, down=False),
        ConvBlock(num_features*2, num_features*1, kernel_size=3, stride=2, padding=1, output_padding=1, down=False),
    ])

    self.last_layer = nn.Conv2d(num_features, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")

  def forward(self, x):
    x = self.conv1(x)
    x = self.acv1(x)
    for layer in self.DownBlock:
      x = layer(x)

    for layer in self.up_blocks:
      x = layer(x)
    return torch.tanh(self.last_layer(x))
    


# Discriminator
TODO Modify the code

In [11]:
class DiscBlock(nn.Module):
  def __init__(self, in_channel, out_channel, stride):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channel, out_channel, 4, stride, 1, bias=True, padding_mode="reflect")
    self.IN1 = nn.InstanceNorm2d(out_channel)
    self.act1 = nn.LeakyReLU(0.2)
  def forward(self, x):
    x = self.conv1(x)
    x = self.IN1(x)
    x = self.act1(x)
    return x

In [12]:
class Discriminator(nn.Module):
  def __init__(self, in_channel, num_features=[64, 128, 256, 512]):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channel, num_features[0], kernel_size=4, stride=2, padding=1, padding_mode="reflect")
    self.act1 = nn.LeakyReLU(0.2)

    layers = []
    in_channel = num_features[0]
    
    for counter in range(1,len(num_features)):

      layers.append(DiscBlock(in_channel, num_features[counter], stride=1 if counter==(len(num_features)-1) else 2))
      in_channel = num_features[counter]
    layers.append(nn.Conv2d(in_channel, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"))
    self.model = nn.Sequential(*layers)


  def forward(self, x):
    x = self.conv1(x)
    x = self.act1(x)
    x = self.model(x)
    return torch.sigmoid(x)

   

# Training

In [13]:
# Method to train for one epoch
def train_CycleGAN_for_one_epoch(G_A, D_A, G_B, D_B,
                                 cri_GAN, cri_Cycle, cri_Idt,
                                 opt_G, opt_D,
                                 data_loader):
  G_A.train()
  D_A.train()
  G_B.train()
  D_B.train()

  # TODO: modify shape
  D_labels_real = torch.ones([BATCH_SIZE,1,30,30]).to(DEVICE)
  D_labels_fake = torch.zeros([BATCH_SIZE,1,30,30]).to(DEVICE)

  for i, (A_real, B_real) in enumerate(data_loader):
    A_real = A_real.to(DEVICE)
    B_real = B_real.to(DEVICE)

    if A_real.shape[0] < BATCH_SIZE:
      D_labels_real = D_labels_real[:A_real.shape[0]]
      D_labels_fake = D_labels_fake[:A_real.shape[0]]

    # forward pass; generate fake images
    B_fake = G_A(A_real) # G_A(A)
    A_fake = G_B(B_real) # G_B(B)

    ## Update generators
    opt_G.zero_grad()
    # GAN losses
    loss_GAN_G_A = cri_GAN(D_A(B_fake), D_labels_real)
    loss_GAN_G_B = cri_GAN(D_B(A_fake), D_labels_real)
    # Cycle losses
    loss_Cycle_G_A = cri_Cycle(G_B(B_fake), A_real) # loss(G_B(G_A(A)), A)
    loss_Cycle_G_B = cri_Cycle(G_A(A_fake), B_real) # loss(G_A(G_B(B)), B)
    # Identity losses
    loss_Idt_G_A = cri_Idt(G_A(B_real), B_real) # loss(G_A(B), B)
    loss_Idt_G_B = cri_Idt(G_B(A_real), A_real) # loss(G_B(A), A)
    # Combine all losses
    loss_G = (loss_GAN_G_A + loss_GAN_G_B
              + LAMBDA_CYCLE * (loss_Cycle_G_A + loss_Cycle_G_B)
              + LAMBDA_IDT * (loss_Idt_G_A + loss_Idt_G_B))
    # Update
    loss_G.backward()
    opt_G.step()

    ## Update discriminators
    opt_D.zero_grad()
    # Update D_A (discriminator for A->B)
    # Real images
    loss_D_A_real = cri_GAN(D_A(B_real), D_labels_real)
    # Fake images
    loss_D_A_fake = cri_GAN(D_A(B_fake.detach()), D_labels_fake)
    # Combine losses and compute gradients
    loss_D_A = (loss_D_A_real + loss_D_A_fake) * 0.5
    loss_D_A.backward()
    # Update D_B (discriminator for B->A)
    # Real images
    loss_D_B_real = cri_GAN(D_B(A_real), D_labels_real)
    # Fake images
    loss_D_B_fake = cri_GAN(D_B(A_fake.detach()), D_labels_fake)
    # Combine losses and compute gradients
    loss_D_B = (loss_D_B_real + loss_D_B_fake) * 0.5
    loss_D_B.backward()
    # Update the optimizer
    opt_D.step()

In [21]:
# Training setup

# Generator and Discriminator models
G_A = Generator(img_channels=3, num_residual=9).to(DEVICE)
D_A = Discriminator(in_channel=3).to(DEVICE)
G_B = Generator(img_channels=3, num_residual=9).to(DEVICE)
D_B = Discriminator(in_channel=3).to(DEVICE)

# Losses
cri_GAN = nn.MSELoss() # Least-Squares GAN; see Reference 35 of the paper
cri_Cycle = nn.L1Loss()
cri_Idt = nn.L1Loss()

# Optimizers
opt_G = torch.optim.Adam(itertools.chain(G_A.parameters(), G_B.parameters()),
                         lr=LR, betas=(0.5, 0.999))
opt_D = torch.optim.Adam(itertools.chain(D_A.parameters(), D_B.parameters()),
                         lr=LR, betas=(0.5, 0.999))

In [22]:
# Train for all epoches
for epoch in range(NUM_EPOCHS):
  print("Epoch: ", epoch)
  s_time = time.time_ns()

  # Training process
  train_CycleGAN_for_one_epoch(G_A, D_A, G_B, D_B,
                               cri_GAN, cri_Cycle, cri_Idt,
                               opt_G, opt_D,
                               train_data_loader)
  
  e_time = time.time_ns()
  print("  Time elapsed: {} min".format((e_time - s_time) / 1e9 / 60))

Epoch:  0
  Time elapsed: 0.01541389385 min
Epoch:  1
  Time elapsed: 0.0305093105 min
Epoch:  2
  Time elapsed: 0.03222081765 min


In [19]:
# def TrainCicleGAN(Disc_Ground, Gen_Seg, Disc_Seg, Gen_Ground, loader, optim_Disc, optim_Gen, L1, MSE, Disc_scaler, Gen_scaler):
#   # loop_prop = tqdm(loader, leave=True)
#   index2 = 0
#   # for index, (segm, grd) in enumerate(loop_prop):
#   for index, (segm, grd) in enumerate(loader):
#     grd = grd.to(DEVICE)
#     segm = segm.to(DEVICE)

#     H_reals = 0
#     H_fakes = 0
#     with torch.cuda.amp.autocast():
#         print("test")
#         fake_ground = Gen_Ground(segm)
#         D_H_real = Disc_Ground(grd)
#         D_H_fake = Disc_Ground(fake_ground.detach())
#         H_reals += D_H_real.mean().item()
#         H_fakes += D_H_fake.mean().item()
#         D_H_real_loss = MSE(D_H_real, torch.ones_like(D_H_real))
#         D_H_fake_loss = MSE(D_H_fake, torch.zeros_like(D_H_fake))
#         D_H_loss = D_H_real_loss + D_H_fake_loss

#         print(torch.ones_like(D_H_real).shape)
#         print(torch.ones_like(D_H_real))
       
#         fake_seg = Gen_Seg(grd)
#         D_Z_real = Disc_Seg(segm)
#         D_Z_fake = Disc_Seg(fake_seg.detach())
#         D_Z_real_loss = MSE(D_Z_real, torch.ones_like(D_Z_real))
#         D_Z_fake_loss = MSE(D_Z_fake, torch.zeros_like(D_Z_fake))
#         D_Z_loss = D_Z_real_loss + D_Z_fake_loss

#         # put it togethor
#         D_loss = (D_H_loss + D_Z_loss)/2

#     optim_Disc.zero_grad()
#     Disc_scaler.scale(D_loss).backward()
#     Disc_scaler.step(optim_Disc)
#     Disc_scaler.update()

#     # Train Generators H and Z
#     with torch.cuda.amp.autocast():
#         # adversarial loss for both generators
#         D_H_fake = Disc_Ground(fake_ground)
#         D_Z_fake = Disc_Seg(fake_seg)
#         loss_G_H = MSE(D_H_fake, torch.ones_like(D_H_fake))
#         loss_G_Z = MSE(D_Z_fake, torch.ones_like(D_Z_fake))

#         # cycle loss
#         cycle_segm = Gen_Seg(fake_ground)
#         cycle_ground = Gen_Ground(fake_seg)
#         cycle_segm_loss = L1(segm, cycle_segm)
#         cycle_ground_loss = L1(grd, cycle_ground)

#         # identity loss (remove these for efficiency if you set lambda_identity=0)
#         identity_segm = Gen_Seg(segm)
#         identity_ground = Gen_Ground(grd)
#         identity_segm_loss = L1(segm, identity_segm)
#         identity_ground_loss = L1(grd, identity_ground)

#         # add all togethor
#         G_loss = (
#             loss_G_Z
#             + loss_G_H
#             + cycle_segm_loss * LAMBDA_CYCLE
#             + cycle_ground_loss * LAMBDA_CYCLE
#             # + identity_ground_loss * LAMBDA_IDENTITY
#             # + identity_segm_loss * LAMBDA_IDENTITY
#         )

#     optim_Gen.zero_grad()
#     Gen_scaler.scale(G_loss).backward()
#     Gen_scaler.step(optim_Gen)
#     Gen_scaler.update()

#     # if index2 % 200 == 0:
#         # save_image(fake_ground, SAVE_IMAGE + f"/Ground/{index2}.png")
#         # save_image(fake_seg, SAVE_IMAGE + f"/Segment/{index2}.png")
#     save_image(fake_ground, f"./Ground/{index2}.png")
#     save_image(fake_seg, f"./Segment/{index2}.png")
#     index2 += 1
#     loop_prop.set_postfix(H_real=H_reals/(index+1), H_fake=H_fakes/(index+1))

#   return fake_ground, fake_seg



# Main Function

In [None]:
# from torchvision import datasets

In [None]:



Disc_Ground = Discriminator(in_channel=3).to(DEVICE)
Disc_Seg = Discriminator(in_channel=3).to(DEVICE)
Gen_Ground = Generator(img_channels=3, num_residual=9).to(DEVICE)
Gen_Seg = Generator(img_channels=3, num_residual=9).to(DEVICE)





optim_Disc = torch.optim.Adam(list(Disc_Ground.parameters()) + list(Disc_Seg.parameters()), lr=LR, betas=(0.5, 0.999),)
optim_Gen = torch.optim.Adam(list(Disc_Ground.parameters()) + list(Gen_Seg.parameters()), lr=LR, betas=(0.5, 0.999),)

L1 = nn.L1Loss()
MSE = nn.MSELoss()


# if LOAD_MODEL:

#   Disc_Ground.load_state_dict(torch.load(WEIGHTS_DISC_G))
#   Disc_Seg.load_state_dict(torch.load(WEIGHTS_DISC_S))
#   Gen_Ground.load_state_dict(torch.load(WEIGHTS_GEN_G))
#   Gen_Seg.load_state_dict(torch.load(WEIGHTS_GEN_S))
#     # load_checkpoint(
#     #     WEIGHTS_GEN_G, Gen_Ground, optim_Gen, LEARNING_RATE,
#     # )
#     # load_checkpoint(
#     #     WEIGHTS_GEN_S, Gen_Seg, optim_Gen, LEARNING_RATE,
#     # )
#     # load_checkpoint(
#     #     WEIGHTS_DISC_G, Disc_Ground, optim_Disc, LEARNING_RATE,
#     # )
#     # load_checkpoint(
#     #     WEIGHTS_DISC_S, Disc_Seg, optim_Disc, LEARNING_RATE,
#     # )
# dataset = CityscapesDataSet(SEGMENT_PATH, GROUND_PATH, transforms2)
# test_dataset = CityscapesDataSet(SEGMENT_PATH_TEST, GROUND_PATH_TEST, transforms2)
dataset = CustomDataSet(path1, path2, train_data_transform)
# test_dataset = CityscapesDataSet(SEGMENT_PATH_TEST, GROUND_PATH_TEST, transforms2)

loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

# test_loader = DataLoader(
#   test_dataset,
#   batch_size=1,
#   shuffle=False,
#   pin_memory=True,
# )

Gen_scaler = torch.cuda.amp.GradScaler()
Disc_scaler = torch.cuda.amp.GradScaler()

for epoch in range(NUM_EPOCHS):
    print("Epoch:", epoch)
    s_time = time.time_ns()
    x1, x2 = TrainCicleGAN(Disc_Ground, Gen_Seg, Disc_Seg, Gen_Ground, loader, optim_Disc, optim_Gen, L1, MSE, Disc_scaler, Gen_scaler)
    e_time = time.time_ns()
    print("Time elapsed: {} min".format((e_time - s_time) / 1e9 / 60))
    print()

    # if SAVE_MODEL:
    #   torch.save(Disc_Ground.state_dict(), WEIGHTS_DISC_G)
    #   torch.save(Disc_Seg.state_dict(), WEIGHTS_DISC_S)

    #   torch.save(Gen_Ground.state_dict(), WEIGHTS_GEN_G)
    #   torch.save(Gen_Seg.state_dict(), WEIGHTS_GEN_S)



In [None]:
tmpimg = Image.open('/content/drive/MyDrive/ECE-GY 6953 DL/Project/Datasets/mini/trainA/n02381460_1127.jpg')
# tmpimg = train_data_transform(tmpimg).permute(1,2,0)
plt.imshow(tmpimg)

In [None]:
path = '/content/Segment/0.png'
# path = '/content/drive/MyDrive/ECE-GY 6953 DL/Project/Datasets/mini/trainA/n02381460_1127.jpg'
img = Image.open(path)
print(np.array(img).shape)
plt.imshow(img)