# Exercise: Create your own AI image generator

In this exercise, you will create your own AI image generator. You will take the famous MNIST dataset and train a Variational Auto Encoder that will generate new images of handwritten digits.

Here are the steps we will follow in this lesson:
* Load the MNIST dataset
* Create a Variational Auto Encoder
* Train the Variational Auto Encoder
* Generate new images


Let's get started with setting up our environment.

In [None]:
# First, we will ensure the dependencies for this notebook are installed and imported.

# Dependencies for this notebook
! pip install torch torchvision matplotlib > /dev/null

# Imports for this notebook
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt

# Set the device
if torch.cuda.is_available():
    device = torch.device('cuda:0')
    print('GPU is available!')
else:
    device = torch.device('cpu')
    print('GPU is not available, CPU will be used.')

## Step 1: Load the MNIST dataset

First, we will load the MNIST dataset. The MNIST dataset contains 60,000 training images of handwritten digits from zero to nine and 10,000 testing images. The images have been normalized and centered in a fixed size of 28 x 28 pixels. You can learn more about the dataset [here](https://huggingface.co/datasets/mnist).

In [None]:
# Next, we use HuggingFace's datasets library to load the MNIST dataset

# Load the MNIST dataset
train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms.ToTensor(), download=True)

# Create a PyTorch dataloader, which loads data in batches and shuffles the data so the order of the images
# changes with each training epoch
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)

## Step 2: Create a Variational Auto Encoder

Next, we will create a Variational Auto Encoder. A Variational Auto Encoder is a type of neural network that can learn to generate new images. It is composed of two parts: an encoder and a decoder. Without going into the details too much, the encoder takes an image as input and outputs a vector of numbers that represents the image; whereas the decoder takes the vector of numbers as input and outputs an image. The encoder and decoder are trained together so that the decoder learns to generate images that look like the images in the training dataset.

In [None]:
# Here, we define the architecture of our Variational Autoencoder (VAE)

class VAE(nn.Module):
    def __init__(self, x_dim, h_dim1, h_dim2, z_dim):
        super(VAE, self).__init__()
        
        # Encoder parts
        self.fc1 = nn.Linear(x_dim, h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)
        self.fc31 = nn.Linear(h_dim2, z_dim)
        self.fc32 = nn.Linear(h_dim2, z_dim)

        # Decoder parts
        self.fc4 = nn.Linear(z_dim, h_dim2)
        self.fc5 = nn.Linear(h_dim2, h_dim1)
        self.fc6 = nn.Linear(h_dim1, x_dim)

        self.z_dim = z_dim
        
    def encoder(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return self.fc31(h), self.fc32(h) # mu, log_var
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # return z sample
        
    def decoder(self, z):
        h = F.relu(self.fc4(z))
        h = F.relu(self.fc5(h))
        return F.sigmoid(self.fc6(h)) 
    
    def forward(self, x):
        mu, log_var = self.encoder(x.view(-1, 784))
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var

# Construct the model given these parameters
vae = VAE(x_dim=784, h_dim1=512, h_dim2=256, z_dim=4)
vae = vae.to(device) # move the model to the device (GPU or CPU)

print(vae)

# Step 3: Train the Variational Auto Encoder

Now that we have created our Variational Auto Encoder, we will train it on the MNIST dataset. We will train it for 6 epochs, which means that we will show the training dataset to the Variational Auto Encoder 6 times. We will also use a batch size of 128, which means that we will show 128 images to the Variational Auto Encoder at a time.

In [None]:
# Here we define and optimizer and a loss function for our model

# The Adam optimizer is a popular optimizer for deep learning models
optimizer = optim.Adam(vae.parameters(), lr=5e-3)


# The loss function for our VAE contains multiple parts. It's not important now to understand the details of
# it, but it's good to know that this loss function is more complex than the other examples we will see.
def loss_function(recon_x, x, mu, log_var):

    # Reconstruction error. To measure how well we have reconstructed the input image.
    # See https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html#torch.nn.BCELoss
    reconstruction_error = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')

    # KL divergence. In essence, this term ensures that the distribution predicted in the latent space
    # is close to a standard normal distribution.
    kl_divergence = -0.5 * torch.sum(log_var - log_var.exp() - mu.pow(2) + 1)

    # Combining the two terms into a single loss value provides a trade-off between the two terms.
    return reconstruction_error + kl_divergence

In [None]:
# Here we define a PyTorch training loop for our VAE for a single epoch

def train_epoch(epoch, device=device):
    vae.train() # set model to training mode
    train_loss = 0

    # iterate over the training data
    for batch_idx, (data, _) in enumerate(train_loader):

        data = data.to(device) # move data to device
        optimizer.zero_grad() # clear the gradients of all optimized variables
        
        output_images, mu, log_var = vae(data) # forward pass
        loss = loss_function(output_images, data, mu, log_var) # calculate loss
        
        loss.backward() # backward pass
        train_loss += loss.item() # update running training loss
        optimizer.step() # perform a single optimization step (parameter update)
        
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} '
                  f'[{batch_idx * len(data)}/{len(train_loader.dataset)} {100. * batch_idx / len(train_loader):.0f}] '
                  f'Loss: {loss.item() / len(data):.6f}'
                  )
    print(f'====> Epoch: {epoch} Average loss: {train_loss / len(train_loader.dataset):.4f}')

