# Chapter 9: CycleGAN

# Imports

In [343]:
import torch
import torch.nn as nn
import torch.cuda as cuda
from torch.optim import Adam
from torchvision.datasets import VisionDataset, ImageFolder
from torchvision.datasets.folder import default_loader
from torchvision.datasets.utils import check_integrity, download_and_extract_archive
from torchvision.transforms import Compose, ToTensor, Resize
from torch.utils.data import DataLoader
import os
import matplotlib.pyplot as plt

# Loading data

PyTorch or torchvision does not come with the `apple2orange` dataset. Therefore we implement a full dataset which retrieves the data.

In [344]:
class Apple2Orange(VisionDataset):
    url = "https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/apple2orange.zip"
    base_folder = "apple2orange"
    filename = "apple2orange.zip"
    zip_md5 = "5b58c340256288622a835d6f3b6198ae"
    
    def __init__(self, root, train=True, transform=None, target_transform=None,
             download=False):
        super().__init__(root, transform=transform,
                            target_transform=target_transform)
        self.train = train
        self.train_folder = os.path.join(self.root, self.base_folder, "train")
        self.test_folder = os.path.join(self.root, self.base_folder, "test")

        if download:
            self.download()
            
        if train:
            root = self.train_folder
        else:
            root = self.test_folder
            
        self.apples = [f for f in os.scandir(os.path.join(root, "apple")) if f.is_file()]
        self.oranges = [f for f in os.scandir(os.path.join(root, "orange")) if f.is_file()]
        
    def __getitem__(self, index):
        return self.transform(default_loader(self.apples[index])), self.transform(default_loader(self.oranges[index]))
    
    def __len__(self):
        return min(len(self.apples), len(self.oranges))
            
    def _check_integrity(self):
        fpath = os.path.join(self.root, self.filename)
        
        if check_integrity(fpath, self.zip_md5) and os.path.exists(os.path.join(self.root, self.base_folder)):
            return True
        else:
            return False
        
            
    def download(self):
        if self._check_integrity():
            print('Files already downloaded and verified')
            return
        
        #download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.zip_md5)
        
        train_path = os.path.join(self.root, self.base_folder, "train")
        test_path = os.path.join(self.root, self.base_folder, "test")
        source_train_a = os.path.join(self.root, self.base_folder, "trainA")
        source_train_b = os.path.join(self.root, self.base_folder, "trainB")
        source_test_a = os.path.join(self.root, self.base_folder, "testA")
        source_test_b = os.path.join(self.root, self.base_folder, "testB")
        
        os.mkdir(train_path)
        os.mkdir(test_path)
        
        os.renames(source_train_a, os.path.join(train_path, "apple"))
        os.renames(source_train_b, os.path.join(train_path, "orange"))
        os.renames(source_test_a, os.path.join(test_path, "apple"))
        os.renames(source_test_b, os.path.join(test_path, "orange"))


In [345]:
t = Compose([Resize(128), ToTensor()])
train_set = Apple2Orange('~/pytorch', train=True, download=True, transform=t)
test_set = Apple2Orange('~/pytorch', train=False, download=True, transform=t)
train_loader = DataLoader(train_set, batch_size=8, shuffle=True, drop_last=True)

Files already downloaded and verified
Files already downloaded and verified


# Support CUDA

In [346]:
if cuda.is_available():
    print("Using cuda device for training")
    device = 'cuda'
else:
    print("Using cpu for training")
    device = 'cpu'

Using cpu for training


# Generator

In [338]:
class Generator(nn.Module):
    class TransSkip(nn.Module):
        def __init__(self, in_channel, out_channel, kernel_size):
            super().__init__()
            self.trans1 = nn.Sequential(
                nn.UpsamplingBilinear2d(scale_factor=2),
                nn.Conv2d(in_channel, out_channel, kernel_size, stride=1, padding=int(kernel_size/2)),
                nn.LeakyReLU(negative_slope=0.2),
                nn.InstanceNorm2d(out_channel/2))
            
        def forward(self, x, skip):
            trans = self.trans1(x)
            x = torch.cat([trans,  skip], dim=1)
            
            return x

    def __init__(self):
        super().__init__()
        self.conv1 = self.conv_layer(3, 32, 3)
        self.conv2 = self.conv_layer(32, 64, 3)
        self.conv3 = self.conv_layer(64, 128, 3)
        self.conv4 = self.conv_layer(128, 256, 3)
        
        self.trans1 = self.TransSkip(256, 128, 3)
        self.trans2 = self.TransSkip(256, 64, 3)
        self.trans3 = self.TransSkip(128, 32, 3)
        # Check if using skip from first image improves result
        # self.trans4 = self.TransSkip(64, 64, 3)
        # self.conv5 = nn.Sequential(
        #     nn.Conv2d(67, 3, 3, stride=1, padding=1),
        #     nn.Tanh())
        self.trans4 = nn.UpsamplingBilinear2d(scale_factor=2)
        self.conv5 = nn.Sequential(
            nn.Conv2d(64, 3, 3, stride=1, padding=1),
            nn.Tanh())
    
    def forward(self, x):
        skip_0 = x
        skip_1 = x = self.conv1(x)
        skip_2 = x = self.conv2(x)
        skip_3 = x = self.conv3(x)
        x = self.conv4(x)
        x = self.trans1(x, skip_3)
        x = self.trans2(x, skip_2)
        x = self.trans3(x, skip_1)
        # x = self.trans4(x, skip_0)
        x = self.trans4(x)
        x = self.conv5(x)
        
        return x
    
    def conv_layer(self, in_channels, out_channels, kernel_size):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride=2, padding=1),
            nn.LeakyReLU(),
            nn.InstanceNorm2d(out_channels))

