# Imports
Make sure you have the packages installed. You might need to pip install some of them, the console should tell you if you are missing any packages when you try to run the code.

In [None]:
# basics
import numpy as np
import matplotlib.pyplot as plt

# PyTorch and core machine learning libraries
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Libraries for data loading and pre-processing
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader

# for displaying images
from torchvision.utils import make_grid
from torchvision.transforms.functional import to_pil_image

# libraries used for quantative metrics calculation
from torchvision.utils import save_image
from torch_fidelity import calculate_metrics
import os

# Helper function for displaying images

In [None]:
def display_images(images, labels=None, title="", num_samples=20, cols=4):
    # ensure we don't exceed the number of available images
    images = images[:min(num_samples, len(images))]

    # create a grid of images
    # normalize each image
    grid = make_grid(images, nrow=cols, normalize=True, scale_each=True)  # Adjust grid columns
    
    # convert the grid to a PIL image
    grid_img = to_pil_image(grid)
    
    # plot
    plt.figure(figsize=(12, 12))  # You can adjust the figure size as needed
    plt.imshow(grid_img, cmap="gray")
    plt.title(title, fontsize=20)
    plt.axis('off')
    
    # if labels are provided, display them (note: labels arent displayed very well)
    if labels is not None:
        num_images = len(images)
        rows = (num_images + cols - 1) // cols  # Calculate the number of rows in the grid
        for i, label in enumerate(labels[:num_images]):
            plt.text(
                (i % cols) * grid_img.width / cols, 
                (i // cols + 1) * grid_img.height / rows - 10,  # Adjust text position
                label, 
                horizontalalignment='center',
                fontsize=10,
                color='white',
                weight='bold'
            )
    plt.show()

In [None]:
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device Type: {device} " + (f"| Name: {torch.cuda.get_device_name(0)}" if torch.cuda.is_available() else "CPU"))

# Unconditional Discriminator

In [None]:
# discriminator model to determine if an image is real or fake
class Discriminator(nn.Module):
    def __init__(self, in_channels=1):
        super(Discriminator, self).__init__()
        # define the architecture of the discriminator
        self.model = nn.Sequential(
            # first convolutional downsampling
            nn.Conv2d(in_channels, 128, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            # second convolutional downsampling
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            # flattening the output for the linear
            nn.Flatten(),
            nn.Dropout(0.4),
            # final classifier layer
            nn.Linear(128 * 7 * 7, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        # forward pass through the model
        return self.model(x)

# Unconditional Generator

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim # the dimensions of the latent noise vector
        
        self.model = nn.Sequential(
            # foundation for 7x7 image
            nn.Linear(latent_dim, 128 * 7 * 7),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm1d(128 * 7 * 7),
            nn.Unflatten(1, (128, 7, 7)),
            # first upsampling with fractional strided convolution to 14x14
            nn.ConvTranspose2d(128, 128, 4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(128),
            # second upsampling with fractional strided convolution to 28x28
            nn.ConvTranspose2d(128, 128, 4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(128),
            # final convolution to generate the image
            nn.Conv2d(128, 1, 7, padding=3, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        # forward pass through the model
        return self.model(z)

In [None]:
# we dont necessarily need this function but it improves readability
# this function takes a value for the latent dimension and the number of samples to generate and creates latent noise vectors
# to feed into our GAN for image generation
def generate_latent_points(latent_dim, n_samples, device='cpu'):
    # generate points in the latent space
    x_input = torch.randn(n_samples, latent_dim, device=device)
    return x_input

# Adversarial training

In [None]:
from tqdm import tqdm

def train(generator, discriminator, dataloader, latent_dim, device=device, n_epochs=100, n_batch=128):
    generator.to(device)
    discriminator.to(device)
    # initialize separate optimizers for the generator and discriminator networks
    optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    criterion = torch.nn.BCELoss()

    # lists to track loss
    losses_g = []
    losses_d = []

    # loop through epochs
    for epoch in range(n_epochs):
        loss_g_accum = 0.0
        loss_d_accum = 0.0
        num_batches = 0

        # wrap dataloader with tqdm for a progress bar
        loop = tqdm(dataloader, leave=True)
        for imgs, _ in loop:
            current_batch_size = imgs.size(0)
            if current_batch_size != n_batch:  # Skip incomplete batches
                continue

            num_batches += 1

            real_labels = torch.ones(current_batch_size, 1, device=device)
            fake_labels = torch.zeros(current_batch_size, 1, device=device)

            # train Discriminator
            optimizer_d.zero_grad()
            real_imgs = imgs.to(device)
            real_loss = criterion(discriminator(real_imgs), real_labels)
            noise = torch.randn(current_batch_size, latent_dim, device=device)
            fake_imgs = generator(noise).detach()
            fake_loss = criterion(discriminator(fake_imgs), fake_labels)
            d_loss = (real_loss + fake_loss) / 2
            d_loss.backward()
            optimizer_d.step()
            loss_d_accum += d_loss.item()

            # train Generator
            optimizer_g.zero_grad()
            gen_imgs = generator(noise)
            g_loss = criterion(discriminator(gen_imgs), real_labels)
            g_loss.backward()
            optimizer_g.step()
            loss_g_accum += g_loss.item()

            # update the progress bar description
            loop.set_description(f"Epoch [{epoch+1}/{n_epochs}]")
            loop.set_postfix(D_loss=d_loss.item(), G_loss=g_loss.item())

        # calculate average losses for the current epoch and append to lists
        avg_loss_d = loss_d_accum / num_batches
        avg_loss_g = loss_g_accum / num_batches
        losses_d.append(avg_loss_d)
        losses_g.append(avg_loss_g)

        # generate and display images at the end of the epoch
        with torch.no_grad():
            test_noise = torch.randn(10, latent_dim, device=device)
            test_imgs = generator(test_noise)
            display_images(test_imgs, labels=None, title="", num_samples=10, cols=10)

    # save the generator model
    torch.save(generator.state_dict(), 'dcgan.pt')

    # plot the losses after training
    plt.figure(figsize=(10, 5))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(losses_g, label="Generator")
    plt.plot(losses_d, label="Discriminator")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

In [None]:
# dimension of the latent space
latent_dim = 100

# initialize the models
discriminator = Discriminator()
generator = Generator(latent_dim)

# load and pre-process the image data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
])

train_dataset = datasets.FashionMNIST(root='../data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

In [None]:
# train the model
n_epochs = 300
train(generator, discriminator, train_loader, latent_dim, n_epochs = n_epochs)

# Unconditional sampling

In [None]:
# optionally load in saved model
# latent_dim = 100
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = Generator(latent_dim)
# model.load_state_dict(torch.load('model path'))
# model.eval()  # Set the generator to evaluation mode
# model.to(device)

# Generate images
n_samples = 50 # number of samples to generate
#create latent noise vectors
latent_points = generate_latent_points(100, n_samples, device=device)
with torch.no_grad():
    generated_images = generator(latent_points) # generate images

# display generated images
display_images(generated_images, labels=None, title="Unconditional DCGAN Generations", num_samples=n_samples, cols=10)

# Quantative Metrics Calculation (FID)

In [None]:
# takes a tensor of images and saves them to a specified folder
def save_images(images, folder="saved_images"):
    if not isinstance(images, torch.Tensor):
        raise ValueError("Images should be a PyTorch tensor")

    images = images.detach().cpu()

    # normalize and prepare images
    os.makedirs(folder, exist_ok=True)
    for i, img_tensor in enumerate(images):
        img = img_tensor.squeeze()  # Remove color channels if 1
        # Since imshow handles normalization for display, ensure the saved images mimic this display behavior
        # here, assuming images are normalized to [0, 1] for grayscale as in show_images
        img_np = img.numpy()
        plt.imsave(os.path.join(folder, f'image_{i}.png'), img_np, cmap='gray')

# computes the FID score between two sets of images at the passed paths
def compute_fid(real_images_path, fake_images_path):
    metrics = calculate_metrics(input1=real_images_path, input2=fake_images_path, cuda=True, isc=False, fid=True, kid=False, samples_find_deep=True)
    return metrics['frechet_inception_distance']

In [None]:
# optionally load in saved model
# latent_dim = 100
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = Generator(latent_dim)
# model.load_state_dict(torch.load('model path'))
# model.eval()  # Set the generator to evaluation mode
# model.to(device)

n_loops = 5
n_samples = 10000

for loop in range(n_loops):
    # generate n_samples images using your model
    latent_points = generate_latent_points(100, n_samples, device=device)
    # generate images
    with torch.no_grad():
        generated_images = generator(latent_points)

    # folder to save images in a loop-specific subfolder
    loop_folder = os.path.join("generated_images_dcgan", f'set_{loop + 1}')
    save_images(generated_images, loop_folder)

In [None]:
# load FashionMNIST images
dataset = datasets.FashionMNIST(root="../data", train=True, transform=ToTensor(), download=True)
real_images = torch.stack([dataset[i][0] for i in range(n_samples * n_loops)]) # Stack images to create a single tensor

# save the first real n_samples * n_loops images
save_images(real_images, "real_images_dcgan")

In [None]:
# compute and print FID score
fid_score = compute_fid("real_images_dcgan", "generated_images_dcgan")
print(f"FID Score: {fid_score}")