The GAN framework is comprised of two neural networks, the generator, and discriminator. In the context of image generation, the generator generates fake data, when given noise as input, and the discriminator classifies real images from fake images. During training, the generator and the discriminator compete with each other in a game and as a result, get better at their jobs. The generator tries to generate better-looking images to fool the discriminator and the discriminator tries to get better at identifying real images from fake images.

To train a GAN, we need a training dataset. Given a training dataset, the GAN will learn to generate new data with the same distribution as the training dataset. For instance, if we train a GAN on cat images, it will learn to generate new cat images that look real to our eyes.

## Creating the Dataset

In [None]:
from torchvision import datasets
import torchvision.transforms as transforms
import os

# path to store/load data
path2data="./data"
os.makedirs(path2data, exist_ok= True)
    
"""
The original images might be in different sizes, thus we used a Resize transformation 
to resize images to 64 by 64. Next, ToTensor scales the image pixels to the range of 
[0, 1]. Next, we applied a normalization. The normalization mean and std values were set 
to normalize inputs to the range of [-1, 1]. As you will find out in Defining the 
Generator and Discriminator recipe, the output of the generator model is a tanh function 
that generates outputs in the range [-1, 1]
"""
h, w = 64, 64
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)
transform= transforms.Compose([
           transforms.Resize((h,w)),
           transforms.CenterCrop((h,w)),
           transforms.ToTensor(),
           transforms.Normalize(mean, std)])
    
train_ds=datasets.STL10(path2data, split='train', 
                        download=True,
                        transform=transform)
print(len(train_ds))

In [None]:
import torch
"""
As expected, the extracted sample is a PyTorch tensor in the shape of (3, height, width) 
and is normalized to the range of [-1, 1].
"""
for x, _ in train_ds:
    print(x.shape, torch.min(x), torch.max(x))
    break

In [None]:
import torch

for x,y in train_ds:
    print(x.shape,y)
    break

In [None]:
from torchvision.transforms.functional import to_pil_image
import matplotlib.pylab as plt
%matplotlib inline

# since the tensor was normalized to [-1, 1], we had to re-normalize it for visualization purposes
plt.imshow(to_pil_image(0.5*x+0.5))

In [None]:
import torch

batch_size = 32 # chosen
train_dl = torch.utils.data.DataLoader(train_ds, 
                                       batch_size=batch_size, 
                                       shuffle=True)

In [None]:
for x,y in train_dl:
    print(x.shape, y.shape)
    break

## Defining Generator

In [None]:
from torch import nn
import torch.nn.functional as F

class Generator(nn.Module):
    """
    nz: The size of the input noise vector (set to 100)
    ngf: A coefficient for the number of convolutional filters in the generator (set to 64)
    noc: The number of output channels (set to 3 for RGB images)
    
    A conv-transpose layer is also called a fractionally-strided convolution or a 
    deconvolution. They are used to upsample the input vector to the desired output 
    size.
    
    
    """
    def __init__(self, params):
        super(Generator, self).__init__()
        nz = params["nz"]
        ngf = params["ngf"]
        noc = params["noc"]
        self.dconv1 = nn.ConvTranspose2d( nz, ngf * 8, kernel_size=4,
                                         stride=1, padding=0, bias=False)
        self.bn1 = nn.BatchNorm2d(ngf * 8)
        self.dconv2 = nn.ConvTranspose2d(ngf * 8, ngf * 4, kernel_size=4, 
                                         stride=2, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(ngf * 4)
        self.dconv3 = nn.ConvTranspose2d( ngf * 4, ngf * 2, kernel_size=4, 
                                         stride=2, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(ngf * 2)
        self.dconv4 = nn.ConvTranspose2d( ngf * 2, ngf, kernel_size=4, 
                                         stride=2, padding=1, bias=False)
        self.bn4 = nn.BatchNorm2d(ngf)
        self.dconv5 = nn.ConvTranspose2d( ngf, noc, kernel_size=4, 
                                         stride=2, padding=1, bias=False)

    def forward(self, x):
        x = F.relu(self.bn1(self.dconv1(x)))
        x = F.relu(self.bn2(self.dconv2(x)))            
        x = F.relu(self.bn3(self.dconv3(x)))        
        x = F.relu(self.bn4(self.dconv4(x)))    
        out = torch.tanh(self.dconv5(x))
        return out

In [None]:
params_gen = {
        "nz": 100,
        "ngf": 64,
        "noc": 3,
        }
model_gen = Generator(params_gen)
device = torch.device("cuda:3")
model_gen.to(device)
print(model_gen)

In [None]:
# To make sure that the model was created properly, we passed some dummy input to the generator model. As expected, the model output is a tensor of shape [1, 3, 64, 64]
with torch.no_grad():
    y= model_gen(torch.zeros(1,100,1,1, device=device))
print(y.shape)    

## Defining Discriminator

In [None]:
"""
Similarly, in the __init__ method, we defined the layers and in the forward method, we 
defined the connections between the layers. Notice that we did not use any pooling layers 
and instead set the stride argument to 2 or 4 to downsample the input size. Also, notice 
that leaky_relu activation was used instead of relu to reduce overfitting.
"""
class Discriminator(nn.Module):
    def __init__(self, params):
        super(Discriminator, self).__init__()
        nic= params["nic"]
        ndf = params["ndf"]
        self.conv1 = nn.Conv2d(nic, ndf, kernel_size=4, stride=2, padding=1, bias=False)
        self.conv2 = nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(ndf * 2)            
        self.conv3 = nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(ndf * 4)
        self.conv4 = nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1, bias=False)
        self.bn4 = nn.BatchNorm2d(ndf * 8)
        self.conv5 = nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=1, padding=0, bias=False)

    def forward(self, x):
        x = F.leaky_relu(self.conv1(x), 0.2, True)
        x = F.leaky_relu(self.bn2(self.conv2(x)), 0.2, inplace = True)
        x = F.leaky_relu(self.bn3(self.conv3(x)), 0.2, inplace = True)
        x = F.leaky_relu(self.bn4(self.conv4(x)), 0.2, inplace = True)        
        
        out = torch.sigmoid(self.conv5(x))
        return out.view(-1)