In [339]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = self.conv_layer(3, 64, 3)
        self.conv2 = self.conv_layer(64, 128, 3)
        self.conv3 = self.conv_layer(128, 256, 3)
        self.conv4 = self.conv_layer(256, 512, 3)
        self.conv5 = nn.Conv2d(512, 1, 4, stride=1, padding=2)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        
        return x
    
    def conv_layer(self, in_channels, out_channels, kernel_size):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride=2, padding=int(kernel_size/2)),
            nn.LeakyReLU(negative_slope=0.2))

In [340]:
x, y = train_set[0]
type(y)

torch.Tensor

In [341]:
gen_AB = Generator()
gen_BA = Generator()
dis_A = Discriminator()
dis_B = Discriminator()

cycle_loss_criteria = nn.L1Loss()
adversarial_loss_criteria = nn.MSELoss()

real = torch.ones(8, 1, 9, 9)
fake = torch.zeros(8, 1, 9, 9)

dis_A_optim = Adam(dis_A.parameters(), lr=0.0002, betas=(0.5, 0.999))
dis_B_optim = Adam(dis_B.parameters(), lr=0.0002, betas=(0.5, 0.999))
gen_AB_optim = Adam(gen_AB.parameters(), lr=0.0002, betas=(0.5, 0.999))
gen_BA_optim = Adam(gen_BA.parameters(), lr=0.0002, betas=(0.5, 0.999))


for img_A, img_B in train_loader:
    # A -> B
    img_B_gen = gen_AB(img_A)
    d_out = dis_B(img_B_gen)
    fake_loss = adversarial_loss_criteria(d_out, fake)
    
    dis_A_optim.zero_grad()
    fake_loss.backward()
    dis_A_optim.step()
    
    d_out = dis_B(img_B)
    real_loss = adversarial_loss_criteria(d_out, real)
    
    dis_A_optim.zero_grad()
    real_loss.backward()
    dis_A_optim.step()
    
    # B -> A
    img_A_gen = gen_BA(img_A)
    d_out = dis_A(img_A_gen)
    fake_loss = adversarial_loss_criteria(d_out, fake)
    
    dis_B_optim.zero_grad()
    fake_loss.backward()
    dis_B_optim.step()
    
    d_out = dis_A(img_A)
    fake_loss = adversarial_loss_criteria(d_out, real)
    
    dis_B_optim.zero_grad()
    fake_loss.backward()
    dis_B_optim.step()

    # Generator
    img_B_gen = gen_AB(img_A)
    img_A_rec = gen_BA(img_B_gen)
    d_out = dis_B(img_B_gen)
    
    rec_A_loss = cycle_loss_criteria(img_A, img_A_rec) * 10
    adv_A_loss = adversarial_loss_criteria(d_out, real)
    
    img_A_gen = gen_BA(img_B)
    img_B_rec = gen_AB(img_A_gen)
    d_out = dis_A(img_A_gen)
    
    rec_B_loss = cycle_loss_criteria(img_B, img_B_rec) * 10
    adv_B_loss = adversarial_loss_criteria(d_out, real)
    
    gen_AB_optim.zero_grad()
    gen_BA_optim.zero_grad()
    
    rec_A_loss.backward(retain_graph=True)
    adv_A_loss.backward()
    
    rec_B_loss.backward(retain_graph=True)
    adv_B_loss.backward()
    
    print(f"A_rec: {rec_A_loss:.4f} A_adv: {adv_A_loss:.4f}")
    print(f"B_rec: {rec_B_loss:.4f} B_adv: {adv_B_loss:.4f}")
    
    gen_AB_optim.step()
    gen_BA_optim.step()

A_rec: 5.6418 A_adv: 1.0109
B_rec: 5.9284 B_adv: 1.0087
A_rec: 4.8296 A_adv: 1.0126
B_rec: 5.0547 B_adv: 1.0070
A_rec: 5.0231 A_adv: 1.0123
B_rec: 5.1306 B_adv: 1.0076
A_rec: 4.7849 A_adv: 1.0118
B_rec: 5.1849 B_adv: 1.0073
A_rec: 5.6057 A_adv: 1.0130
B_rec: 5.2316 B_adv: 1.0076


KeyboardInterrupt: 