In [None]:
from google.colab import drive

drive.mount('/gdrive')

# Specify the directory path where `assignemnt3.ipynb` exists.
# For example, if you saved `assignment3.ipynb` in `/gdrive/My Drive/cs376/assignment3` directory,
# then set root = '/gdrive/My Drive/CS376-2021F/HW3'
root = '/gdrive/My Drive/CycleGAN'

In [None]:
import os
from PIL.Image import Image
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, datasets, utils
import torch.optim as optim
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt

from model import Discriminator, Generator
from dataset import ImageDataset
from torch.utils.tensorboard import SummaryWriter

torch.manual_seed(470)
torch.cuda.manual_seed(470)

from pathlib import Path
from datetime import datetime

now = datetime.now()

In [None]:
# Hyperparameters
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
LOG_DIR = os.path.join(ROOT_DIR , "logs/" + now.strftime("%Y%m%d-%H%M%S"))
LOG_ITER = 100
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCHSIZE = 4
LEARNING_RATE = 0.002
MAX_EPOCH = 100

LAMBDA = 10

In [None]:
import os
import glob
from PIL import Image
from random import shuffle
from torch.utils.data import Dataset

class ImageDataset(Dataset):
    def __init__(self, path_X, path_Y, transform = None):
        self.path_X = glob.glob(os.path.join(path_X, '*.jpg'))
        self.path_Y = glob.glob(os.path.join(path_Y, '*.jpg'))
        shuffle(self.path_X)
        shuffle(self.path_Y)
        self.transform = transform
        self.length = max(len(self.path_X), len(self.path_Y))
    
    def __len__(self):
        return self.length
    
    def __getitem__(self, index):
        img_X = Image.open(self.path_X[index % len(self.path_X)])
        img_Y = Image.open(self.path_Y[index % len(self.path_Y)])
        
        if self.transform:
            img_X = self.transform(img_X)
            img_Y = self.transform(img_Y)
        
        return img_X, img_Y

In [None]:
# Construct Data Pipeline
data_dir_X = os.path.join(Path(ROOT_DIR).parent, 'dataset', 'photo_jpg')
data_dir_Y = os.path.join(Path(ROOT_DIR).parent, 'dataset', 'monet_jpg')
transform = transforms.Compose([transforms.ToTensor()])

#Helper Functions
def imshow(img):
    img = img.numpy()
    plt.imshow(np.transpose(img, (1, 2, 0)))
    plt.show()


def show():
    for batch_idx, (inputs, targets) in enumerate(train_dataloader):
        imshow(inputs[0])


In [None]:
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()

        conv_block = [nn.ReflectionPad2d(1),
                      nn.Conv2d(in_features, in_features, 3),
                      nn.InstanceNorm2d(in_features),
                      nn.ReLU(inplace=True),
                      nn.ReflectionPad2d(1),
                      nn.Conv2d(in_features, in_features, 3),
                      nn.InstanceNorm2d(in_features)]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.features = self.create_layers()

    def forward(self, x):
        x = self.features(x)
        return x

    def create_layers(self):
        infos = ['c7s1-64', 'd128', 'd256', 'R256', 'R256', 'R256', 'R256',
                 'R256', 'R256', 'R256', 'R256', 'R256', 'u128', 'u64', 'c7s1-3']
        layers = []
        in_channels = 3
        for x in infos:
            if x.startswith('c7s1-'):
                layers += [nn.ReflectionPad2d(3),
                           nn.Conv2d(in_channels, int(x[5:]), kernel_size=7),
                           nn.InstanceNorm2d(int(x[5:])),
                           nn.ReLU(inplace=True)]
                in_channels = int(x[5:])

            elif x.startswith('d'):
                layers += [nn.Conv2d(in_channels, int(x[1:]), kernel_size=3, stride=2, padding=1),
                           nn.InstanceNorm2d(int(x[1:])),
                           nn.ReLU(inplace=True)]
                in_channels = int(x[1:])

            elif x.startswith('R'):
                layers += [ResidualBlock(in_channels)]

            elif x.startswith('u'):
                layers += [nn.ConvTranspose2d(in_channels, int(x[1:]), kernel_size=3, stride=2, padding=1, output_padding=1),
                           nn.InstanceNorm2d(int(x[1:])),
                           nn.ReLU(inplace=True)]
                in_channels = int(x[1:])

        return nn.Sequential(*layers)


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.features = self.create_layers()

    def forward(self, x):
        return self.features(x)

    def create_layers(self):
        infos = [64, 128, 256, 512]
        layers = []
        in_channels = 3
        for x in infos:
            if x == 64:
                layers += [nn.Conv2d(in_channels, x, kernel_size=4, stride=2),
                           nn.InstanceNorm2d(x), 
                           nn.LeakyReLU(0.2, inplace=True)]
                in_channels = x
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=4, stride=2), 
                           nn.LeakyReLU(0.2, inplace=True)]
                in_channels = x
        layers += [nn.Conv2d(in_channels, 1, 4, padding=1)]
        return nn.Sequential(*layers)


