# Imports
Make sure you have the packages installed. You might need to pip install some of them, the console should tell you if you are missing any packages when you try to run the code.

In [None]:
# basics
import numpy as np
import matplotlib.pyplot as plt

# PyTorch and core machine learning libraries
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Libraries for data loading and pre-processing
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader

# for displaying images
from torchvision.utils import make_grid
from torchvision.transforms.functional import to_pil_image

# Imports for quantative metrics calculations
from torchvision.utils import save_image
from torch_fidelity import calculate_metrics
import os

# Helper function for displaying images

In [None]:
def display_images(images, labels=None, title="", num_samples=20, cols=4):
    # ensure we don't exceed the number of available images
    images = images[:min(num_samples, len(images))]

    # create a grid of images
    # normalize each image
    grid = make_grid(images, nrow=cols, normalize=True, scale_each=True)  # Adjust grid columns
    
    # convert the grid to a PIL image
    grid_img = to_pil_image(grid)
    
    # plot
    plt.figure(figsize=(12, 12))  # You can adjust the figure size as needed
    plt.imshow(grid_img, cmap="gray")
    plt.title(title, fontsize=20)
    plt.axis('off')
    
    # if labels are provided, display them (note: labels arent displayed very well)
    if labels is not None:
        num_images = len(images)
        rows = (num_images + cols - 1) // cols  # Calculate the number of rows in the grid
        for i, label in enumerate(labels[:num_images]):
            plt.text(
                (i % cols) * grid_img.width / cols, 
                (i // cols + 1) * grid_img.height / rows - 10,  # Adjust text position
                label, 
                horizontalalignment='center',
                fontsize=10,
                color='white',
                weight='bold'
            )
    plt.show()

In [None]:
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device Type: {device} " + (f"| Name: {torch.cuda.get_device_name(0)}" if torch.cuda.is_available() else "CPU"))

# Class-Conditioned Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=1, n_classes=10):
        super(Discriminator, self).__init__()
        self.label_embedding = nn.Embedding(n_classes, 28*28)  # each label is embedded to match the image size (28x28)

        self.model = nn.Sequential(
            # First convolutional downsampling
            nn.Conv2d(in_channels + 1, 128, kernel_size=3, stride=2, padding=1),  # input has two channels now
            nn.LeakyReLU(0.2, inplace=True),
            # Second convolutional downsampling
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            # Flatenning the output for the dense layer
            nn.Flatten(),
            nn.Dropout(0.4),
            # Final classifier layer
            nn.Linear(128 * 7 * 7, 1),
            nn.Sigmoid()
        )
    
    def forward(self, img, labels):
        # creating label embeddings
        label_embeddings = self.label_embedding(labels)
        label_embeddings = label_embeddings.view(img.size(0), 1, 28, 28)  # Reshape embeddings to (B, 1, 28, 28)

        # concatenating label embeddings and image
        combined_input = torch.cat([img, label_embeddings], dim=1)  # Concatenate along the channel dimension

        # forward pass through the model
        return self.model(combined_input)

# Class-conditioned Generator

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim, n_classes=10):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim
        
        # embedding and linear layer to match label dimensions to feature map dimensions
        self.label_embedding = nn.Embedding(n_classes, 50)
        self.fc_label = nn.Linear(50, 7 * 7)  # resize from embedding size to 7x7 feature map
        
        # transform the latent vector
        self.fc_latent = nn.Sequential(
            nn.Linear(latent_dim, 128 * 7 * 7),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(128 * 7 * 7)
        )
        
        # Multiple fractionally strided convolutions to upsample from 7x7 to 28x28
        self.conv_transpose_layers = nn.Sequential(
            nn.ConvTranspose2d(129, 128, kernel_size=4, stride=2, padding=1),  # Note: 129 channels, concatenation of gen + label feature maps
            nn.LeakyReLU(0.2),
            nn.BatchNorm2d(128),
            nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 1, kernel_size=7, padding=3),
            nn.Tanh()
        )

    def forward(self, z, labels):
        # create and process label embeddings
        labels_embedded = self.label_embedding(labels)
        labels_transformed = self.fc_label(labels_embedded).view(-1, 1, 7, 7)
        
        # process the initial latent vector
        latent_transformed = self.fc_latent(z).view(-1, 128, 7, 7)
        
        # concatenate label embeddings and transofmred latent vector along the channel dimension
        combined_input = torch.cat([latent_transformed, labels_transformed], dim=1)
        
        # pass through the transposed convolution layers
        return self.conv_transpose_layers(combined_input)

