**Device**

In [None]:
import torch


def get_device() -> torch.device:
    """
    Returns the best available device (CUDA, MPS, or CPU).

    Returns:
        torch.device: The best available device ('cuda', 'mps', or 'cpu').
    """
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif torch.backends.mps.is_available():
        return torch.device("mps")
    else:
        return torch.device("cpu")


device = get_device()

# Face Generation

We will define and train a **Style Generative Adverserial Network** on a dataset of faces. The goal is to get a generator network to generate *new* images of faces that look as realistic as possible!

### Get the Data

We'll be using the [CelebFaces Attributes Dataset (CelebA)](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) to train your adversarial networks.

This dataset has higher resolution images than datasets like MNIST or SVHN and so, we should prepare to define deeper networks, StleGAN generator, and train them for a longer time to get good results. It is suggested that we utilize a GPU for training.

### Pre-processed Data

Since we are focused is on building the GANs, we will be using a pre-processed data by [Udacity](https://udacity.com/). Each of the CelebA images has been cropped to remove parts of the image that don't include a face, then resized down to 64x64x3 NumPy images. Some sample data is show below.

<img src='../assets/processed_face_data.png' width=60% />

> Please download this data from Udacity [by clicking here](https://s3.amazonaws.com/video.udacity-data.com/topher/2018/November/5be7eb6f_processed-celeba-small/processed-celeba-small.zip)

This is a zip file and after extracting in the home directory of this notebook for further loading and processing. After extracting the data, we should be left with a directory of data `processed-celeba-small/`.

In [None]:
from glob import glob
from typing import Tuple, Callable, Dict

import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms as T
import torch.nn.functional as F

## Data pipeline

The [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) dataset contains over 200,000 celebrity images with annotations. Since we're going to be generating faces, we won't need the annotations, we'll only need the images. Note that these are color images with [3 color channels (RGB)](https://en.wikipedia.org/wiki/Channel_(digital_image)#RGB_Images) each.

### Pre-process and Load the Data

* Implement the input data transformation
* Create a custom Dataset class that reads the CelebA data

### `get_transforms` function

The `get_transforms` function outputs a [`torchvision.transforms.Compose`](https://pytorch.org/vision/stable/generated/torchvision.transforms.Compose.html#torchvision.transforms.Compose) of different transformations. We have two constraints:
* the function takes a tuple of size as input and should **resize the images** to the input size
* the output images have values **ranging from -1 to 1**

In [None]:
def get_transforms(size: Tuple[int, int]) -> Callable:
    """Transforms to apply to the image."""
    transforms = [
        T.Resize(size),
        T.ToTensor(),
        T.Normalize(
            mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]
        ),  # Scale values to range [-1, 1]
    ]

    return T.Compose(transforms)

### Implement the DatasetDirectory class

The `DatasetDirectory` class is a torch Dataset that reads from the above data directory. The `__getitem__` method outpust a transformed tensor and the `__len__` method outputs the number of files in our dataset. Look at [this custom dataset](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#creating-a-custom-dataset-for-your-files) for ideas. 

In [None]:
import os
from typing import Callable
from torch.utils.data import Dataset
from PIL import Image
import torch


class DatasetDirectory(Dataset):
    """
    A custom dataset class that loads images from a specified folder.

    Args:
    - directory: Location of the images.
    - transforms: Transform function to apply to the images.
    - extension: File format of the images (default is '.jpg').
    """

    def __init__(
        self, directory: str, transforms: Callable = None, extension: str = ".jpg"
    ):
        self.directory = directory
        self.transforms = transforms
        self.extension = extension
        # Get a list of all files in the directory with the specified extension
        self.image_paths = [
            os.path.join(directory, file)
            for file in os.listdir(directory)
            if file.endswith(self.extension)
        ]

    def __len__(self) -> int:
        """Returns the number of items in the dataset."""
        return len(self.image_paths)

    def __getitem__(self, index: int) -> torch.Tensor:
        """Loads an image and applies transformations."""
        # Load the image
        img_path = self.image_paths[index]
        image = Image.open(img_path).convert("RGB")  # Ensure 3 channels

        # Apply transformations if provided
        if self.transforms:
            image = self.transforms(image)

        return image

Initialize `image_size` by 64 and `img_channels` by 3 because we will generate 64 by 64 RGB images.

In [None]:
image_size = 64
img_channels = 3

Downloading the dataset

In [None]:
import urllib.request
import zipfile
from tqdm.notebook import tqdm

# Define the filename, URL, and folder name
filename = "processed-celeba-small.zip"
url = "https://s3.amazonaws.com/video.udacity-data.com/topher/2018/November/5be7eb6f_processed-celeba-small/processed-celeba-small.zip"
folder_name = "processed_celeba_small"


# Download progress callback
def download_progress_hook(t):
    """Download progress bar hook."""
    last_b = [0]

    def update_to(b=1, bsize=1, tsize=None):
        if tsize is not None:
            t.total = tsize
        t.update((b - last_b[0]) * bsize)
        last_b[0] = b

    return update_to


# Check if the folder already exists
if not os.path.exists(folder_name):
    # Download if the folder doesn't exist
    if not os.path.exists(filename):
        print("Downloading file...")
        with tqdm(unit="B", unit_scale=True, unit_divisor=1024) as t:
            urllib.request.urlretrieve(
                url, filename, reporthook=download_progress_hook(t)
            )
        print("Download complete.")
    else:
        print("File already downloaded. Skipping download.")

    # Unzip the file with progress
    print("Unzipping file...")
    with zipfile.ZipFile(filename, "r") as zip_ref:
        for file in tqdm(
            iterable=zip_ref.namelist(),
            total=len(zip_ref.namelist()),
            desc="Extracting",
        ):
            zip_ref.extract(file)
    print("Unzipping complete.")
else:
    print("Data already exists. Skipping download and unzip.")

In [None]:
data_dir = "processed_celeba_small/celeba/"

# run this cell to create the dataset
dataset = DatasetDirectory(data_dir, get_transforms((image_size, image_size)))

Let's visualize images from the dataset.

In [None]:
def denormalize(images):
    """Transform images from [-1.0, 1.0] to [0, 255] and cast them to uint8."""
    return ((images + 1.0) / 2.0 * 255).astype(np.uint8)


# plot the images in the batch, along with the corresponding labels
fig = plt.figure(figsize=(20, 4))
plot_size = 20
for idx in np.arange(plot_size):
    ax = fig.add_subplot(2, int(plot_size / 2), idx + 1, xticks=[], yticks=[])
    img = dataset[idx].numpy()
    img = np.transpose(img, (1, 2, 0))
    img = denormalize(img)
    ax.imshow(img)

## Model implementation

As we already know, a GAN is comprised of two adversarial networks, a discriminator and a generator. Now that we have a working data pipeline, we need to implement the discriminator and the generator. 

### Create the discriminator

The discriminator's job is to score real and fake images. We have two constraints here:
* the discriminator takes as input a **batch of 64x64x3 images**
* the output should be a single value (=score)

Feel free to get inspiration from the different architectures, such as DCGAN, WGAN-GP or DRAGAN.

#### Some tips
* To scale down from the input image, one can either use `Conv2d` layers with the correct hyperparameters or Pooling layers.
* If one is planning on using gradient penalty, then we should not use Batch Normalization layers in the discriminator.

In [None]:
import torch
import torch.nn as nn

The class `WSConv2d` (weighted scaled convolutional layer) to Equalized Learning Rate for the conv layers.

In [None]:
class WSConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = (2 / (in_channels * (kernel_size**2))) ** 0.5
        self.bias = self.conv.bias
        self.conv.bias = None

        # initialize conv layer
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)

The class Discriminator is the same as in [ProGAN paper](https://arxiv.org/pdf/1710.10196).

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.leaky(self.conv1(x))
        x = self.leaky(self.conv2(x))
        return x


class Discriminator(nn.Module):
    def __init__(self, in_channels, img_channels=3):
        super().__init__()
        self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
        self.leaky = nn.LeakyReLU(0.2)

        # here we work back ways from factors because the discriminator
        # should be mirrored from the generator. So the first prog_block and
        # rgb layer we append will work for input size 1024x1024, then 512->256-> etc
        for i in range(len(factors) - 1, 0, -1):
            conv_in = int(in_channels * factors[i])
            conv_out = int(in_channels * factors[i - 1])
            self.prog_blocks.append(ConvBlock(conv_in, conv_out))
            self.rgb_layers.append(
                WSConv2d(img_channels, conv_in, kernel_size=1, stride=1, padding=0)
            )

        # perhaps confusing name "initial_rgb" this is just the RGB layer for 4x4 input size
        # did this to "mirror" the generator initial_rgb
        self.initial_rgb = WSConv2d(
            img_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        self.rgb_layers.append(self.initial_rgb)
        self.avg_pool = nn.AvgPool2d(
            kernel_size=2, stride=2
        )  # down sampling using avg pool

        # this is the block for 4x4 input size
        self.final_block = nn.Sequential(
            # +1 to in_channels because we concatenate from MiniBatch std
            WSConv2d(in_channels + 1, in_channels, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1),
            nn.LeakyReLU(0.2),
            WSConv2d(
                in_channels, 1, kernel_size=1, padding=0, stride=1
            ),  # we use this instead of linear layer
        )

    def fade_in(self, alpha, downscaled, out):
        """Used to fade in downscaled using avg pooling and output from CNN"""
        # alpha should be scalar within [0, 1], and upscale.shape == generated.shape
        return alpha * out + (1 - alpha) * downscaled

    def minibatch_std(self, x):
        batch_statistics = (
            torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
        )
        # we take the std for each example (across all channels, and pixels) then we repeat it
        # for a single channel and concatenate it with the image. In this way the discriminator
        # will get information about the variation in the batch/image
        return torch.cat([x, batch_statistics], dim=1)

    def forward(self, x, alpha, steps):
        # where we should start in the list of prog_blocks, maybe a bit confusing but
        # the last is for the 4x4. So example let's say steps=1, then we should start
        # at the second to last because input_size will be 8x8. If steps==0 we just
        # use the final block
        cur_step = len(self.prog_blocks) - steps

        # convert from rgb as initial step, this will depend on
        # the image size (each will have it's on rgb layer)
        out = self.leaky(self.rgb_layers[cur_step](x))

        if steps == 0:  # i.e, image is 4x4
            out = self.minibatch_std(out)
            return self.final_block(out).view(out.shape[0], -1)

        # because prog_blocks might change the channels, for down scale we use rgb_layer
        # from previous/smaller size which in our case correlates to +1 in the indexing
        downscaled = self.leaky(self.rgb_layers[cur_step + 1](self.avg_pool(x)))
        out = self.avg_pool(self.prog_blocks[cur_step](out))

        # the fade_in is done first between the downscaled and the input
        # this is opposite from the generator
        out = self.fade_in(alpha, downscaled, out)

        for step in range(cur_step + 1, len(self.prog_blocks)):
            out = self.prog_blocks[step](out)
            out = self.avg_pool(out)

        out = self.minibatch_std(out)
        return self.final_block(out).view(out.shape[0], -1)

In the original paper, they initialize z_dim, w_dim, and inchannels by 512. But considering the simple pre-processed data  andto use less VRAM usage and speed-up training, we will be using 128. We could perhaps even get better results if we doubled them.

In [None]:
in_channels = 128
z_dim = 128
w_dim = 128

Let's define a variable with the name `factors` that contain the numbers that will multiply with in_channels to have the number of channels that we want in each image resolution.

In [None]:
factors = [1, 1, 1, 1, 1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 32]

In [None]:
# run this cell to create the Discriminator (critic)
critic = Discriminator(in_channels, img_channels=img_channels).to(device)

In [None]:
critic

### Create the generator

The generator's job creates the "fake images" and learns the dataset distribution. We have three constraints here:
* the generator takes as input a vector of dimension `[batch_size, latent_dimension, 1, 1]`
* the generator must outputs **64x64x3 images**

Feel free to get inspiration from the different architectures, such as DCGAN, WGAN-GP or DRAGAN.

#### Some tips:
* to scale up from the latent vector input, you can use `ConvTranspose2d` layers
* as often with Gan, **Batch Normalization** helps with training

#### The WSLinear (Weighted Scaled Linear) Class

- In the initialization, we pass in `in_features` and `out_channels`. A linear layer is created, followed by defining a scale factor, which is set as the square root of 2 divided by `in_features`. The bias from the current linear layer is stored separately to prevent it from being affected by scaling, after which it is removed. Finally, the linear layer is initialized.
- In the forward method, we pass `x` and simply multiply it by the scale factor, then add the bias after reshaping.

In [None]:
class WSLinear(nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
    ):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.scale = (2 / in_features) ** 0.5
        self.bias = self.linear.bias
        self.linear.bias = None

        # initialize linear layer
        nn.init.normal_(self.linear.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.linear(x * self.scale) + self.bias

#### Adaptive Instance Normalization (AdaIN) Class

- In the initialization method, we pass in `channels` and `w_dim`. We set up an instance normalization layer for normalizing the input and create `style_scale` and `style_bias`, which serve as the adaptive components. These are implemented using `WSLinear`, which maps the Noise Mapping Network (W) to match the `channels`.

- In the forward pass, we pass `x`, apply instance normalization to it, and then return the result of `style_scale * x + style_bias`.

The PixelNorm class to normalize Z before the Noise Mapping Network.

In [None]:
class PixelNorm(nn.Module):
    def __init__(self):
        super(PixelNorm, self).__init__()
        self.epsilon = 1e-8

    def forward(self, x):
        return x / torch.sqrt(torch.mean(x**2, dim=1, keepdim=True) + self.epsilon)

#### MappingNetwork Class

In the initialization method, we pass in `z_dim` and `w_dim`. The network consists of:
- A normalization layer for `z_dim`.
- A sequence of eight fully connected layers using `WSLinear`, each followed by a ReLU activation function, mapping `z_dim` to `w_dim`.

In the forward method, we apply this mapping to the input and return the transformed output.

In [None]:
class MappingNetwork(nn.Module):
    def __init__(self, z_dim, w_dim):
        super().__init__()
        self.mapping = nn.Sequential(
            PixelNorm(),
            WSLinear(z_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
        )

    def forward(self, x):
        return self.mapping(x)

In [None]:
class AdaIN(nn.Module):
    def __init__(self, channels, w_dim):
        super().__init__()
        self.instance_norm = nn.InstanceNorm2d(channels)
        self.style_scale = WSLinear(w_dim, channels)
        self.style_bias = WSLinear(w_dim, channels)

    def forward(self, x, w):
        x = self.instance_norm(x)
        style_scale = self.style_scale(w).unsqueeze(2).unsqueeze(3)
        style_bias = self.style_bias(w).unsqueeze(2).unsqueeze(3)
        return style_scale * x + style_bias

#### InjectNoise Class

- In the initialization method, we pass in `channels`. We initialize `weight` using a random normal distribution and use `nn.Parameter` so that this weight becomes trainable and can be optimized during training.
- In the forward method, we pass an image `x` and return it with added random noise, allowing dynamic variations in the generator's output.

In [None]:
class InjectNoise(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(1, channels, 1, 1))

    def forward(self, x):
        noise = torch.randn((x.shape[0], 1, x.shape[2], x.shape[3]), device=x.device)
        return x + self.weight * noise

#### GenBlock Class

In the initialization method, we pass `in_channels`, `out_channels`, and `w_dim`. The following components are then set up:
- `conv1`: A convolutional layer using `WSConv2d` that maps `in_channels` to `out_channels`.
- `conv2`: Another `WSConv2d` layer that maps `out_channels` to `out_channels`.
- `leaky`: A Leaky ReLU activation with a slope of 0.2, as specified in the paper.
- `inject_noise1` and `inject_noise2`: Instances of `InjectNoise` to add noise at different stages.
- `adain1` and `adain2`: Instances of `AdaIN` to normalize and modulate the features adaptively.

In the forward method, we pass `x` through the following sequence:
1. Apply `conv1`, followed by `inject_noise1`, then apply `leaky`, and normalize with `adain1`.
2. Pass the result through `conv2`, followed by `inject_noise2`, apply `leaky`, and normalize with `adain2`.

Finally, we return the modified `x`.

In [None]:
class GenBlock(nn.Module):
    def __init__(self, in_channels, out_channels, w_dim):
        super(GenBlock, self).__init__()
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2, inplace=True)
        self.inject_noise1 = InjectNoise(out_channels)
        self.inject_noise2 = InjectNoise(out_channels)
        self.adain1 = AdaIN(out_channels, w_dim)
        self.adain2 = AdaIN(out_channels, w_dim)

    def forward(self, x, w):
        x = self.adain1(self.leaky(self.inject_noise1(self.conv1(x))), w)
        x = self.adain2(self.leaky(self.inject_noise2(self.conv2(x))), w)
        return x

In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim, w_dim, in_channels, img_channels=3):
        super(Generator, self).__init__()
        self.starting_constant = nn.Parameter(torch.ones((1, in_channels, 4, 4)))
        self.map = MappingNetwork(z_dim, w_dim)
        self.initial_adain1 = AdaIN(in_channels, w_dim)
        self.initial_adain2 = AdaIN(in_channels, w_dim)
        self.initial_noise1 = InjectNoise(in_channels)
        self.initial_noise2 = InjectNoise(in_channels)
        self.initial_conv = nn.Conv2d(
            in_channels, in_channels, kernel_size=3, stride=1, padding=1
        )
        self.leaky = nn.LeakyReLU(0.2, inplace=True)

        self.initial_rgb = WSConv2d(
            in_channels, img_channels, kernel_size=1, stride=1, padding=0
        )
        self.prog_blocks, self.rgb_layers = (
            nn.ModuleList([]),
            nn.ModuleList([self.initial_rgb]),
        )

        for i in range(
            len(factors) - 1
        ):  # -1 to prevent index error because of factors[i+1]
            conv_in_c = int(in_channels * factors[i])
            conv_out_c = int(in_channels * factors[i + 1])
            self.prog_blocks.append(GenBlock(conv_in_c, conv_out_c, w_dim))
            self.rgb_layers.append(
                WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0)
            )

    def fade_in(self, alpha, upscaled, generated):
        # alpha should be scalar within [0, 1], and upscale.shape == generated.shape
        return torch.tanh(alpha * generated + (1 - alpha) * upscaled)

    def forward(self, noise, alpha, steps):
        w = self.map(noise)
        x = self.initial_adain1(self.initial_noise1(self.starting_constant), w)
        x = self.initial_conv(x)
        out = self.initial_adain2(self.leaky(self.initial_noise2(x)), w)

        if steps == 0:
            return self.initial_rgb(x)

        for step in range(steps):
            upscaled = F.interpolate(out, scale_factor=2, mode="bilinear")
            out = self.prog_blocks[step](upscaled, w)

        # The number of channels in upscale will stay the same, while
        # out which has moved through prog_blocks might change. To ensure
        # we can convert both to rgb we use different rgb_layers
        # (steps-1) and steps for upscaled, out respectively
        final_upscaled = self.rgb_layers[steps - 1](upscaled)
        final_out = self.rgb_layers[steps](out)
        return self.fade_in(alpha, final_upscaled, final_out)

In [None]:
# run this cell to create the generator
generator = Generator(z_dim, w_dim, in_channels, img_channels=img_channels).to(device)
generator

## Optimizer

In the following section, we create the optimizers for the generator and discriminator.

### Implement the optimizers

In [None]:
g_lr = 0.001
d_lr = 0.001
g_betas = (0.0, 0.99)
d_betas = (0.0, 0.99)

In [None]:
import torch.optim as optim


def create_optimizers(generator: nn.Module, discriminator: nn.Module):
    """This function should return the optimizers of the generator and the discriminator"""
    # TODO: implement the generator and discriminator optimizers
    g_optimizer = optim.Adam(
        [
            {
                "params": [
                    param
                    for name, param in generator.named_parameters()
                    if "map" not in name
                ]
            },
            {"params": generator.map.parameters(), "lr": 1e-5},
        ],
        lr=g_lr,
        betas=g_betas,
    )
    d_optimizer = optim.Adam(discriminator.parameters(), lr=d_lr, betas=d_betas)
    return g_optimizer, d_optimizer


g_optimizer, d_optimizer = create_optimizers(generator, critic)

## Losses implementation

In this section, we will implement the loss function for the generator and the discriminator. 

Some tips:
* Choose the commonly used the binary cross entropy loss or select other losses, such as the Wasserstein distance.
* Implement a gradient penalty function.

### Implement the generator loss

The generator's goal is to get the discriminator to think its generated images (= "fake" images) are real.

In [None]:
def generator_loss(fake_logits):
    """
    Computes the loss for the generator in a GAN using the Wasserstein loss formulation.

    Parameters:
    - fake_logits (Tensor): Logits from the discriminator for generated (fake) images.

    Returns:
    - Tensor: Calculated generator loss.
    """
    # The generator aims to maximize the discriminator's response on fake images.
    # In Wasserstein GANs, this corresponds to maximizing the mean of fake logits.
    # Using negative sign to achieve a minimization objective for gradient descent.
    return -torch.mean(fake_logits)

### Implement the discriminator loss

We want the discriminator to give high scores to real images and low scores to fake ones and the discriminator loss should reflect that.

In [None]:
def discriminator_loss(real_logits, fake_logits, gp, lambda_gp=10, drift_penalty=0.001):
    """
    Computes the loss for the discriminator in a GAN, incorporating real and fake logits,
    gradient penalty, and a drift penalty.

    Parameters:
    - real_logits (Tensor): Logits from the discriminator for real images.
    - fake_logits (Tensor): Logits from the discriminator for generated (fake) images.
    - gp (Tensor): Gradient penalty term to enforce the Lipschitz constraint.
    - lambda_gp (float): Weight for the gradient penalty term. Default is 10.
    - drift_penalty (float): Weight for the drift penalty term to stabilize training. Default is 0.001.

    Returns:
    - Tensor: Calculated discriminator loss.
    """
    # Wasserstein loss for real and fake logits
    wasserstein_loss = -(torch.mean(real_logits) - torch.mean(fake_logits))

    # Gradient penalty weighted by lambda_gp
    gradient_penalty_term = lambda_gp * gp

    # Drift penalty to regularize the discriminator on real logits (helps prevent exploding gradients)
    drift_penalty_term = drift_penalty * torch.mean(real_logits**2)

    # Total discriminator loss
    return wasserstein_loss + gradient_penalty_term + drift_penalty_term

### Implement the gradient Penalty

We know the importance of gradient penalty in training certain types of Gans.

In [None]:
def gradient_penalty(
    discriminator, real_samples, fake_samples, alpha, train_step, device="cpu"
):
    """
    Calculates the gradient penalty for Wasserstein GAN with gradient penalty (WGAN-GP).

    Parameters:
    - discriminator (nn.Module): The discriminator model.
    - real_samples (Tensor): Batch of real images.
    - fake_samples (Tensor): Batch of fake images generated by the generator.
    - alpha (float): Mixing factor for progressive growing (usually between 0 and 1).
    - train_step (int): Current training step, used for updating the discriminator progressively.
    - device (str): Device to perform calculations on (e.g., 'cpu' or 'cuda').

    Returns:
    - Tensor: Calculated gradient penalty.
    """
    batch_size, channels, height, width = real_samples.shape

    # Random weight for interpolation between real and fake samples
    beta = torch.rand((batch_size, 1, 1, 1), device=device).expand_as(real_samples)
    interpolated_images = real_samples * beta + fake_samples.detach() * (1 - beta)
    interpolated_images.requires_grad_(True)

    # Compute discriminator scores on interpolated images
    mixed_scores = discriminator(interpolated_images, alpha, train_step)

    # Calculate gradients of scores with respect to interpolated images
    gradient = torch.autograd.grad(
        outputs=mixed_scores,
        inputs=interpolated_images,
        grad_outputs=torch.ones_like(mixed_scores, device=device),
        create_graph=True,
        retain_graph=True,
    )[0]

    # Flatten gradients per example and compute L2 norm
    gradient = gradient.view(batch_size, -1)
    gradient_norm = gradient.norm(2, dim=1)

    # Gradient penalty term enforcing the gradient norm to be close to 1
    gp = torch.mean((gradient_norm - 1) ** 2)

    return gp

## Training


Training will involve alternating between training the discriminator and the generator.

* We will train the discriminator by alternating on real and fake images
* Then the generator, which tries to trick the discriminator and therefore have an opposing loss function

### Implement the generator step and the discriminator step functions

Each function do the following:
* calculate the loss
* backpropagate the gradient
* perform one optimizer step

In [None]:
from typing import Dict


def generator_step(
    generator,
    discriminator,
    batch_size: int,
    latent_dim: int,
    alpha: float,
    step: int,
    gen_optimizer,
) -> Dict[str, float]:
    """
    One training step of the generator.

    Parameters:
    - generator (nn.Module): The generator model.
    - discriminator (nn.Module): The discriminator model.
    - batch_size (int): Current batch size.
    - latent_dim (int): Dimension of the latent space.
    - alpha (float): Mixing factor for progressive growing (usually between 0 and 1).
    - step (int): Current training step, used for progressive growing.
    - gen_optimizer (Optimizer): Optimizer for the generator.

    Returns:
    - Dict[str, float]: Dictionary containing the generator loss.
    """
    # Sample random noise as input for the generator
    noise = torch.randn(batch_size, latent_dim).to(device)

    # Generate fake images
    fake_images = generator(noise, alpha, step)

    # Compute the generator loss
    fake_logits = discriminator(fake_images, alpha, step)
    g_loss = generator_loss(fake_logits)

    # Backpropagation and optimization step
    gen_optimizer.zero_grad()
    g_loss.backward()
    gen_optimizer.step()

    return {"loss": g_loss.item()}


def discriminator_step(
    discriminator,
    generator,
    real_images: torch.Tensor,
    batch_size: int,
    latent_dim: int,
    alpha: float,
    step: int,
    lambda_gp: float,
    disc_optimizer,
    drift_penalty=0.001,
) -> Dict[str, float]:
    """
    One training step of the discriminator.

    Parameters:
    - discriminator (nn.Module): The discriminator model.
    - generator (nn.Module): The generator model.
    - real_images (Tensor): Batch of real images.
    - batch_size (int): Current batch size.
    - latent_dim (int): Dimension of the latent space.
    - alpha (float): Mixing factor for progressive growing (usually between 0 and 1).
    - step (int): Current training step, used for progressive growing.
    - lambda_gp (float): Weight for the gradient penalty.
    - disc_optimizer (Optimizer): Optimizer for the discriminator.

    Returns:
    - Dict[str, float]: Dictionary containing the discriminator loss and gradient penalty.
    """
    # Move real images to device
    real_images = real_images.to(device)

    # Generate fake images with the generator
    noise = torch.randn(batch_size, latent_dim).to(device)
    fake_images = generator(noise, alpha, step).detach()

    # Discriminator logits for real and fake images
    real_logits = discriminator(real_images, alpha, step)
    fake_logits = discriminator(fake_images, alpha, step)

    # Compute gradient penalty
    gp = gradient_penalty(
        discriminator, real_images, fake_images, alpha, step, device=device
    )

    # Discriminator loss: Wasserstein loss with gradient penalty and drift penalty
    d_loss = discriminator_loss(real_logits, fake_logits, gp, lambda_gp, drift_penalty)

    # Backpropagation and optimization step
    disc_optimizer.zero_grad()
    d_loss.backward()
    disc_optimizer.step()

    return {"loss": d_loss.item(), "gp": gp.item()}

### Main training loop

In [None]:
from datetime import datetime
from math import log2

In [None]:
# number of images in each batch
batch_sizes = [64, 32, 16, 8]

progressive_epochs = [30] * len(batch_sizes)

In [None]:
def get_loader(image_size):
    transform = get_transforms((image_size, image_size))
    batch_size = batch_sizes[int(log2(image_size / 4))]
    dataset = datasets.ImageFolder(root=data_dir, transform=transform)
    loader = DataLoader(
        dataset, batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=False
    )
    return loader, dataset

In [None]:
def display(fixed_latent_vector: torch.Tensor):
    """helper function to display images during training"""
    fig = plt.figure(figsize=(14, 4))
    plot_size = 16
    for idx in np.arange(plot_size):
        ax = fig.add_subplot(2, int(plot_size / 2), idx + 1, xticks=[], yticks=[])
        img = fixed_latent_vector[idx, ...].detach().cpu().numpy()
        img = np.transpose(img, (1, 2, 0))
        img = denormalize(img)
        ax.imshow(img)
    plt.show()

### Implement the training strategy

In [None]:
start_train_at_img_size = 8
lambda_gp = 10
alpha = 1e-5  # start with very low alpha

In [None]:
from torchvision.utils import save_image


def generate_examples(gen, steps, n=100):
    gen.eval()
    alpha = 1.0
    for i in range(n):
        with torch.no_grad():
            noise = torch.randn(1, z_dim).to(device)
            img = gen(noise, alpha, steps)
            if not os.path.exists(f"saved_examples/step{steps}"):
                os.makedirs(f"saved_examples/step{steps}")
            save_image(img * 0.5 + 0.5, f"saved_examples/step{steps}/img_{i}.png")
    gen.train()

In [None]:
# Start at the step that corresponds to the initial image size in config
step = int(log2(start_train_at_img_size / 4))
losses = []  # To store losses for plotting later
fixed_latent_vector = (
    torch.randn(16, z_dim).float().to(device)
)  # Fixed vector for consistent image generation
print_every = 200  # Frequency of printing loss updates

for num_epochs in progressive_epochs[step:]:
    alpha = 1e-5  # Start with a very low alpha
    loader, dataset = get_loader(4 * 2**step)
    print(f"Current image size: {4 * 2 ** step}")

    for epoch in range(num_epochs):
        # Training loop for each batch
        for batch_idx, (real_images, _) in enumerate(loader):
            cur_batch_size = real_images.shape[0]

            # Discriminator Step
            d_loss = discriminator_step(
                discriminator=critic,
                generator=generator,
                real_images=real_images,
                batch_size=cur_batch_size,
                latent_dim=z_dim,
                alpha=alpha,
                step=step,
                lambda_gp=lambda_gp,
                disc_optimizer=d_optimizer,
            )

            # Generator Step
            g_loss = generator_step(
                generator=generator,
                discriminator=critic,
                batch_size=cur_batch_size,
                latent_dim=z_dim,
                alpha=alpha,
                step=step,
                gen_optimizer=g_optimizer,
            )

            # Update alpha based on the current batch size
            alpha += cur_batch_size / ((progressive_epochs[step] * 0.5) * len(dataset))
            alpha = min(alpha, 1)  # Ensure alpha does not exceed 1

            # Store losses and print them periodically
            if batch_idx % print_every == 0:
                losses.append(
                    (d_loss["loss"], g_loss["loss"])
                )  # Append both losses for later plotting
                time = str(datetime.now()).split(".")[0]
                print(
                    f"{time} | Epoch [{epoch+1}/{num_epochs}] | Batch {batch_idx}/{len(loader)} | "
                    f"d_loss: {d_loss['loss']:.4f} | g_loss: {g_loss['loss']:.4f} | gp: {d_loss['gp']:.4f}"
                )

        # Generate and display images at the end of each epoch
        generator.eval()
        with torch.no_grad():
            generated_images = generator(fixed_latent_vector, alpha, step)
            display(generated_images)
        generator.train()

    step += 1  # Progress to the next image size for the generator and discriminator

### Training losses

Plot the training losses for the generator and discriminator.

In [None]:
fig, ax = plt.subplots()
losses = np.array(losses)
plt.plot(losses.T[0], label="Discriminator", alpha=0.5)
plt.plot(losses.T[1], label="Generator", alpha=0.5)
plt.title("Training Losses")
plt.legend()