In [None]:
import torch
import torch.nn as nn
import torchvision

import torch.optim as optim
from torch.utils.data import DataLoader,Dataset

import torchvision.transforms as transforms
from PIL import Image

from tqdm import tqdm
import os
import time
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

## set random seed

In [None]:
seed = 7
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

# Cycle GAN

<img src="./../img/cycle_gan_arch.jpeg">

in cycle gan, we have two generator and two discriminator.
* one generator G_H_Z convert horse to zebra.
* the other gen G_Z_H convert zebra to horse
* and two discriminator D_Z, D_H to disc whether a zebra/horse is real or fake.

the generator take in a img of horse/zebra **A** of the shape (N, 3, 256, 256), then output a img **B** of zebra/horse of the same shape.

the disc's job is to indentify whether the horse/zebra is generated or not, it take a img (N, 3, 256, 256) as input, then output a patch (N, 1, 30, 30)
we use the output tensor and all ones/zeros tensor of the same shape to compute loss (MSE LOSS) to train disc.

the generator's loss is much complicated, it can be divived into three parts:
1. first part is the same as disc'loss, put the origin img and generated img to get output, compute loss1 using the output and the all ones/zeros label (MSE LOSS).
2. second part is: when we input a horse img into G_Z_H, since the G_Z_H is intended to generate horse img, we don not want it to change the input img, in other word, we wanna the output horse img of G_Z_H and input horse img as close as possible. we use the input and output img to compute loss (L1 LOSS).
3. third part is: when we input a horse img into G_H_Z, getting a output, and then put output into G_Z_H to reget a horse img. we wanna the final output horse and the input horse to as close as possible. so we use (L1 LOSS) to compute the loss.

<img src="./../img/cyclegan.ppm">


the discriminator of cycle gan has two feature:
in DCGAN, the disc take a image as input and output a single scacler (N, ). cycle gan's disc also input a image, but output a square tensor (N, 1, 30, 30). the (1, 30, 30) tensor's value are in range (0, 1) the same as other gan.

people are assuming that: each number of the square tnesor indicate that: whether the specific part of the image is close to the target or not.

## define model

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channel):
        super(ResidualBlock, self).__init__()
        # first sublayer have Relu, the second do not have
        self.conv_block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_channel, in_channel, (3,3)),
            nn.InstanceNorm2d(in_channel),
            nn.ReLU(),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_channel, in_channel, (3,3)),
            nn.InstanceNorm2d(in_channel)
        )

    def forward(self, x):
        return x + self.conv_block(x)



class Gen(nn.Module):
    def __init__(self, img_channel) -> None:
        super(Gen, self).__init__()
        # up and down block's: inchannel, outchannel, kernel_size, stride, padding
        # in down block: the first two elements is inchannel, outchannel
        # in up block: the first two elements is outchannel, inchannel

                                # (N, 3, 256, 256)
        self.config_up_down = [[img_channel, 64, 7, 1, 3],
                                # (N, 9, 256, 256)
                               [64, 128, 4, 2, 1],
                                # (N, 18, 128, 128)
                               [128, 256, 4, 2, 1]]
                                # (N, 36, 64, 64)
                                # in up block the size goes the opposite direction



        # create down and up block
        self.down_block = self._down_block()
        self.up_block = self._up_block()

        # create residual block
        self.num_residual_block = 9
        residual_block = []
        for i in range(self.num_residual_block):
          residual_block.append(ResidualBlock(256))
        self.residual_block = nn.Sequential(*residual_block)

        # final tanh
        # mind final part use conv instead of conv_transpose
        self.final_part = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, img_channel, (7,7)),
            nn.Tanh()
        )




    def _conv_block(self, inchannel, outchannel, k_s, s, p):
        return nn.Sequential(
            nn.Conv2d(inchannel, outchannel, k_s, s, p),
            nn.InstanceNorm2d(outchannel),
            nn.ReLU()
        )

    def _deconv_block(self, inchannel, outchannel, k_s, s, p):
        return nn.Sequential(
            nn.ConvTranspose2d(inchannel, outchannel, k_s, s, p),
            nn.InstanceNorm2d(outchannel),
            nn.ReLU()
        )
    def _down_block(self):
        down_layers = []
        for i in range(len(self.config_up_down)):
            down_layers.append(self._conv_block(self.config_up_down[i][0],
                                                self.config_up_down[i][1],
                                                self.config_up_down[i][2],
                                                self.config_up_down[i][3],
                                                self.config_up_down[i][4]))

        return nn.Sequential(*down_layers)



    def _up_block(self):
        up_layers = []
        # mind that there does not use:
        #   conv(64 --> 3)
        for i in range(len(self.config_up_down) - 1, 0, -1):
            up_layers.append(self._deconv_block(self.config_up_down[i][1],
                                                self.config_up_down[i][0],
                                                self.config_up_down[i][2],
                                                self.config_up_down[i][3],
                                                self.config_up_down[i][4]))
        return nn.Sequential(*up_layers)

    def forward(self, x):
        for layer in self.down_block:
            x = layer(x)
        for layer in self.residual_block:
            x = layer(x)
        for layer in self.up_block:
            x = layer(x)
        for layer in self.final_part:
            x = layer(x)
        return x

