Source: https://hardikbansal.github.io/CycleGANBlog/

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import os
from PIL import Image
from torchvision import transforms
import itertools
import torchvision.utils as vutils
%matplotlib inline

## HYPER-PARAMETERS

In [2]:
a_folder = "/home/tyler/data/image/horse2zebra/trainA/"
b_folder = "/home/tyler/data/image/horse2zebra/trainB/"
batch_size = 1
num_epochs = 200
lambda_a = 10
lambda_b = 10
lr = 0.0002
lr_d = 0.00005
save_results_every = 100
save_results_path = "/home/tyler/data/image/horse2zebra/results"
save_models_path = "/home/tyler/data/image/horse2zebra/models/"

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

class GANDataSet(Dataset):
    def __init__(self, a_folder, b_folder, transform=None):
        self.a_folder = a_folder
        self.b_folder = b_folder
        self.a_images = os.listdir(a_folder)
        self.b_images = os.listdir(b_folder)
        self.transform = transform
        
    def __len__(self):
        return min(len(self.a_images), len(self.b_images))
    
    def __read_image(self, path):
        image = Image.open(path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image
    
    def __getitem__(self, idx):
        a_img_name = os.path.join(self.a_folder,
                                self.a_images[idx])
        b_img_name = os.path.join(self.b_folder,
                                self.b_images[idx])
        a_img = self.__read_image(a_img_name)
        b_img = self.__read_image(b_img_name)

        return a_img, b_img

In [5]:
dataset = GANDataSet(a_folder, b_folder, 
                     transform=transforms.Compose([
                         transforms.Resize((256, 256)),
                         transforms.ToTensor(),
                         transforms.Normalize((0.5, 0.5, 0.5), 
                                              (0.5, 0.5, 0.5))]))
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [6]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 7, 1, 1)
        self.conv2 = nn.Conv2d(64, 128, 2, 2)
        self.conv3 = nn.Conv2d(128, 256, 2, 2)
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(128)
        
    def forward(self, input_image):
        output = F.relu(self.bn1(self.conv1(input_image)))
        output = F.relu(self.bn2(self.conv2(output)))
        return self.conv3(output)