In [None]:
train_dataset = ImageDataset(data_dir_X, data_dir_Y, transform=transform)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCHSIZE, shuffle=True, num_workers=2)

Generator_XY = Generator()
Discriminator_X = Discriminator()

Generator_YX = Generator()
Discriminator_Y = Discriminator()

optimizer_generator = optim.Adam(Generator_XY.parameters(), lr=LEARNING_RATE)
optimizer_discriminator_X = optim.Adam(Discriminator_X.parameters(), lr=LEARNING_RATE)
optimizer_discriminator_Y = optim.Adam(Discriminator_Y.parameters(), lr=LEARNING_RATE)

img_size = torch.empty(256, 256)

In [None]:
writer = SummaryWriter(LOG_DIR)
iteration = 0
for epoch in range(MAX_EPOCH):
    Generator_XY.train()
    Generator_YX.train()
    Discriminator_X.train()
    Discriminator_Y.train()
    for input_X, input_Y in train_dataloader:
        iteration += 1
        input_X = input_X.to(DEVICE)
        input_Y = input_Y.to(DEVICE)

        X_to_Y = Generator_XY(input_X)
        Y_to_X = Generator_YX(input_Y)

        # Adversarial Loss
        MSELoss = torch.nn.MSELoss()
        result_XYY = Discriminator_Y(X_to_Y)
        result_YY = Discriminator_Y(input_Y)
        result_YXX = Discriminator_X(Y_to_X)
        result_XX = Discriminator_X(input_X)

        loss_GAN_G = (MSELoss(result_XYY, torch.ones_like(result_XYY)) + MSELoss(result_YXX, torch.ones_like(result_YXX))) / 2

        loss_GAN_DY = MSELoss(result_YY, torch.ones_like(result_YY)) + MSELoss(result_XYY, torch.zeros_like(result_XYY))
        loss_GAN_DX = MSELoss(result_XX, torch.ones_like(result_XX)) + MSELoss(result_YXX, torch.zeros_like(result_YXX))


        # Cycle Consistency Loss
        L1Norm = torch.nn.L1Loss()
        loss_cyc = (L1Norm(Generator_YX(Y_to_X), input_X) + L1Norm(Generator_XY(Y_to_X), input_Y))*LAMBDA

        # Identity Loss
        loss_identity = (L1Norm(X_to_Y, input_Y) + L1Norm(Y_to_X, input_X))*0.5*LAMBDA

        loss_G = loss_GAN_G + loss_cyc + loss_identity
        loss_DX = loss_GAN_DX
        loss_DY = loss_GAN_DY

        optimizer_discriminator_X.zero_grad()
        loss_DX.backward(retain_graph=True)
        optimizer_discriminator_X.step()

        optimizer_discriminator_Y.zero_grad()
        loss_DY.backward(retain_graph=True)
        optimizer_discriminator_Y.step()

        optimizer_generator.zero_grad()
        loss_G.backward()
        optimizer_generator.step()
        
        loss = loss_G.item() + loss_DX.item() + loss_DY.item()

        if iteration % 20 == 0 and writer is not None:
            writer.add_scalar('train_loss', loss, iteration)
            print('[epoch: {}, iteration: {}] train loss : {:4f}'.format(epoch+1, iteration, loss))
            
    print('[epoch: {}] train loss : {:4f}'.format(epoch+1, loss))