In [None]:
# modified to now also return random labels for gan training
def generate_latent_points(latent_dim, batch_size, n_classes=10, device='cpu'):
    # generate points in the latent space
    z_input = torch.randn(batch_size, latent_dim, device=device)
    # generate random labels
    labels = torch.randint(0, n_classes, (batch_size,), device=device)
    return z_input, labels

# Adversarial training

In [None]:
from tqdm import tqdm

def train(generator, discriminator, dataset_loader, device, latent_dim, n_epochs=100, n_batch=128, n_classes=10):
    generator.to(device)
    discriminator.to(device)
    # separate adam optimizers for generator and discriminator
    optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    criterion = torch.nn.BCELoss()

    # lists to track loss
    losses_g = []
    losses_d = []

    # loop through epochs
    for epoch in range(n_epochs):
        loss_g_accum = 0.0
        loss_d_accum = 0.0
        num_batches = 0

        loop = tqdm(dataset_loader, leave=True)
        for imgs, labels in loop:
            current_batch_size = imgs.size(0)
            if current_batch_size != n_batch:  # Skip incomplete batches
                continue

            num_batches += 1
            # create real and fake labels for discrminator training
            real_labels = torch.ones(current_batch_size, 1, device=device)
            fake_labels = torch.zeros(current_batch_size, 1, device=device)

            # train Discriminator
            optimizer_d.zero_grad()
            real_imgs = imgs.to(device)
            labels_real = labels.to(device).unsqueeze(1)  # Adjust the label dimension if necessary
            real_loss = criterion(discriminator(real_imgs, labels_real), real_labels)

            z_input, gen_labels = generate_latent_points(latent_dim, current_batch_size, n_classes, device)
            fake_imgs = generator(z_input, gen_labels.unsqueeze(1)).detach()
            fake_loss = criterion(discriminator(fake_imgs, gen_labels.unsqueeze(1)), fake_labels)

            d_loss = (real_loss + fake_loss) / 2
            d_loss.backward()
            optimizer_d.step()
            loss_d_accum += d_loss.item()

            # train Generator
            optimizer_g.zero_grad()
            gen_imgs = generator(z_input, gen_labels.unsqueeze(1))
            g_loss = criterion(discriminator(gen_imgs, gen_labels.unsqueeze(1)), real_labels)
            g_loss.backward()
            optimizer_g.step()
            loss_g_accum += g_loss.item()

            loop.set_description(f"Epoch [{epoch+1}/{n_epochs}]")
            loop.set_postfix(D_loss=d_loss.item(), G_loss=g_loss.item())

        # calculate and store average losses for generator and discriminator for this epoch
        avg_loss_d = loss_d_accum / num_batches
        avg_loss_g = loss_g_accum / num_batches
        losses_d.append(avg_loss_d)
        losses_g.append(avg_loss_g)

        # visualization at the end of each epoch
        with torch.no_grad():
            # generate one latent point per class
            z = torch.randn(n_classes, latent_dim, device=device)  # n_classes latent points
            labels = torch.arange(0, n_classes, device=device)  # One label for each class
            
            # generate images
            generated_images = generator(z, labels.unsqueeze(1))
            display_images(generated_images, labels=None, title="", num_samples=10, cols=10)

    torch.save(generator.state_dict(), 'conditional_gan_300.pt')

    plt.figure(figsize=(10, 5))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(losses_g, label="Generator")
    plt.plot(losses_d, label="Discriminator")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

In [None]:
# Dimension of the latent space
latent_dim = 100
# Number of classes in the dataset (for FashionMNIST, it's 10)
n_classes = 10

generator = Generator(latent_dim=latent_dim, n_classes=n_classes).to(device) # Create the generator
discriminator = Discriminator(n_classes=n_classes).to(device) # Create the discriminator

# Define data transformation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.FashionMNIST(root='../data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

In [None]:
n_epochs = 300
n_batch=128
train(generator, discriminator, train_loader, device, latent_dim, n_epochs, n_batch, n_classes)