In [7]:
def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class ResBlock(nn.Module):

    def __init__(self, inplanes, planes, stride=1):
        super(ResBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += residual
        out = self.relu(out)

        return out
    
class Transformer(nn.Module):
    
    def __init__(self, n_blocks):
        super(Transformer, self).__init__()
        
        layers = []
        for i in range(n_blocks):
            layers.append(ResBlock(256, 256))
            
        self.reslayers = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.reslayers(x)

In [8]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.conv1 = nn.ConvTranspose2d(256, 128, 2, 2)
        self.conv2 = nn.ConvTranspose2d(128, 64, 2, 2)
        self.conv3 = nn.ConvTranspose2d(64, 3, 7, 1, 1)
        self.bn1 = nn.BatchNorm2d(128)
        self.bn2 = nn.BatchNorm2d(64)
        
    def forward(self, input_image):
        output = F.relu(self.bn1(self.conv1(input_image)))
        output = F.relu(self.bn2(self.conv2(output)))
        return torch.tanh(self.conv3(output))

In [9]:
class Generator(nn.Module):
    def __init__(self, encoder, transformer, decoder):
        super(Generator, self).__init__()
        self.encoder = encoder
        self.transformer = transformer
        self.decoder = decoder
        
    def forward(self, x):
        enc = self.encoder(x)
        trans = self.transformer(enc)
        decoded = self.decoder(trans)
        return decoded

In [10]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.c1 = nn.Conv2d(3, 64, 4, 2, 1) #128
        self.c2 = nn.Conv2d(64, 128, 4, 2, 1) #64
        self.c3 = nn.Conv2d(128, 256, 4, 2, 1) #32
        self.c4 = nn.Conv2d(256, 512, 4, 2, 1) #16
        self.out = nn.Conv2d(512, 1, 16, 1, 0) # 1
        
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(128)
        self.bn3 = nn.BatchNorm2d(256)
        self.bn4 = nn.BatchNorm2d(512)
        
    def forward(self, x):
        x = F.relu(self.bn1(self.c1(x)))
        x = F.relu(self.bn2(self.c2(x)))
        x = F.relu(self.bn3(self.c3(x)))
        x = F.relu(self.bn4(self.c4(x)))
        x = torch.sigmoid(self.out(x))
        return x.view(-1)

In [11]:
generator_a_b = Generator(Encoder().to(device), 
                          Transformer(6).to(device), 
                          Decoder().to(device)).to(device)
generator_b_a = Generator(Encoder().to(device), 
                          Transformer(6).to(device), 
                          Decoder().to(device)).to(device)
discriminator_a = Discriminator().to(device)
discriminator_b = Discriminator().to(device)

# generator_a_b.apply(weights_init)
# generator_b_a.apply(weights_init)
# discriminator_a.apply(weights_init)
# discriminator_b.apply(weights_init)

mse_loss = nn.MSELoss()
cycle_loss = nn.L1Loss()

optimizer_G = torch.optim.Adam(itertools.chain(generator_a_b.parameters(), 
                                               generator_b_a.parameters()),
                               lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(itertools.chain(discriminator_a.parameters(), 
                                               discriminator_b.parameters()),
                               lr=lr_d, betas=(0.5, 0.999))

In [12]:
def save_models(models):
    for name, model in models:
        torch.save(model.state_dict(), save_models_path + name)


total_step = len(dataloader)
for epoch in range(num_epochs):
    for i, (img_a, img_b) in enumerate(dataloader):
        
        img_a = img_a.to(device)
        img_b = img_b.to(device)

        gen_b = generator_a_b(img_a)
        cyclic_a = generator_b_a(gen_b)
        
        gen_a = generator_b_a(img_b)
        cyclic_b = generator_a_b(gen_a)
        
        disc_a_real_pred = discriminator_a(img_a)
        disc_a_fake_pred = discriminator_a(gen_a)
        
        disc_b_real_pred = discriminator_b(img_b)
        disc_b_fake_pred = discriminator_b(gen_b)
        
        ones =  torch.tensor(1.0).expand_as(disc_a_real_pred).to(device)
        zeros = torch.tensor(0.0).expand_as(disc_a_fake_pred).to(device)
        
        disc_a_loss_real = mse_loss(disc_a_real_pred, ones)
        disc_a_loss_fake = mse_loss(disc_a_fake_pred, zeros)
        
        disc_b_loss_real = mse_loss(disc_b_real_pred, ones)
        disc_b_loss_fake = mse_loss(disc_b_fake_pred, zeros)
        
        # discriminator losses
        disc_a_loss = (disc_a_loss_real + disc_a_loss_fake) / 2
        disc_b_loss = (disc_b_loss_real + disc_b_loss_fake) / 2
        
        # generator losses
        gen_loss_a = mse_loss(disc_a_fake_pred, ones)
        gen_loss_b = mse_loss(disc_b_fake_pred, ones)
        
        # cyclic losses
        cycle_loss_a = cycle_loss(cyclic_a, img_a) * lambda_a
        cycle_loss_b = cycle_loss(cyclic_b, img_b) * lambda_b
        
        #update weights
        optimizer_G.zero_grad()
        loss_g = gen_loss_a + gen_loss_b + cycle_loss_a + cycle_loss_b
        loss_g.backward(retain_graph=True)
        optimizer_G.step()
        
        optimizer_D.zero_grad()
        disc_a_loss.backward(retain_graph=True)
        disc_b_loss.backward()
        optimizer_D.step()
        
        if (i+1) % save_results_every == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_a: {:.4f}, g_a: {:.4f}, c_a: {:.4f}, d_b: {:.4f}, g_b: {:.4f}, c_b: {:.4f}' 
                  .format(epoch, num_epochs, i+1, total_step, 
                          disc_a_loss.item(), gen_loss_a.item(),cycle_loss_a.item(),
                          disc_b_loss.item(), gen_loss_b.item(),cycle_loss_b.item()))
            vutils.save_image(gen_b.data, 
                              '{0}/fake_samples_epoch_{1}.png'.format(save_results_path, epoch), 
                              normalize = True)
            vutils.save_image(img_a.data, 
                              '{0}/samples_epoch_{1}.png'.format(save_results_path, epoch), 
                              normalize = True)
            save_models([("gen_a_b", generator_a_b), ("gen_b_a", generator_b_a),
                         ("disc_a", discriminator_a), ("disc_b", discriminator_b)])

Epoch [0/200], Step [100/1067], d_a: 0.5122, g_a: 0.6833, c_a: 1.9676, d_b: 0.4931, g_b: 0.9997, c_b: 4.7854
Epoch [0/200], Step [200/1067], d_a: 0.4994, g_a: 0.9999, c_a: 2.1278, d_b: 0.4972, g_b: 0.9999, c_b: 3.2650
Epoch [0/200], Step [300/1067], d_a: 0.4999, g_a: 1.0000, c_a: 2.0287, d_b: 0.4997, g_b: 0.0000, c_b: 2.1554
Epoch [0/200], Step [400/1067], d_a: 0.5000, g_a: 0.0000, c_a: 2.3022, d_b: 0.4985, g_b: 1.0000, c_b: 3.2464


RuntimeError: cuda runtime error (6) : the launch timed out and was terminated at /opt/conda/conda-bld/pytorch_1532581333611/work/aten/src/THC/generated/../generic/THCTensorMathPointwise.cu:266