In [None]:
params_dis = {
    "nic": 3,
    "ndf": 64}
model_dis = Discriminator(params_dis)
model_dis.to(device)
print(model_dis)

In [None]:
# sample
with torch.no_grad():
    y= model_dis(torch.zeros(1,3,h,w, device=device))
print(y.shape)    

In [None]:
"""
The DCGAN paper suggested initializing the weights using a normal distribution with 
mean=0 and std=0.02, as we did in the helper function.
"""
def initialize_weights(model):
    classname = model.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(model.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(model.weight.data, 1.0, 0.02)
        nn.init.constant_(model.bias.data, 0)

In [None]:
model_gen.apply(initialize_weights);
model_dis.apply(initialize_weights);

For the models to learn, we need to define a criterion. The discriminator model is a classification network and we can use the binary cross-entropy (BCE) loss function as its criterion. For the generator model to learn, we pass its output to the discriminator model and then evaluate the output of the discriminator model. Thus, the same BCE loss function can be used as a criterion to train the generator model. Also, we will use the Adam optimizer to update the parameters of the discriminator and generator models.

## Defining Loss, Optimizer

In [None]:
loss_func = nn.BCELoss()

In [None]:
from torch import optim
"""
we defined the Adam optimizer from torch.optim for the generator model based on the 
hyperparameters suggested in the DCGAN paper. The paper suggested setting the learning 
rate to 0.0002 and the momentum term beta1 to 0.5 for training stability.
"""
lr = 2e-4 
beta1 = 0.5
opt_dis = optim.Adam(model_dis.parameters(), lr=lr, betas=(beta1, 0.999))

In [None]:
opt_gen = optim.Adam(model_gen.parameters(), lr=lr, betas=(beta1, 0.999))

Training the GAN framework is done in two stages: training the discriminator and training the generator. To this end, we will take the following steps:

1. Get a batch of real images with the target labels set to 1.
2. Generate a batch of fake images using the generator with the target labels set to 0.
3. Feed the mini-batches to the discriminator and compute the loss and gradients.
4. Update the discriminator parameters using the gradients.
5. Generate a batch of fake images using the generator with the target labels set to 1.
6. Feed the fake mini-batch to the discriminator and compute the loss and gradients.
7. Update the generator only based on gradients.
8. Repeat from step 1.


## Training

In [None]:
"""
we defined a few parameters. We defined real_label and fake_label and set them to 1 and 0, 
respectively. Later, we will need to label a mini-batch using these parameters. The nz 
parameter specifies the size of the input noise vector to the generator model. This was 
set to 100 in Defining the generator and discriminator recipe. The num_epochs parameter 
specifies how many times we want to iterate over the training data. To store the loss 
values for the discriminator and generator models, we defined the loss_history dictionary.
"""
real_label = 1
fake_label = 0
nz = params_gen["nz"]
num_epochs = 100


loss_history={"gen": [],
              "dis": []}


"""
we implemented the training loop. The training loop iterates over the real dataset 
for num_epochs. In each epoch, we got a batch of real images from celeb_dl and fed it 
to the discriminator model and got its output as out_dis. Note that here, the real 
images were labeled with real_label using the torch.full method. Then, the loss value 
for the real mini-batch was calculated as loss_r. Next, the gradients of loss_r with 
respect to the discriminator parameters were calculated in a backward pass.
"""
batch_count = 0
for epoch in range(num_epochs):
    for xb, yb in train_dl:
        ba_si = xb.size(0)
        model_dis.zero_grad()
        xb = xb.to(device)
        yb = torch.full((ba_si,), real_label, device=device)
        out_dis = model_dis(xb)
        loss_r = loss_func(out_dis, yb)
        loss_r.backward()

        """
        In passing the output of the generator to the discriminator, we used the .detach() 
        method to avoid gradient tracking for the generator model. Note that at this point, 
        the fake images were labeled with fake_label using the torch.fill_ method. Then, 
        the loss value for the fake mini-batch was calculated as loss_f. Next, the gradients 
        of loss_f with respect to the discriminator parameters were calculated in a backward 
        pass.
        """
        noise = torch.randn(ba_si, nz, 1, 1, device=device)
        out_gen = model_gen(noise)
        out_dis = model_dis(out_gen.detach())
        yb.fill_(fake_label)    
        loss_f = loss_func(out_dis, yb)
        loss_f.backward()
        loss_dis = loss_r + loss_f  
        opt_dis.step()   

        """
        Next, we trained the generator model. To this end, we passed the fake images to 
        the discriminator model and got its output. Note that here, the fakes images were 
        labeled with real_label using the .fill_ method. This may sound strange at first, 
        but it is done to force the generator model to generate better-looking images.
        """
        model_gen.zero_grad()
        yb.fill_(real_label)  
        out_dis = model_dis(out_gen)
        loss_gen = loss_func(out_dis, yb)
        loss_gen.backward()
        opt_gen.step()

        loss_history["gen"].append(loss_gen.item())
        loss_history["dis"].append(loss_dis.item())
        batch_count += 1
        if batch_count % 100 == 0:
            print(epoch, loss_gen.item(),loss_dis.item())
        

In [None]:
plt.figure(figsize=(10,5))
plt.title("Loss Progress")
plt.plot(loss_history["gen"],label="Gen. Loss")
plt.plot(loss_history["dis"],label="Dis. Loss")
plt.xlabel("batch count")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
# store models
import os
path2models = "./models/"
os.makedirs(path2models, exist_ok=True)
path2weights_gen = os.path.join(path2models, "weights_gen_128.pt")
path2weights_dis = os.path.join(path2models, "weights_dis_128.pt")

torch.save(model_gen.state_dict(), path2weights_gen)
torch.save(model_dis.state_dict(), path2weights_dis)

## Deploying Generator

Once we've trained a GAN, we end up with two trained models. Usually, we discard the discriminator model and keep the generator model. We can use the trained generator to generate new images. To deploy the generator model, we load the trained weights into the model and then feed it with random noise. Make sure to define the model class beforehand. To avoid repetition, we will not define the model class here.

In [None]:
# Load the weights:
weights = torch.load(path2weights_gen)
model_gen.load_state_dict(weights)
model_gen.eval()

In [None]:
import numpy as np
# set the model in evaluation mode
with torch.no_grad():
    # fed random noise vectors into the model and received generated fake images
    fixed_noise = torch.randn(16, nz, 1, 1, device=device)
    print(fixed_noise.shape)
    img_fake = model_gen(fixed_noise).detach().cpu()    
print(img_fake.shape)

plt.figure(figsize=(10,10))
for ii in range(16):
    plt.subplot(4,4,ii+1)
    plt.imshow(to_pil_image(0.5*img_fake[ii]+0.5)) # re-normalize the output tensor back to its original values for visualization purposes
    plt.axis("off")
    

Check out the generated images. Some of them may look very distorted, while others look relatively realistic. To improve the results, you can train the model on a single class of data as opposed to multiple classes together. GANs perform better when they are trained with a single class. The STL-10 dataset has multiple classes. Try to select one category and the train the GAN models. Also, you can try to train the model for a longer time and see how that changes the generated images.