## Main Reference
1. [PyTorch-GAN | Github/eriklindernoren | Collection of PyTorch implementations of GAN](https://github.com/sw-song/PyTorch-GAN)
2. [CycleGAN | Github/junyanz | Torch implementation for learning an image-to-image translation without input-output pairs](https://github.com/junyanz/CycleGAN)

## Index
```
Step 1. Import Libraries
Step 2. Initial Setting
Step 3. Define Generator
Step 4. Define Discriminator
Step 5. Define Loss Function
Step 6. Initialize Generator and Discriminator
Step 7. GPU Setting
Step 8. Weight Setting
Step 9. Configure Optimizer
Step 10. Learning Rate Scheduler Setting
Step 11. Image Transformation Setting
Step 12. DataLoader Setting
Step 13. Define function to get sample images
Step 14. Training
```
---

### Step 1. Import Libraries

In [None]:
import numpy as np

import torchvision.transforms as transforms
from torchvision.utils import make_grid

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

### Step 2. Initial Setting

In [None]:
n_cpu = 2 # number of cpu threads to use during batch generation

In [None]:
# data (path)
dataset_name = 'gan-getting-started'
root = '../input/'+dataset_name

# data (img)
img_height = 256
img_width = 256
channels = 3

# training
epoch = 0 # epoch to start training from
n_epochs = 30 # number of epochs of training
batch_size = 1 # size of the batches
lr = 0.0002 # adam : learning rate
b1 = 0.5 # adam : decay of first order momentum of gradient
b2 = 0.999 # adam : decay of first order momentum of gradient
decay_epoch = 3 # suggested default : 100 (suggested 'n_epochs' is 200)
                 # epoch from which to start lr decay


### Step 3. Define Generator

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1), # Pads the input tensor using the reflection of the input boundary
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features), 
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features)
        )

    def forward(self, x):
        return x + self.block(x)