In [None]:
# Now let's train the model! With each epoch, we will see the loss decrease as the model learns to
# reconstruct the input images.

! rm -rf ./data/gen_mnist_img && mkdir -p ./data/gen_mnist_img

z = torch.randn(16, vae.z_dim).to(device)

def write_gen_mnist_img(z, vae, epoch):
    with torch.no_grad():
        z_decoded = vae.decoder(z).to(device)

    filename = f'./data/gen_mnist_img/epoch_{epoch}.png'   
    print(f"Writing {filename}")
    save_image(z_decoded.view(16, 1, 28, 28), filename)

for epoch in range(6):
    if epoch % 3 == 0:
        write_gen_mnist_img(z, vae, epoch)

    train_epoch(epoch)

write_gen_mnist_img(z, vae, epoch)


In [None]:
# While training we generated images from the model. Let's take a look at them.

from IPython.display import Image, display
import matplotlib.pyplot as plt
import os

filenames = sorted(os.listdir('./data/gen_mnist_img'))

for filename in filenames:
    print("=========================================")
    print(filename)
    display(Image('./data/gen_mnist_img/' + filename))
    plt.show()


Notice how before training the VAE, the images look like random noise. After training the VAE, the images look like handwritten digits.

# Step 4: Generate new images

Now that we have trained our Variational Auto Encoder, we can use it to generate new images. Our VAE uses a vector of 4 numbers to represent an image. We can generate a random vector of 4 numbers and use our VAE to generate an image from it. We can also interpolate between two vectors of 4 numbers and generate the images that lie between them.

In [None]:
# Let's first generate 8 random handwritten digits

# Generate 8 random handwritten digits
z = torch.randn(8, vae.z_dim).to(device)

# Decode the random digits
with torch.no_grad():
    z_decoded = vae.decoder(z).to(device)

# Plot the decoded digits
plt.figure(figsize=(8, 2))
for i in range(8):
    plt.subplot(1, 8, i+1)
    plt.imshow(z_decoded[i].view(28, 28).cpu().numpy(), cmap='gray')
    plt.axis('off')

In [None]:
# Let's explore the space of z a little more. First, let's handcraft
# vectors that should exhibit a smooth transition in the output space

z = torch.tensor(
    [
        [ 0.0000,  0.0000,  0.0000,  -1.400],
        [ 0.0000,  0.0000,  0.0000,  -1.200],
        [ 0.0000,  0.0000,  0.0000,  -1.000],
        [ 0.0000,  0.0000,  0.0000,  -0.800],
        [ 0.0000,  0.0000,  0.0000,  -0.600],
        [ 0.0000,  0.0000,  0.0000,  -0.400],
        [ 0.0000,  0.0000,  0.0000,  -0.200],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.2000],
        [ 0.0000,  0.0000,  0.0000,  0.4000],
        [ 0.0000,  0.0000,  0.0000,  0.6000],
        [ 0.0000,  0.0000,  0.0000,  0.8000],
        [ 0.0000,  0.0000,  0.0000,  1.0000],
        [ 0.0000,  0.0000,  0.0000,  1.2000],
        [ 0.0000,  0.0000,  0.0000,  1.4000],
        [ 0.0000,  0.0000,  0.0000,  1.6000],
    ]
) * 3

# Then generate a handwritten digit from that vector.
z_decoded = vae.decoder(z.to(device))

filename = "./data/smooth transtion 1.png"
save_image(z_decoded.view(16, 1, 28, 28), filename)
display(Image(filename))


Congratulations! You have successfully created your own AI image generator.