In [None]:
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F

# Exploring Dataset and Preprocessing

The dataset contains a total of 2872 images which have been resized to be 64 by 64. In addition images are converted to tensor form and have shape [3, 64, 64]

The dataset is passed through a dataloader to sample batches from.

In [None]:
path = "gallery"
batch_size = 32
image_size = (64, 64)
device = torch.device("cpu")
img_transforms = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    ])

dataset = torchvision.datasets.ImageFolder(root=path, transform=img_transforms)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)

dataset




len(data) = 2

data[0].shape = torch.Size([32, 3, 40, 40])

data[1].shape = torch.Size([32])

In [None]:
# showing the first 3 batches
for batch, data in enumerate(dataloader):
    plt.figure(figsize=(5,5))
    plt.imshow(np.transpose(torchvision.utils.make_grid(data[0].to(device)).cpu(),(1,2,0)))
    
    if batch == 2: break
    

# Model and Training

## GAN Model
The [Generative Adversarial Net](https://arxiv.org/abs/1406.2661) consists of a **generative model** that aims to capture the data distribution and **discriminative model** which estimates the probability that a sample game from the dataset rather than from the generative model. 

As the generator is "generating" data then the output shape should be [3, 64, 64] as is the ungenerated data. 

On the other hand, as the discriminator is determining whether or not the data is from the dataset or generated then it should ouput a one-hot encoded vector (e.g. [0, 1]).

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        
        self.conv2 = nn.Conv2d(64, 128, 3, 1)
        self.bn2 = nn.BatchNorm2d(128)
        
        self.conv3 = nn.Conv2d(128, 512, 3, 1)
        self.bn3 = nn.BatchNorm2d(512)
        
    def forward(self, x):
        out = nn.LeakyReLU(self.bn1(self.conv1a(x)))
        out = nn.LeakyReLU(self.bn2(self.conv2(out)))
        out = nn.LeakyReLU(self.bn3(self.conv3(out)))
        return out
        
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        
        self.conv2 = nn.Conv2d(64, 128, 3, 1)
        self.bn2 = nn.BatchNorm2d(128)
        
        self.drop = nn.Dropout(p=0.2)
        
        self.lin1 = nn.Linear(128, 64)
        self.lin2 = nn.Linear(64, 1)
        
    def forward(self, x):
        out = nn.LeakyReLU(self.bn1(self.conv1a(x)))
        out = nn.LeakyReLU(self.bn2(self.conv2(out)))
        out = self.drop(out)                   
        out = self.lin2(self.lin1(out))
        return out
          
        

## Training 
Ultimately we want to update the discriminator to maximize 
$$log(D(x)) + log(1 - D(G(z)))$$

For the generator, we want to minimize 
$$log(1 - D(G(z)))$$

In [None]:
generator = Generator()
discriminator = Discriminator()
optimizerG = torch.optim.Adam(G.parameters(), lr=1e-4)
optimizerD = torch.optim.Adam(D.parameters(), lr=1e-4)

epochs = 200
criterion = nn.BCELoss()
lossesG = []
lossesD = []
generated_images = []

In [None]:
for i in range(epochs):
    for batch, data in enumerate(dataloader):
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
    
    # ----- train discriminator ----- #
        discriminator.zero_grad()
        output = discriminator()    
    
    # ----- train generator ----- #
        generator.zero_grad()
        generated_image = generator(torch.randn(real_cpu.size(0), latent_size, 1, 1, device=device)) # feeding noise
        output = discriminator(generated_image)
        targets = torch.ones(generated_image.size(0), 1, device=device)
        loss_G = criterion(output, targets)
    
    
    
    lossesG.append(loss_G)
    if i % 10:
        print(f"Epoch: {i}, discriminator loss, generator loss")
    break