the generator of cycle gan is made up of 3 parts:
1. the first part is down block, decreasing the image size by conv.
2. then use residual block while mataining the image size.
3. finally the up block, to get the origin image size back.

TIPS: the residual connection here is just elementwise add, we do not need do conv to obtain the same channel_num or image size.


In [None]:
class Disc(nn.Module):
    """
    input: (N, 3, 256, 256)
    output: (N, 1, 30, 30)
    """

    def __init__(self, img_channel) -> None:
        super(Disc, self).__init__()
        # conv block's: inchannel, outchannel, kernel_size, stride, padding
        # input: (3, 256, 256)
        self.config_init = [img_channel, 64, 4, 2, 1]
        # (64, 128, 128)
        self.config_lst = [[64, 128, 4, 2, 1],
                           # (128, 64, 64)
                           [128, 256, 4, 2, 1],
                           # (256, 32, 32)
                           [256, 512, 4, 1, 1]]
                           # (512, 31, 31)
        self.config_final = [512, 1, 4, 1, 1]
        # (1, 30, 30)
        self.conv_layers = self._create_conv_layers()
        self.sigmoid = nn.Sigmoid()

    def _create_conv_layers(self):
        layers_lst = []
        # first layer don not have instance norm
        layers_lst.append(self._conv_block(self.config_init, norm=False))

        # intermediate layers follow the struct: conv, norm, leakyrelu
        for i in range(len(self.config_lst)):
            layers_lst.append(self._conv_block(self.config_lst[i], norm=True))
        # final layer is just a conv
        inchannel, outchannel, k_s, s, p = self.config_final[0],  self.config_final[1],  self.config_final[2],  self.config_final[3], self.config_final[4]
        layers_lst.append(nn.Conv2d(inchannel, outchannel, k_s, s, p))

        return nn.Sequential(*layers_lst)




    def _conv_block(self, config, norm):
        inchannel, outchannel, k_s, s, p = config[0],  config[1],  config[2],  config[3], config[4]
        return nn.Sequential(
            nn.Conv2d(inchannel, outchannel, k_s, s, p),
            nn.InstanceNorm2d(outchannel) if norm == True else nn.Identity(),
            nn.LeakyReLU(0.2)

        )
    def forward(self, x):
        for layer in self.conv_layers:
            x = layer(x)
        x = self.sigmoid(x)
        return x

## define some params

In [None]:
LEARNING_RATE_GEN = 2e-4
LEARNING_RATE_DISC = 2e-4
LAMBDA_CYCLE = 10
LAMBDA_INDENTITY = 5
IMAGE_CHANNEL = 3
NUM_EPOCH = 10
BATCH_SIZE = 4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

## prepare data

In [None]:
# download and unzip dataset
! wget https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip && unzip horse2zebra.zip

In [None]:
class horse_zebra(Dataset):
    def __init__(self, transform, path_horse="./horse2zebra/trainA/",path_zebra="./horse2zebra/trainB/"):
        self.path_horse = path_horse
        self.path_zebra = path_zebra

        self.transform = transform

        self.lst_horse = os.listdir(self.path_horse)
        self.lst_zebra = os.listdir(self.path_zebra)

        # one thing about the dataset:
        # the num of horse pic != num of zebra pic
        self.length = max(len(self.lst_horse), len(self.lst_zebra))

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        # some pic on the dataset just have one channel
        # so need to use function: convert("RGB")
        horse = Image.open(self.path_horse + self.lst_horse[torch.randint(len(self.lst_horse), size=(1,))]).convert("RGB")
        zebra = Image.open(self.path_zebra + self.lst_zebra[idx]).convert("RGB")
        if self.transform:
            horse = self.transform(horse)
            zebra = self.transform(zebra)
        return horse, zebra