class GeneratorResNet(nn.Module):
    def __init__(self, input_shape, num_residual_block):
        super(GeneratorResNet, self).__init__()
        
        channels = input_shape[0]
        
        # Initial Convolution Block
        out_features = 64
        model = [
            nn.ReflectionPad2d(channels),
            nn.Conv2d(channels, out_features, 7),
            nn.InstanceNorm2d(out_features),
            nn.ReLU(inplace=True)
        ]
        in_features = out_features
        
        # Downsampling
        for _ in range(2):
            out_features *= 2
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
        
        # Residual blocks
        for _ in range(num_residual_block):
            model += [ResidualBlock(out_features)]
            
        # Upsampling
        for _ in range(2):
            out_features //= 2
            model += [
                nn.Upsample(scale_factor=2), # --> width*2, heigh*2
                nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            
        # Output Layer
        model += [nn.ReflectionPad2d(channels),
                  nn.Conv2d(out_features, channels, 7),
                  nn.Tanh()
                 ]
        
        # Unpacking
        self.model = nn.Sequential(*model) 
        
    def forward(self, x):
        return self.model(x)

### Step 4. Define Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()
        
        channels, height, width = input_shape
        
        # Calculate output shape of image discriminator (PatchGAN)
        self.output_shape = (1, height//2**4, width//2**4)
        
        def discriminator_block(in_filters, out_filters, normalize=True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        
        self.model = nn.Sequential(
            *discriminator_block(channels, 64, normalize=False),
            *discriminator_block(64, 128),
            *discriminator_block(128,256),
            *discriminator_block(256,512),
            nn.ZeroPad2d((1,0,1,0)),
            nn.Conv2d(512, 1, 4, padding=1)
        )
        
    def forward(self, img):
        return self.model(img)

In [None]:
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

### Step 6. Initialize Generator and Discriminator

In [None]:
input_shape = (channels, img_height, img_width) # (3,256,256)
n_residual_blocks = 9 # suggested default, number of residual blocks in generator

Generator_monet_to_photo = GeneratorResNet(input_shape, n_residual_blocks)
Generator_photo_to_monet = GeneratorResNet(input_shape, n_residual_blocks)
Discriminator_monet = Discriminator(input_shape)
Discriminator_photo = Discriminator(input_shape)

### Step 7. GPU Setting

In [None]:
cuda = torch.cuda.is_available()

if cuda:
    Generator_monet_to_photo = Generator_monet_to_photo.cuda()
    Generator_photo_to_monet = Generator_photo_to_monet.cuda()
    Discriminator_monet = Discriminator_monet.cuda()
    Discriminator_photo = Discriminator_photo.cuda()
    
    criterion_GAN.cuda()
    criterion_cycle.cuda()
    criterion_identity.cuda()

### Step 8. Weight Setting

In [None]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02) # reset Conv2d's weight(tensor) with Gaussian Distribution
        if hasattr(m, 'bias') and m.bias is not None:
            torch.nn.init.constant_(m.bias.data, 0.0) # reset Conv2d's bias(tensor) with Constant(0)
        elif classname.find('BatchNorm2d') != -1:
            torch.nn.init.normal_(m.weight.data, 1.0, 0.02) # reset BatchNorm2d's weight(tensor) with Gaussian Distribution
            torch.nn.init.constant_(m.bias.data, 0.0) # reset BatchNorm2d's bias(tensor) with Constant(0)

In [None]:
Generator_monet_to_photo.apply(weights_init_normal)
Generator_photo_to_monet.apply(weights_init_normal)
Discriminator_monet.apply(weights_init_normal)
Discriminator_photo.apply(weights_init_normal)

In [None]:
def temp_weights_init_normal(m):
    classname =  m.__class__.__name__
    print(classname)

In [None]:
Generator_monet_to_photo.apply(temp_weights_init_normal);

### Step 9. Configure Optimizers

In [None]:
import itertools
# lr = 0.0002
# b1 = 0.5
# b2 = 0.999

optimizer_G = torch.optim.Adam(
    itertools.chain(Generator_monet_to_photo.parameters(), Generator_photo_to_monet.parameters()), lr=lr, betas=(b1,b2)
)

optimizer_Discriminator_monet = torch.optim.Adam(
    Discriminator_monet.parameters(), lr=lr, betas=(b1,b2)
)
optimizer_Discriminator_photo = torch.optim.Adam(
    Discriminator_photo.parameters(), lr=lr, betas=(b1,b2)
)

### Step 10. Learning Rate Scheduler Setting

In [None]:
class LambdaLR:
    def __init__(self, n_epochs, offset, decay_start_epoch):
        assert (n_epochs - decay_start_epoch) > 0, "Decay must start before the training session ends!"
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch
        
    def step(self, epoch):
        return 1.0 - max(0, epoch+self.offset - self.decay_start_epoch)/(self.n_epochs - self.decay_start_epoch)

In [None]:
# n_epochs = 10
# epoch = 0
# decay_epoch = 5


lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
    optimizer_G,
    lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step
)

lr_scheduler_Discriminator_monet = torch.optim.lr_scheduler.LambdaLR(
    optimizer_Discriminator_monet,
    lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step
)
lr_scheduler_Discriminator_photo = torch.optim.lr_scheduler.LambdaLR(
    optimizer_Discriminator_photo,
    lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step
)

### Step 11. Image Transformation Setting

In [None]:
from PIL import Image
import torchvision.transforms as transforms

transforms_ = [
    transforms.Resize(int(img_height*1.12), Image.BICUBIC),
    transforms.RandomCrop((img_height, img_width)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]

transforms_ = [
    transforms.Resize(int(img_height*1.12), Image.BICUBIC),
    transforms.RandomCrop((img_height, img_width)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]

### Step 12. DataLoader Setting

In [None]:
def to_rgb(image):
    rgb_image = Image.new("RGB", image.size)
    rgb_image.paste(image)
    return rgb_image

In [None]:
import os
import glob

In [None]:
print(root+'/monet_jpg')

In [None]:
len(glob.glob(os.path.join(root+'/monet_jpg')+'/*.*'))

In [None]:
len(glob.glob(os.path.join(root+'/photo_jpg')+'/*.*'))

In [None]:
from torch.utils.data import Dataset

class ImageDataset(Dataset):
    def __init__(self, root, transforms_=None, unaligned=False, mode='train'):
        self.transform = transforms.Compose(transforms_)
        self.unaligned = unaligned
        self.mode = mode
        if self.mode == 'train':
            self.files_Monet = sorted(glob.glob(os.path.join(root+'/monet_jpg')+'/*.*')[:25])
            self.files_Photo = sorted(glob.glob(os.path.join(root+'/photo_jpg')+'/*.*')[:250])
        elif self.mode == 'test':
            self.files_Monet = sorted(glob.glob(os.path.join(root+'/monet_jpg')+'/*.*')[25:])
            self.files_Photo = sorted(glob.glob(os.path.join(root+'/photo_jpg')+'/*.*')[250:301])
        elif self.mode == 'all':
            self.files_Monet = None
            self.files_Photo = sorted(glob.glob(os.path.join(root+'/photo_jpg')+'/*.*'))

    def  __getitem__(self, index):
        if self.files_Monet is None:
            image_Photo = Image.open(self.files_Photo[index % len(self.files_Photo)])
            if image_Photo.mode != 'RGB':
                image_Photo = to_rgb(image_Photo)
            item_B = self.transform(image_Photo)
            return {'B':item_B}
        image_Monet = Image.open(self.files_Monet[index % len(self.files_Monet)])
        
        if self.unaligned:
            image_Photo = Image.open(self.files_Photo[np.random.randint(0, len(self.files_Photo)-1)])
        else:
            image_Photo = Image.open(self.files_Photo[index % len(self.files_Photo)])
        if image_Monet.mode != 'RGB':
            image_Monet = to_rgb(image_Monet)
        if image_Photo.mode != 'RGB':
            image_Photo = to_rgb(image_Photo)
            
        item_A = self.transform(image_Monet)
        item_B = self.transform(image_Photo)
        return {'A':item_A, 'B':item_B}
    
    def __len__(self):
        if self.mode == 'all':
            return len(self.files_Photo)
        return max(len(self.files_Monet), len(self.files_Photo))
            

In [None]:
dataloader = DataLoader(
    ImageDataset(root, transforms_=transforms_, unaligned=True),
    batch_size=1, # 1
    shuffle=True,
    num_workers=n_cpu # 3
)

val_dataloader = DataLoader(
    ImageDataset(root, transforms_=transforms_, unaligned=True, mode='test'),
    batch_size=5,
    shuffle=True,
    num_workers=n_cpu
)

### Step 13. Define function to get sample images

In [None]:
import matplotlib.pyplot as plt

In [None]:
Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

In [None]:
def sample_images():
    """show a generated sample from the test set"""
    imgs = next(iter(val_dataloader))
    Generator_monet_to_photo.eval()
    Generator_photo_to_monet.eval()
    real_Monet = imgs['A'].type(Tensor) # A : monet
    fake_Photo = Generator_monet_to_photo(real_Monet).detach()
    real_Photo = imgs['B'].type(Tensor) # B : photo
    fake_Monet = Generator_photo_to_monet(real_Photo).detach()
    # Arange images along x-axis
    real_Monet = make_grid(real_Monet, nrow=5, normalize=True)
    fake_Photo = make_grid(fake_Photo, nrow=5, normalize=True)
    real_Photo = make_grid(real_Photo, nrow=5, normalize=True)
    fake_Monet = make_grid(fake_Monet, nrow=5, normalize=True)
    # Arange images along y-axis    
    image_grid = torch.cat((real_Monet, fake_Photo, real_Photo, fake_Monet), 1)
    plt.imshow(image_grid.cpu().permute(1,2,0))
    plt.title('Real A vs Fake B | Real B vs Fake A')
    plt.axis('off')
    plt.show();

> TEST CODE : show image data

In [None]:
temp_imgs = next(iter(val_dataloader))

In [None]:
Generator_monet_to_photo.eval() # test mode 
Generator_photo_to_monet.eval() # test mode
print(temp_imgs['A'].shape)
print(temp_imgs['B'].shape)


In [None]:
temp_real_Monet = temp_imgs['A'].type(Tensor) # A : monet
temp_fake_Photo = Generator_monet_to_photo(temp_real_Monet).detach()
temp_real_Photo = temp_imgs['B'].type(Tensor) # B : photo
temp_fake_Monet = Generator_photo_to_monet(temp_real_Photo).detach()

In [None]:
print(temp_real_Monet.shape)
print(temp_fake_Photo.shape)
print(temp_real_Photo.shape)
print(temp_fake_Monet.shape)

In [None]:
temp_real_Monet = make_grid(temp_real_Monet, nrow=5, normalize=True)
temp_real_Photo = make_grid(temp_real_Photo, nrow=5, normalize=True)
temp_fake_Monet = make_grid(temp_fake_Monet, nrow=5, normalize=True)
temp_fake_Photo = make_grid(temp_fake_Photo, nrow=5, normalize=True)

In [None]:
type(temp_real_Monet)

In [None]:
plt.figure(figsize=[100, 100])
plt.imshow(temp_real_Monet.cpu().permute(1,2,0))
plt.title('Real Monet')
plt.axis('off');

In [None]:
print(temp_real_Monet.shape)
print(temp_fake_Photo.shape)
print(temp_real_Photo.shape)
print(temp_fake_Monet.shape)

In [None]:
temp_image_grid = torch.cat((temp_real_Monet, temp_fake_Monet, temp_real_Photo, temp_fake_Photo), 1)
print(temp_image_grid.shape)

In [None]:
temp_image_grid.cpu().permute(1,2,0).shape

In [None]:
plt.figure(figsize=[100, 100])
plt.imshow(temp_image_grid.cpu().permute(1,2,0))
plt.title('Real A | Fake B | Real B | Fake A ')
plt.axis('off');

### Step 14. Training

In [None]:
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import warnings

In [None]:
loss_vals=  []
for epoch in range(epoch, n_epochs):
    epoch_loss= []
    for i, batch in enumerate(tqdm(dataloader)):
        
        # Set model input
        real_Monet = batch['A'].type(Tensor)
        real_Photo = batch['B'].type(Tensor)
        
        # Adversarial ground truths
        valid = Tensor(np.ones((real_Monet.size(0), *Discriminator_monet.output_shape))) # requires_grad = False. Default.
        fake = Tensor(np.zeros((real_Monet.size(0), *Discriminator_monet.output_shape))) # requires_grad = False. Default.
        
# -----------------
# Train Generators
# -----------------
        Generator_monet_to_photo.train() # train mode
        Generator_photo_to_monet.train() # train mode
        
        optimizer_G.zero_grad() # Integrated optimizer(Generator_monet_to_photo, Generator_photo_to_monet)
        
        # Identity Loss
        loss_Identity_Monet = criterion_identity(Generator_photo_to_monet(real_Monet), real_Monet) # If you put A into a generator that creates A with B,
        loss_Identity_Photo = criterion_identity(Generator_monet_to_photo(real_Photo), real_Photo) # then of course A must come out as it is.
                                                             # Taking this into consideration, add an identity loss that simply compares 'A and A' (or 'B and B').
        loss_identity = (loss_Identity_Monet + loss_Identity_Photo)/2
        
        # GAN Loss
        fake_Photo = Generator_monet_to_photo(real_Monet) # fake_Photo is fake-photo that generated by real monet-drawing
        loss_GAN_M2P = criterion_GAN(Discriminator_photo(fake_Photo), valid) # tricking the 'fake-B' into 'real-B'
        fake_Monet = Generator_photo_to_monet(real_Photo)
        loss_GAN_P2M = criterion_GAN(Discriminator_monet(fake_Monet), valid) # tricking the 'fake-A' into 'real-A'
        
        loss_GAN = (loss_GAN_M2P + loss_GAN_P2M)/2
        
        # Cycle Loss
        recov_A = Generator_photo_to_monet(fake_Photo) # recov_A is fake-monet-drawing that generated by fake-photo
        loss_cycle_Monet = criterion_cycle(recov_A, real_Monet) # Reduces the difference between the restored image and the real image
        recov_B = Generator_monet_to_photo(fake_Monet)
        loss_cycle_Photo = criterion_cycle(recov_B, real_Photo)
        
        loss_cycle = (loss_cycle_Monet + loss_cycle_Photo)/2
        
# ------> Total Loss
        loss_G = loss_GAN + (10.0*loss_cycle) + (5.0*loss_identity) # multiply suggested weight(default cycle loss weight : 10, default identity loss weight : 5)
        
        loss_G.backward()
        epoch_loss.append(loss_G.item())
        optimizer_G.step()
        
# -----------------
# Train Discriminator Monet
# -----------------
        optimizer_Discriminator_monet.zero_grad()
    
        loss_real = criterion_GAN(Discriminator_monet(real_Monet), valid) # train to discriminate real images as real
        loss_fake = criterion_GAN(Discriminator_monet(fake_Monet.detach()), fake) # train to discriminate fake images as fake
        
        loss_Discriminator_monet = (loss_real + loss_fake)/2
        
        loss_Discriminator_monet.backward()
        optimizer_Discriminator_monet.step()

# -----------------
# Train Discriminator Photo
# -----------------
        optimizer_Discriminator_photo.zero_grad()
    
        loss_real = criterion_GAN(Discriminator_photo(real_Photo), valid) # train to discriminate real images as real
        loss_fake = criterion_GAN(Discriminator_photo(fake_Photo.detach()), fake) # train to discriminate fake images as fake
        
        loss_Discriminator_photo = (loss_real + loss_fake)/2
        
        loss_Discriminator_photo.backward()
        optimizer_Discriminator_photo.step()
        
# ------> Total Loss
        loss_D = (loss_Discriminator_monet + loss_Discriminator_photo)/2
    
# -----------------
# Show Progress
# -----------------
        if (i+1) % 50 == 0:
            sample_images()
            print('[Epoch %d/%d] [Batch %d/%d] [D loss : %f] [G loss : %f - (adv : %f, cycle : %f, identity : %f)]'
                    %(epoch+1,n_epochs,       # [Epoch -]
                      i+1,len(dataloader),   # [Batch -]
                      loss_D.item(),       # [D loss -]
                      loss_G.item(),       # [G loss -]
                      loss_GAN.item(),     # [adv -]
                      loss_cycle.item(),   # [cycle -]
                      loss_identity.item(),# [identity -]
                     ))
    loss_vals.append(sum(epoch_loss)/len(epoch_loss))
    print(loss_vals)


In [None]:
def my_plot(epochs, loss,title):
    plt.title(title)
    plt.plot(epochs, loss)

In [None]:
# show Generetors Loss
my_plot(np.linspace(1, n_epochs, n_epochs).astype(int), loss_vals,"Generetors Loss")

# **Submission**

In [None]:
#Directory
! mkdir ../images

In [None]:
!pwd

In [None]:
def reverse_normalize(image, mean_=0.5, std_=0.5):
    if torch.is_tensor(image):
        image = image.detach().numpy()
    un_normalized_img = image * std_ + mean_
    un_normalized_img = un_normalized_img * 255
    return np.uint8(un_normalized_img)

In [None]:
Generator_photo_to_monet.eval() # test mode

mean_=0.5 
std_=0.5

#Get data loader for final transformation / submission
submit_dataloader  = DataLoader(
    ImageDataset(root, 
                 transforms_=transforms_, unaligned=False,mode='all'),
    batch_size=1, # 1
    shuffle=False,
    num_workers=n_cpu # 3
)
print(len(submit_dataloader))
dataiter = iter(submit_dataloader)
#Loop through each picture
for image_idx in range(0, len(submit_dataloader)):
    #Get base picture
    fixed_X = next(dataiter)
    real_Photo = fixed_X['B'].type(Tensor) # B : photo
    #Identify correct device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    fake_Monet_monet = Generator_photo_to_monet(real_Photo)
    image1_numpy = fake_Monet_monet.detach().cpu().numpy()
    image1_numpy= np.squeeze(image1_numpy, axis=0)
    image1_numpy = image1_numpy.transpose(1, 2, 0)
    image1_numpy = reverse_normalize(image1_numpy, mean_, std_)
    image1_numpy = np.uint8(image1_numpy)
    image_i = Image.fromarray(image1_numpy)
    print(image1_numpy.shape)
    #Save picture
    image_i.save("../images/" + str(image_idx) + ".jpg")

    
    

In [None]:
import shutil
shutil.make_archive("/kaggle/working/images", 'zip', "/kaggle/images")