# Creating Monet Paintings using Gans and Pytorch
In this notebook I will create a simple GAN to paint MONET paintints.

In [None]:
# Set wandb environment variable
from kaggle_secrets import UserSecretsClient
import os
os.environ["WANDB_API_KEY"]  = UserSecretsClient().get_secret("WANDB_API_KEY")

In [None]:
import pandas as pd

import random
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

import wandb
import torch
from torch import nn
import torch.optim as optim
import torch.utils.data
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import tqdm

In [None]:
# Folders Path
training_pictures = '../input/gan-getting-started/monet_jpg/'

# Image Size. Necessary for defining the transforms
IMG_SIZE = 256

# Batch size during training
BATCH_SIZE = 100

# Noise Dimention
Z_DIM = 100

# Number of epochs to train
NUM_EPOCHS = 2000

# Number of steps to log results
log_step = 20

device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
device

# Loading and visualizing the Pictures

In [None]:
class ImageDataset(Dataset):
    """ Custom Dataset class """
    def __init__(self, root: str, transform: transforms.Compose):
        self.root = root
        self.transform = transform
        self.all_imgs = tuple(os.path.join(root, p) for p in os.listdir(root) if p.endswith('.jpg'))

    def __len__(self) -> int:
        return len(self.all_imgs)

    def __getitem__(self, idx: int) -> torch.Tensor:
        return self.transform(self.load_image(self.all_imgs[idx]))
    
    @staticmethod
    def load_image(path: str) -> np.ndarray:
        with Image.open(path) as p:
            img = np.asarray(p)
        return img
        
    def plot_random_images(self, n_img: int) -> None:
        
        # Sample random number of images
        sampled_images = random.sample(self.all_imgs, n_img)
        
        images_list = list(np.transpose(torch.from_numpy(self.load_image(p)), (2, 0, 1)) for p in sampled_images)
        image_grid = vutils.make_grid(images_list)
        plt.figure(figsize=(20, 20))
        plt.imshow(np.transpose(image_grid, (1, 2, 0)))
        plt.axis('off')
        plt.show()
        
    
dataset = ImageDataset(root=training_pictures,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                           ]))

dataset.plot_random_images(40)

# Create the GANS

## Generator

In [None]:
def get_noise(n_samples):
    """ Get Noise function"""
    return torch.randn(n_samples, Z_DIM, 1, 1, device=device, dtype=torch.float)

class Generator(nn.Module):
    """ Generator Class """
    def __init__(self, z_dim: int, hidden_dim: int):
        super().__init__()
        self.main = nn.Sequential(
            self._block(z_dim, hidden_dim*4, kernel_size=9, stride=3, padding=0, output_padding=2),
            self._block(hidden_dim*4, hidden_dim*2, kernel_size=9, stride=3, padding=0, output_padding=1),
            self._block(hidden_dim*2, hidden_dim, kernel_size=8, stride=3, padding=0,output_padding=0),
            nn.ConvTranspose2d(hidden_dim, 3, kernel_size=8, stride=2, padding=0, output_padding=0),
            nn.Tanh()  
        )
        
    def _block(self, in_channels: int, out_channels: int, kernel_size=8, stride=1, dilation=1, padding=0, output_padding=0) -> nn.Sequential:
        return nn.Sequential(
                    nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding, output_padding=output_padding),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU(),
                )


    def forward(self, input):
        return self.main(input)
    
# Assert output image shape
Z_DIM=20
gen_shape = np.transpose(Generator(Z_DIM, 5).to(device)(get_noise(1)).detach().cpu().numpy(), (2, 3, 1, 0))[:, :, :, 0].shape
assert gen_shape == (IMG_SIZE, IMG_SIZE, 3), gen_shape

## Discriminator