In [None]:
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]
transform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

In [None]:
dataset = horse_zebra(transform)
dataloader = DataLoader(dataset, shuffle=True, batch_size=BATCH_SIZE)

In [None]:
# get test samples
# here, i choose to pick images in training set to visiualize whether the generator is learning
fix_horse, fix_zebra = next(iter(dataloader))

## keep the image buffer

in the paper, they do not put the image newly generated by Gen into Disc, instead, they keep a buffer of generated image and then sample image from the buffer to Disc.

In [None]:
class ImageBuffer:
    def __init__(self, max_size=50):
        self.max_size = max_size
        self.data = []

    def get_image(self, data):
        # data is what Gen generated and send in func
        # we will return a tensor of the same shape as data send in
        tensor_return = []
        for each in data:
            each = each.unsqueeze(0)
            if len(self.data) < self.max_size:
                self.data.append(each)
                tensor_return.append(each)
            else:
                # with prob == 0.5
                # we will update the buffer
                if torch.randn(1) < 0.5:
                    choose = torch.randint(self.max_size, size=(1,))
                    tensor_return.append(self.data[choose])
                    self.data[choose] = each
                else:
                    tensor_return.append(each)
        return torch.cat(tensor_return, dim=0)


## define optim and loss

In [None]:
def weight_init(m):
    """
    init model's prams, if a layer is a conv layer: set the weight to N(0, 0.02), set the bias to constant 0.
    if a layer is a batch norm: set the weight to N(1, 0.02), set the bias to constant 0.
    usage: model.apply(weight_init)

    params:
        m: model's layer
    return:
        None
    """
    class_name = m.__class__.__name__
    if class_name.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0, 0.02)
        if hasattr(m, "bias") and m.bias is not None:
            torch.nn.init.constant_(m.bias.data, 0)
    if class_name.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1, 0.02)
        torch.nn.init.constant_(m.bias.data, 0)


In [None]:
G_H_Z = Gen(IMAGE_CHANNEL).to(DEVICE)
G_Z_H = Gen(IMAGE_CHANNEL).to(DEVICE)

D_H = Disc(IMAGE_CHANNEL).to(DEVICE)
D_Z = Disc(IMAGE_CHANNEL).to(DEVICE)

G_H_Z.apply(weight_init)
G_Z_H.apply(weight_init)
D_H.apply(weight_init)
D_Z.apply(weight_init)

# optim two networks' params at the same time
opt_gen = optim.Adam(list(G_Z_H.parameters()) + list(G_H_Z.parameters()), lr=LEARNING_RATE_GEN, betas=(0.5, 0.999))
opt_d_h = optim.Adam(D_H.parameters(), lr=LEARNING_RATE_DISC, betas=(0.5, 0.999))
opt_d_z = optim.Adam(D_Z.parameters(), lr=LEARNING_RATE_DISC, betas=(0.5, 0.999))

gan_criterion = nn.MSELoss()
cycle_criterion = nn.L1Loss()
identity_criterion = nn.L1Loss()

model_checkpoint = {
    'G_H_Z':None,
    'G_Z_H':None,
    'D_Z':None,
    'D_H':None
}


## set up the training loop

In [None]:
! mkdir ./horse_output && mkdir ./zebra_output

In [None]:
# function to sample images
def sample_image(fix_horse, fix_zebra):
    fake_zebra = G_H_Z(fix_horse)
    fake_horse = G_Z_H(fix_zebra)
    cycle_horse = G_Z_H(fake_zebra)
    cycle_zebra = G_H_Z(fake_horse)

    grid_real_horse = torchvision.utils.make_grid(fix_horse, nrow=1, normalize=True)
    grid_fake_zebra = torchvision.utils.make_grid(fake_zebra, nrow=1, normalize=True)
    grid_cycle_horse = torchvision.utils.make_grid(cycle_horse, nrow=1, normalize=True)
    grid_horse = torch.cat((grid_real_horse, grid_fake_zebra, grid_cycle_horse), dim=2)

    grid_real_zebra = torchvision.utils.make_grid(fix_zebra, nrow=1, normalize=True)
    grid_fake_horse = torchvision.utils.make_grid(fake_horse, nrow=1, normalize=True)
    grid_cycle_zebra = torchvision.utils.make_grid(cycle_zebra, nrow=1, normalize=True)
    grid_zebra = torch.cat((grid_real_zebra, grid_fake_horse, grid_cycle_zebra), dim=2)
    torchvision.utils.save_image(grid_horse, f"./horse_output/horse_epoch{epoch}_batch_{batch_idx}.png", normalize=False)
    torchvision.utils.save_image(grid_zebra, f"./zebra_output/zebra_epoch{epoch}_batch_{batch_idx}.png", normalize=False)
    print("saving samples... done!")