In [None]:
# generate n_samples images of target_class
def generate_images_for_class(generator, latent_dim, target_class, n_samples, device):
    generator.to(device)
    # generate points in the latent space
    # z = torch.randn(n_samples, latent_dim, device=device)
    z, _ = generate_latent_points(latent_dim, n_samples, n_classes=10, device=device)
    # create labels for the target class
    labels = torch.full((n_samples,), target_class, dtype=torch.long, device=device)
    # generate and return images
    with torch.no_grad():
        generated_images = generator(z, labels)
    return generated_images

# Targeted sampling

In [None]:
generator.eval()

# Example usage
target_class = 1  # For example, class '3'
n_samples = 10
generated_images = generate_images_for_class(generator, latent_dim, target_class, n_samples, device)

display_images(generated_images, labels=None, title="", num_samples=10, cols=10)

In [None]:
# Option to load in saved model
# latent_dim = 100
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = Generator(latent_dim)
# model.load_state_dict(torch.load('model path'))
# model.eval()  # Set the generator to evaluation mode
# model.to(device)

all_generated = []

for i in range(10):
    target_class = i  # For example, class '3'
    n_samples = 5
    generated_images = generate_images_for_class(generator, latent_dim, target_class, n_samples, device)
    all_generated.append(generated_images)

all_generated = torch.cat(all_generated, dim=0)
display_images(all_generated, labels=None, title="", num_samples=n_samples*10, cols=10)

# Quantitative Metrics Calculation (FID)

In [None]:
def save_images(images, folder):
    """
    Saves the tensor images to the specified folder after ensuring they are in the correct format for saving.
    Adjusted for grayscale images.
    """
    os.makedirs(folder, exist_ok=True)
    for i, batch in enumerate(images):  # Iterate over batches (each batch corresponds to a class)
        for j, img_tensor in enumerate(batch):  # Iterate over images in each batch
            # Ensure the image is in CPU memory and squeeze to remove the channel dimension for grayscale
            img = img_tensor.cpu().squeeze().numpy()
            plt.imsave(os.path.join(folder, f'class_{i}_image_{j}.png'), img, cmap='gray')

def compute_fid(real_images_path, fake_images_path):
    """
    Computes the FID score between two sets of images located at the given paths.
    """
    metrics = calculate_metrics(input1=real_images_path, input2=fake_images_path, cuda=True, isc=False, fid=True, kid=False)
    return metrics['frechet_inception_distance']

In [None]:
# Option to load in saved model
# latent_dim = 100
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = Generator(latent_dim)
# model.load_state_dict(torch.load('path to .pt file'))
# model.eval()  # Set the generator to evaluation mode
# model.to(device)

all_generated_images = []

n_samples = 5000
n_classes = 10

# generate and collect images
with torch.no_grad():
    for target_class in range(n_classes):
        # generate latent points and corresponding labels for the target class
        z, labels = generate_latent_points(latent_dim, n_samples, n_classes, device)
        
        # since we need images for a specific class, adjust labels to be all 'target_class'
        labels.fill_(target_class)
        
        # generate images using the model,the generated z, and the labels we specified
        generated_images = generator(z, labels.unsqueeze(1))
        all_generated_images.append(generated_images)

# concatenate all images along the first dimension
all_generated_images = torch.cat(all_generated_images)

# Save the images
save_images(all_generated_images, "generated_images_cdcgan")

In [None]:
# Load FashionMNIST dataset
dataset = datasets.FashionMNIST(root="../data", train=True, transform=ToTensor(), download=True)

n_samples = 5000
n_classes = 10

# initialize a list to hold selected images ensuring n_samples images per class
selected_images = []

for class_id in range(n_classes):
    # Filter indices for the current class
    class_indices = [i for i, (_, label) in enumerate(dataset) if label == class_id]
    # Randomly select n_samples indices for the current class without replacement
    selected_indices = np.random.choice(class_indices, n_samples, replace=False)
    
    # Append selected images to the list
    for idx in selected_indices:
        image, _ = dataset[idx]
        selected_images.append(image)

# Convert list of selected images to a tensor
real_images = torch.stack(selected_images)

# Save the selected images
save_images(real_images, "real_images_cdcgan")

In [None]:
# Compute FID score
fid_score = compute_fid("real_images_cdcgan", "generated_images_cdcgan")
print(f"FID Score: {fid_score}")