In [None]:
class Critic(nn.Module):
    '''
    Critic Class I took from my coursera course.
    Values:
        im_chan: the number of channels in the images, fitted for the dataset used, a scalar
              (MNIST is black-and-white, so 1 channel is your default)
        hidden_dim: the inner dimension, a scalar
    '''
    def __init__(self, im_chan=3, hidden_dim=20):
        super().__init__()
        self.crit = nn.Sequential(
            self.make_crit_block(im_chan, hidden_dim*6, kernel_size=8, stride=3),
            self.make_crit_block(hidden_dim*6, hidden_dim*3, kernel_size=8, stride=3),
            self.make_crit_block(hidden_dim*3, hidden_dim, kernel_size=8, stride=3),
            self.make_crit_block(hidden_dim, 1, stride=1),
            self.make_crit_block(1, 1, kernel_size=9,final_layer=True),
        )

    def make_crit_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False, padding=1):
        '''
        Function to return a sequence of operations corresponding to a critic block of DCGAN;
        a convolution, a batchnorm (except in the final layer), and an activation (except in the final layer).
        Parameters:
            input_channels: how many channels the input feature representation has
            output_channels: how many channels the output feature representation should have
            kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
            stride: the stride of the convolution
            final_layer: a boolean, true if it is the final layer and false otherwise 
                      (affects activation and batchnorm)
        '''
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace=True),
            )
        else:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size=kernel_size, stride=stride, padding=padding),
                nn.Sigmoid()
            )

    def forward(self, image):
        '''
        Function for completing a forward pass of the critic: Given an image tensor, 
        returns a 1-dimension tensor representing fake/real.
        Parameters:
            image: a flattened image tensor with dimension (im_chan)
        '''
        crit_pred = self.crit(image)
        return crit_pred.view(len(crit_pred), -1)
    
# Assert the output
disc_size = Critic()(dataset[0].view(1, 3, IMG_SIZE, IMG_SIZE)).view(-1).size()
assert disc_size == torch.Size([1]), disc_size

# Start the training

In [None]:
# Initialize Critic and Generator
gen = Generator(Z_DIM, 100).to(device)
gen_opt = torch.optim.Adam(gen.parameters())
crit = Critic().to(device) 
crit_opt = torch.optim.Adam(crit.parameters(), lr=0.00005)

# Loss function
criterion = nn.BCELoss()

In [None]:
# Create the data loader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
wandb.init(project='Monet GAN', entity='djrmarques')
wandb.watch(gen, log="all", log_freq=log_step)
wandb.watch(crit, log="all", log_freq=log_step)

I want to have the generator trained more times that the critic. For this reason I created a function that will compute the multiplier of times the generator gets trained more than the discriminator based on the current EPOCH. 

In [None]:
def size_generator_training(epoch):
    max_multiple = 2
    return 1
    return min(int(round((epoch/NUM_EPOCHS)*max_multiple)) + 1, max_multiple)
size_generator_training(1)

In [None]:
for epoch in tqdm.trange(NUM_EPOCHS):
    gen_batch_loss = []
    crit_batch_loss = []
    
    gen_train_mult = size_generator_training(epoch)
    for i, data in enumerate(dataloader):
        
        ## Update the Critic with Real images
        crit.zero_grad()
        real_data = data.to(device)
        b_size = real_data.shape[0]
        real_output = crit(real_data).view(-1)
        real_error = criterion(real_output, torch.ones(b_size, dtype=torch.float, device=device))
        real_error.backward()
        
        ## Train with fake image
        noise = get_noise(b_size*gen_train_mult)  # Train the generator 2 times for each time the Critic is trained
        fake = gen(noise)
        fake_crit_train = fake[:b_size]
        fake_output = crit(fake_crit_train.detach()).view(-1)
        fake_error = criterion(fake_output, torch.zeros(b_size, dtype=torch.float, device=device))
        fake_error.backward()
        disc_error = real_error + fake_error
        crit_opt.step()
        
        
        ## Update the Generator
        gen.zero_grad()
        output = crit(fake).view(-1)
        gen_error = criterion(output, torch.ones(b_size*gen_train_mult, dtype=torch.float, device=device))
        gen_error.backward()
        crit.zero_grad()
        gen_opt.step()
        
        gen_batch_loss.append(gen_error.mean().item())
        crit_batch_loss.append(disc_error.mean().item())
        
    if epoch%log_step == 0:
        # Log WAB
        wandb.log({"Critic Error": sum(crit_batch_loss)/len(crit_batch_loss), 'Generator Error': sum(gen_batch_loss)/len(gen_batch_loss)})

        im = plt.imshow(np.transpose(vutils.make_grid(gen(torch.randn(5, Z_DIM, 1, 1, device=device)).detach().cpu()), (1, 2, 0)))
        wandb.log({"Generated Images": wandb.Image(im, caption="Generated Images")})

# Save the trained model

In [None]:
torch.save(gen.state_dict(), "generator.model")
torch.save(crit.state_dict(), "critic.model")