def save_checkpoint(dict, path="model.pth.tar"):
    """
    save the model and optim through a dictionary
     {"model": model.state_dict(), "optim": optim.state_dict()} (*)


    params:
        data: a dict of the structure (*) storing the model and optim

    return:
        None
    """
    now = time.strftime("%D_%H:%M")
    print(f"saving checkpoint at {now}, path is '{path}'")
    torch.save(dict, path)
    print("saving model... done!")

In [None]:
fake_zebra_buffer = ImageBuffer()
fake_horse_buffer = ImageBuffer()

In [None]:
for epoch in range(NUM_EPOCH):
    loop = tqdm(dataloader, leave=True)
    for batch_idx, (horse, zebra) in enumerate(loop):
        G_H_Z.train()
        G_Z_H.train()

        horse = horse.to(DEVICE)
        zebra = zebra.to(DEVICE)

        #### train Gen ####
        opt_gen.zero_grad()

        # train generator with the lastest generated image
        fake_horse = G_Z_H(zebra)
        d_h_fake = D_H(fake_horse)
        fake_zebra = G_H_Z(horse)
        d_z_fake = D_Z(fake_zebra)
        # gan loss
        gan_loss = gan_criterion(d_z_fake, torch.ones_like((d_z_fake)).to(DEVICE)) + gan_criterion(d_h_fake, torch.ones_like(d_h_fake))

        # identity loss
        identity_zebra = G_H_Z(zebra)
        identity_horse = G_Z_H(horse)
        identity_loss = (identity_criterion(identity_zebra, zebra) + identity_criterion(identity_horse, horse)) * LAMBDA_INDENTITY

        # cycle loss
        cycle_horse = G_Z_H(fake_zebra)
        cycle_zebra = G_H_Z(fake_horse)
        cycle_loss = (cycle_criterion(cycle_horse, horse) + cycle_criterion(cycle_zebra, zebra)) * LAMBDA_CYCLE

        Gen_loss = gan_loss + identity_loss + cycle_loss
        Gen_loss.backward()
        opt_gen.step()

        #### train discriminator ####
        opt_d_h.zero_grad()
        # train discriminator with the image from buffer

        d_h_real = D_H(horse)
        # get image from buffer
        fake_horse_ = fake_horse_buffer.get_image(fake_horse)
        d_h_fake = D_H(fake_horse_.detach())
        d_h_loss = gan_criterion(d_h_real, torch.ones_like(d_h_real).to(DEVICE)) + \
            gan_criterion(d_h_fake, torch.zeros_like(d_h_fake).to(DEVICE))
        d_h_loss.backward()
        opt_d_h.step()

        opt_d_z.zero_grad()
        d_z_real = D_Z(zebra)
        fake_zebra_ = fake_zebra_buffer.get_image(fake_zebra)
        d_z_fake = D_Z(fake_zebra_.detach())
        d_z_loss = gan_criterion(d_z_real, torch.ones_like((d_z_real).to(DEVICE))) + \
            gan_criterion(d_z_fake, torch.zeros_like(d_z_fake).to(DEVICE))
        d_z_loss.backward(retain_graph=True)
        opt_d_z.step()

        if batch_idx % 40 == 0:
          G_Z_H.eval()
          G_H_Z.eval()
          with torch.no_grad():
                # sample gans output
                sample_image(fix_horse.to(DEVICE), fix_zebra.to(DEVICE))

    # saving models
    model_checkpoint['G_H_Z']=G_H_Z.state_dict()
    model_checkpoint['G_Z_H']=G_Z_H.state_dict()
    model_checkpoint['D_Z']=D_Z.state_dict()
    model_checkpoint['D_H']=D_H.state_dict()
    save_checkpoint(model_checkpoint)

In [None]:
!ls ./horse_output

In [None]:
! zip -r zebraoutput.zip zebra_output