# Monet-ifying photos with Basic (original) GAN in Pytorch

In this Notebook we will create a basic GAN true to the original paper, in the later Notebooks I will compare how improvements on GAN (DCGAN, WGAN, CycleGAN etc) stack up on this task.
This is part of a handout I'll do for a presentation at school. 
Please let me know if anything is unclear or you have ideas for improvements.

In [None]:
#importing relevant packages
import numpy as np
import pandas as pd
import torch
from torch import nn
import os
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader, Dataset
from PIL import Image

In [None]:
#choose gpu
device = 'cuda'

In [None]:
## for TPU
#!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
#!python pytorch-xla-env-setup.py --apt-packages libomp5 libopenblas-dev

In [None]:
def show_tensor_images(image_tensor, num_images=5, size=(3, 256, 256)):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in a uniform grid.
    '''
    image_unflat = image_tensor.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.figure(figsize = (20, 10))
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

In [None]:
#basic gan only takes 1D Vectors
#this is it's dimension
flatten_dim = 256*256*3

# *Generator*


The generator tries to learn the distribution of Monet Paintings, i.e. given a photo x, it will try to output the most likely monet painting y.
I.e. it tries to match the two distributions as closely as possible.

![What the generator attempts](https://i.imgur.com/t9zb0Cn.png)


In [None]:
def get_generator_block(input_dim, output_dim):
    '''
    Function for returning a block of the generator's neural network
    given input and output dimensions.
    Parameters:
        input_dim: the dimension of the input vector, a scalar
        output_dim: the dimension of the output vector, a scalar
    Returns:
        a generator neural network layer, with a linear transformation 
          followed by a batch normalization and then a relu activation
    '''
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.BatchNorm1d(output_dim),
        nn.ReLU(inplace=True)
    )

In [None]:
class Generator(nn.Module):
    '''
    Generator Class
    Values:
        im_dim: the dimension of the images 256*256 acts as noise vector
    '''
    def __init__(self, im_dim=flatten_dim, hidden_dim=128):
        super(Generator, self).__init__()
        # Build the neural network
        self.gen = nn.Sequential(
            #in flattened image of dimension 3*(256**2) out 128
            get_generator_block(im_dim, hidden_dim), 
            #in 128 out 256
            get_generator_block(hidden_dim, hidden_dim * 2), 
            #in 256 out 512
            get_generator_block(hidden_dim * 2, hidden_dim * 4), 
            #in 512 out 1024
            get_generator_block(hidden_dim * 4, hidden_dim * 8),
            #in 1024, out flattened image
            nn.Linear(hidden_dim*8, im_dim), 
            #scale pixel intensities to between 0 and 1
            nn.Sigmoid() 
        )
    def forward(self, image):
        '''
        Function for completing a forward pass of the generator: Given a noise tensor (photos), 
        returns generated images.
        Parameters:
            noise: a noise tensor with dimensions (n_samples, z_dim)
        '''
        
        return self.gen(image)
    

In [None]:
def get_discriminator_block(input_dim, output_dim):
    '''
    Discriminator Block
    Function for returning a neural network of the discriminator given input and output dimensions.
    Parameters:
        input_dim: the dimension of the input vector, a scalar
        output_dim: the dimension of the output vector, a scalar
    Returns:
        a discriminator neural network layer, with a linear transformation 
          followed by an nn.LeakyReLU activation with negative slope of 0.2 
          (https://pytorch.org/docs/master/generated/torch.nn.LeakyReLU.html)
    '''
    return nn.Sequential(
         nn.Linear(input_dim, output_dim),
        #LeakyRelu to hopefully prevent dying Relus
         nn.LeakyReLU(0.2, inplace=True)
    )

In [None]:
class Discriminator(nn.Module):
    '''
    Discriminator Class
    Values:
        im_dim: flatten_img dimension
        hidden_dim: the inner dimension, a scalar
    '''
    def __init__(self, im_dim=flatten_dim, hidden_dim=128):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            get_discriminator_block(im_dim, hidden_dim * 4),
            get_discriminator_block(hidden_dim * 4, hidden_dim * 2),
            get_discriminator_block(hidden_dim * 2, hidden_dim),
            nn.Linear(hidden_dim, 1)
            #could add sigmoid here, but we'll have it in the scoring
        )

    def forward(self, image):
        '''
        Function for completing a forward pass of the discriminator: Given an image tensor, 
        returns a 1-dimension tensor representing fake/real.
        Parameters:
            image: a flattened image tensor with dimension (im_dim)
        '''
        return self.disc(image)


In [None]:
# Set your parameters
criterion = nn.BCEWithLogitsLoss()
n_epochs = 700
display_step = 500
batch_size = 128
lr = 0.00001

In [None]:
#taken from https://www.kaggle.com/nachiket273/cyclegan-pytorch by @NACHIKET273
#changed a little for understandability
#creates dataset that feeds photo/monet noise, label
class ImageDataset(Dataset):
    def __init__(self, monet_dir, photo_dir, normalize=True):
        super().__init__()
        #folder with monets
        self.monet_dir = monet_dir
        #folder with photos
        self.photo_dir = photo_dir
        self.monet_idx = dict()
        self.photo_idx = dict()
        if normalize:
            self.transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))                                
            ])
        else:
            self.transform = transforms.Compose([
                transforms.ToTensor()                               
            ])
        #iterate over all monets and store them in dict by index
        for i, monet in enumerate(os.listdir(self.monet_dir)):
            self.monet_idx[i] = monet
            
        #iterate over all photos and store them in dict by index
        for i, photo in enumerate(os.listdir(self.photo_dir)):
            self.photo_idx[i] = photo

    def __getitem__(self, idx):
        rand_idx = int(np.random.uniform(0, len(self.monet_idx.keys())))
        photo_path = os.path.join(self.photo_dir, self.photo_idx[rand_idx])
        monet_path = os.path.join(self.monet_dir, self.monet_idx[idx])
        photo_img = Image.open(photo_path)
        photo_img = self.transform(photo_img)
        monet_img = Image.open(monet_path)
        monet_img = self.transform(monet_img)
        return photo_img, monet_img

    def __len__(self):
        return min(len(self.monet_idx.keys()), len(self.photo_idx.keys()))
    
    
class PhotoDataset(Dataset):
    def __init__(self, photo_dir, size=(256, 256), normalize=True):
        super().__init__()
        self.photo_dir = photo_dir
        self.photo_idx = dict()
        if normalize:
            self.transform = transforms.Compose([
                transforms.Resize(size),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))                                
            ])
        else:
            self.transform = transforms.Compose([
                transforms.Resize(size),
                transforms.ToTensor()                               
            ])
        for i, fl in enumerate(os.listdir(self.photo_dir)):
            self.photo_idx[i] = fl

    def __getitem__(self, idx):
        photo_path = os.path.join(self.photo_dir, self.photo_idx[idx])
        photo_img = Image.open(photo_path)
        photo_img = self.transform(photo_img)
        return photo_img

    def __len__(self):
        return len(self.photo_idx.keys())

In [None]:
#create dataset and dataloader to feed to GAN
img_ds = ImageDataset('../input/gan-getting-started/monet_jpg/', '../input/gan-getting-started/photo_jpg/')
dataloader = DataLoader(img_ds, batch_size=batch_size, pin_memory=True)

In [None]:
#get the generator
gen = Generator(im_dim = flatten_dim, hidden_dim = 128).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
#gen_rlr = torch.optim.lr_scheduler.ReduceLROnPlateau(gen_opt, mode = 'min')

#get the discriminator
disc = Discriminator().to(device) 
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)
#disc_rlr = torch.optim.lr_scheduler.ReduceLROnPlateau(disc_opt, mode = 'min')

In [None]:
def get_disc_loss(gen, disc, criterion, photo, num_images, monet, device):
    '''
    Return the loss of the discriminator given inputs.
    Parameters:
        gen: the generator model, which returns an image given photo of dimensions im_dim
        disc: the discriminator model, which returns a single-dimensional prediction of real/fake
        criterion: the loss function, which should be used to compare 
               the discriminator's predictions to the ground truth reality of the images 
               (e.g. fake = 0, real = 1)
        real: a batch of real images
        num_images: the number of images the generator should produce, 
                which is also the length of the real images
        z_dim: the dimension of the photo
        device: the device type
    Returns:
        disc_loss: a torch scalar loss value for the current batch
    '''
    fake = gen(photo)
    disc_fake_pred = disc(fake.detach())
    disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
    disc_real_pred = disc(monet)
    disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
    disc_loss = (disc_fake_loss + disc_real_loss) / 2
    return disc_loss

In [None]:
def get_gen_loss(gen, disc, criterion, num_images, photos, device):
    '''
    Return the loss of the generator given inputs.
    Parameters:
        gen: the generator model, which returns an image given z-dimensional noise
        disc: the discriminator model, which returns a single-dimensional prediction of real/fake
        criterion: the loss function, which should be used to compare 
               the discriminator's predictions to the ground truth reality of the images 
               (e.g. fake = 0, real = 1)
        num_images: the number of images the generator should produce, 
                which is also the length of the real images
        z_dim: the dimension of the noise vector, a scalar
        device: the device type
    Returns:
        gen_loss: a torch scalar loss value for the current batch
    '''
    fake = gen(photos)
    disc_fake_pred = disc(fake)
    gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
    return gen_loss

In [None]:
import matplotlib.pyplot as plt
for photo, monet in tqdm(dataloader):
    plt.imshow(monet[0].numpy().transpose((1, 2, 0)))
    break

In [None]:
cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
gen_loss = False
error = False
for epoch in range(n_epochs):
  
    # Dataloader returns the batches
    for photo, monet in dataloader:
        cur_batch_size = len(photo)

        # Flatten the batch of real images from the dataset
        photo = photo.view(cur_batch_size, -1).to(device)
        monet = monet.view(cur_batch_size, -1).to(device)

        ### Update discriminator ###
        # Zero out the gradients before backpropagation
        disc_opt.zero_grad()

        # Calculate discriminator loss
        disc_loss = get_disc_loss(gen, disc, criterion, photo, cur_batch_size, monet, device)

        # Update gradients
        disc_loss.backward(retain_graph=True)

        # Update optimizer
        disc_opt.step()
        
        #backpropagation
        gen_opt.zero_grad()
        gen_loss = get_gen_loss(gen, disc, criterion, cur_batch_size, photo, device)
        gen_loss.backward()
        gen_opt.step()

        # Keep track of the average discriminator loss
        mean_discriminator_loss += disc_loss.item() / display_step

        # Keep track of the average generator loss
        mean_generator_loss += gen_loss.item() / display_step

        #show images every display_step
        if cur_step % display_step == 0 and cur_step > 0:
            print(f"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")
            fake = gen(photo)
            show_tensor_images(fake)
            show_tensor_images(photo)
            mean_generator_loss = 0
            mean_discriminator_loss = 0
        cur_step += 1

In [None]:
photo_dataset = PhotoDataset('../input/gan-getting-started/photo_jpg/')
dataloader = DataLoader(photo_dataset, batch_size=1, pin_memory=True)

In [None]:
!mkdir ../images

In [None]:
os.listdir()

In [None]:
def unnorm(img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
    for t, m, s in zip(img, mean, std):
        t.mul_(s).add_(s)
        
    return img

In [None]:
topil = transforms.ToPILImage()

In [None]:
t = tqdm(dataloader, leave=False, total=dataloader.__len__())
gen.eval()
for i, photo in enumerate(t):
    with torch.no_grad():
        photo = photo.view(1, -1).to(device)
        pred_monet = gen(photo.to(device)).detach()
    pred_monet = unnorm(pred_monet) #I don't think this is necessary
    pred_monet = torch.reshape(pred_monet, (3, 256, 256))
    img = topil(pred_monet)
    #print(type(img))
    img = img.convert("RGB")
    img.save("../images/" + str(i+1) + ".jpg")

In [None]:
b = topil(pred_monet)

In [None]:
np.array(b).shape

In [None]:
plt.imshow(b)

In [None]:
import shutil
shutil.make_archive("/kaggle/working/images", 'zip', "/kaggle/images")

In [None]:
#save your models
torch.save(gen.state_dict(), 'generator')
torch.save(disc.state_dict(), 'discriminator')