In [None]:
import os
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms.functional import to_pil_image
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim
import torch.nn as nn
import torch
from torch.nn.functional import one_hot, leaky_relu, interpolate
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import random
import math
from pathlib import Path
from typing import Union, List
from tqdm.notebook import trange
from torchinfo import summary

## Configuration & Hyperparameters

In [None]:
# # conf
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
TRAIN_CGAN = True
TRAIN_RESNET = True
TRAIN_AUGMENTED_RESNET = True

RESNET = "resnet"
CONV = "conv"
architecture = os.getenv('ARCH', RESNET)  # resnet or conv
tag = os.getenv("TAG", None)
tag = f"_{tag}" if tag else ""
# num of workers for the dataloader
num_workers = 4 if architecture == RESNET else 8
generated_images_dir = f"CGAN/generated_images/{architecture}{tag}"
Path(generated_images_dir).mkdir(parents=True, exist_ok=True)
Path("CGAN/models").mkdir(exist_ok=True)
gan_model_path = f"CGAN/models/generator_{architecture}{tag}.pth"


# Hyperparameters
LATENT_DIM = 100  # Dimensions of noise Z
NUM_CLASSES = 20
IMG_SIZE = 128
CHANNELS = 3  # RGB
g_lr = 0.0005 if architecture == CONV else 0.0003
d_lr = 0.00051 if architecture == CONV else 0.00031
b1, b2 = 0.5, 0.99
N_EPOCHS = 1000
BATCH_SIZE = 64 if architecture == RESNET else 128
# Number of training iterations for discriminator for each iteration of generator
D_TURNS = 5 if architecture == RESNET else 3
lambda_gp = 10  # Hyper-parameter for gradient penalty

In [None]:
# Tensorboard stuff
writer = SummaryWriter(log_dir=f'CGAN/tensorboard/{architecture}{tag}')


def log_losses_to_tensorboard(epoch, g_loss, d_loss):
    writer.add_scalar('Loss/Generator', g_loss, epoch)
    writer.add_scalar('Loss/Discriminator', d_loss, epoch)


def log_gradients_to_tensorboard(model, epoch, model_name):
    total_norm = 0
    for name, param in model.named_parameters():
        if param.grad is not None:
            norm = param.grad.norm(2).item()
            total_norm += norm ** 2
            writer.add_scalar(f'Gradients/{model_name}/{name}', norm, epoch)
    total_norm = total_norm ** 0.5
    writer.add_scalar(f'Gradients/{model_name}/total_norm', total_norm, epoch)

In [None]:
filename = "data/classes.txt"

if os.path.exists(filename):
    # Load the class labels from the text file
    with open(filename, 'r') as file:
        classes = [line.strip() for line in file.readlines()]
else:
    with open('data/Animals_data/name of the animals.txt', 'r') as animals:
        animal_names = [name.strip() for name in animals.readlines()]
        classes = random.sample(animal_names, NUM_CLASSES)
        with open(filename, 'w') as file:
            file.write('\n'.join(classes) + '\n')


classes

## Create Dataset

In [None]:

class AnimalDataset(Dataset):
    def __init__(self, root_dir, classes, transform=None):
        self.root_dir = root_dir  # Root directory containing the image data
        self.transform = transform  # Image transformations to be applied
        self.classes = classes  # List of class names
        self.image_paths = []  # List to store paths of all images
        self.labels = []  # List to store corresponding labels

        # Iterate through each class and collect image paths and labels
        for class_id, class_name in enumerate(self.classes):
            class_dir = os.path.join(root_dir, class_name)
            for img_name in os.listdir(class_dir):
                self.image_paths.append(os.path.join(class_dir, img_name))
                self.labels.append(class_id)

    def __len__(self):
        return len(self.image_paths)  # Return the total number of images in the dataset

    def preload(self):
        # Preload all images into memory (useful for small datasets)
        self.preloaded_images = [Image.open(path).convert(
            'RGB') for path in self.image_paths]

    def __getitem__(self, idx):
        src = self.image_paths[idx]  # Get the path of the image at index idx
        image = Image.open(src).convert('RGB')  # Open the image and convert to RGB
        if self.transform:
            image = self.transform(image)  # Apply transformations if specified
        label = self.labels[idx]  # Get the corresponding label

        return image, label

    def augment_with_gan(self, generator, samples_per_class):
        generator.eval()  # Set generator to evaluation mode

        for class_idx in range(len(self.classes)):
            class_path = f"CGAN/augmentation_data/{self.classes[class_idx]}"
            Path(class_path).mkdir(parents=True, exist_ok=True)  # Create directory for augmented images

            # Generate noise and class labels for the generator
            noise = torch.randn(samples_per_class, LATENT_DIM).to(device)
            class_labels = torch.full(
                (samples_per_class,), class_idx).to(device)

            with torch.no_grad():
                fake_samples = generator(noise, class_labels)  # Generate fake samples
                fake_samples = (fake_samples + 1) / 2.0  # Denormalize the generated images

            # Save generated images and update dataset
            for i, fake_sample in enumerate(fake_samples):
                fake_sample = to_pil_image(fake_sample)
                path = f"{class_path}/{i}.jpg"
                fake_sample.save(path)
                self.image_paths.append(path)
                self.labels.append(class_idx)


# Define image transformations
img_transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.CenterCrop(IMG_SIZE),  # To preserve aspect ratio of landscape oriented images
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),  # Normalize to [-1, 1] range
])

# Create the dataset
dataset = AnimalDataset(root_dir='data/Animals_data/animals/animals/',
                        classes=classes, transform=img_transform)

# Create the dataloader
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE,
                        num_workers=num_workers, shuffle=True)

## Create Models
- There are two models we need:
    1. A _generator_ that takes in random noise (z) and a class label as input and generates a sample image
    2. A _discriminator_ that takes in an image and class label as input and outputs a score signifying "realness" of the image.
- We perform experiments with two different architecture variants:
    1. **Convolution Based (DCGAN)**: Less complex and memory efficient
    2. **Resnet Based**: More complex but requires a lot more memory

### Generator

- The label is converted to an embedding/one-hot vector and concatenated with z
- The concatenated input is then passed to a linear layer which outputs a tensor of shape `(512, init_size, init_size)` where `init_size X init_size` is the initial size of 2D image
- The image is then passed through multiple convolutional blocks via upsampling till we get an image with 3 channels (RGB)
- The last layer is an tanH activation layer so that our RGB values are normalized to [-1, 1] which is the same normalization as the real images
- Instead of regular batch norm, we use conditional batch norm that learns a different set of normalization parameters of each class

In [None]:
def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)  # Initialize weights of Conv2d and ConvTranspose2d layers with normal distribution (mean=0.0, std=0.02)
    elif (isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d)) and m.weight is not None:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)  # Initialize weights of BatchNorm2d and InstanceNorm2d layers with normal distribution (mean=1.0, std=0.02)
        torch.nn.init.constant_(m.bias.data, 0)  # Initialize biases of BatchNorm2d and InstanceNorm2d layers to 0


# Taken from: https://github.com/pytorch/pytorch/issues/8985#issuecomment-405080775
class ConditionalBatchNorm2d(nn.Module):
    def __init__(self, num_features, num_classes):
        super().__init__()  # Initialize parent class
        self.num_features = num_features  # Store number of features
        self.bn = nn.BatchNorm2d(num_features, affine=False)  # Create BatchNorm2d layer without learnable affine parameters
        self.embed = nn.Embedding(num_classes, num_features * 2)  # Create embedding layer for class-conditional parameters
        self.embed.weight.data[:, :num_features].normal_(
            1, 0.02)  # Initialize scale parameters of embedding with normal distribution (mean=1, std=0.02)
        self.embed.weight.data[:, num_features:].zero_()  # Initialize bias parameters of embedding to 0

    def forward(self, x, labels):
        out = self.bn(x)  # Apply batch normalization
        gamma, beta = self.embed(labels).chunk(2, 1)  # Get class-conditional scale and bias parameters
        out = gamma.reshape(-1, self.num_features, 1, 1) * out + \
            beta.reshape(-1, self.num_features, 1, 1)  # Apply class-conditional affine transformation
        return out


class ConditionalGenerator(nn.Module):
    def __init__(self, latent_dim, num_classes, channels):
        super(ConditionalGenerator, self).__init__()  # Initialize parent class
        num_blocks = 5 # no of upsamples
        self.init_size = IMG_SIZE // (2**num_blocks)  # Calculate initial size of feature maps
        self.num_classes = num_classes  # Store number of classes

        self.l1 = nn.Sequential(
            nn.Linear(latent_dim + num_classes, 512 * self.init_size * self.init_size))  # Initial linear layer to project input to desired shape

        self.cbn1 = ConditionalBatchNorm2d(512, num_classes)  # First conditional batch norm layer

        self.cbns = nn.ModuleList([
            ConditionalBatchNorm2d(256, num_classes),
            ConditionalBatchNorm2d(128, num_classes),
            ConditionalBatchNorm2d(64, num_classes),
            ConditionalBatchNorm2d(32, num_classes),
        ])  # List of conditional batch norm layers for each block
        self.relus = nn.ModuleList(
            [nn.LeakyReLU(0.2, inplace=True) for _ in range(num_blocks-1)])  # List of LeakyReLU activation functions

        self.upsamplers = nn.ModuleList(
            [nn.Upsample(scale_factor=2) for _ in range(num_blocks-1)]
        )  # List of upsampling layers
        self.convs = nn.ModuleList([
            nn.Conv2d(512, 256, 5, stride=1, padding=2),
            nn.Conv2d(256, 128, 5, stride=1, padding=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.Conv2d(64, 32, 3, stride=1, padding=1),
        ]
        )  # List of convolutional layers
        self.out_layer = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(32, channels, 3, stride=1, padding=1),
            nn.Tanh()
        )  # Output layer with final upsampling, convolution, and Tanh activation

    def forward(self, z, labels):
        label_embeddings = one_hot(labels, self.num_classes)  # Convert labels to one-hot encoding
        combined = torch.cat([z, label_embeddings], dim=1)  # Concatenate noise and label embeddings
        out = self.l1(combined)  # Pass through initial linear layer
        out = out.view(out.size(0), 512, self.init_size, self.init_size)  # Reshape output to 4D tensor
        out = self.cbn1(out, labels)  # Apply first conditional batch norm
        for upsampler, conv, cbn, relu in zip(self.upsamplers, self.convs, self.cbns, self.relus):
            out = upsampler(out)  # Upsample
            out = conv(out)  # Apply convolution
            out = cbn(out, labels)  # Apply conditional batch norm
            out = relu(out)  # Apply LeakyReLU activation

        out = self.out_layer(out)  # Pass through final output layer
        return out

The ResUpBlock class is a building block for the ResNet-style generator in a Conditional GAN (Generative Adversarial Network). It implements a residual upsampling block that combines convolutional layers with conditional batch normalization and skip connections.

In [1]:
class ResUpBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.cond_bns = nn.ModuleList([
            ConditionalBatchNorm2d(in_channels, NUM_CLASSES),
            ConditionalBatchNorm2d(out_channels, NUM_CLASSES)
        ])

        self.convs = nn.ModuleList([
            nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1),
            nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1),
        ])

        self.skip = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0)
        )
        self.initialize()

    def initialize(self):
        for m in self.convs.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.xavier_uniform_(m.weight, math.sqrt(2))
                torch.nn.init.zeros_(m.bias)
        for m in self.skip.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.xavier_uniform_(m.weight)
                torch.nn.init.zeros_(m.bias)

    def forward(self, x, labels):
        out = self.cond_bns[0](x, labels)
        out = leaky_relu(out, 0.2)
        out = interpolate(out, scale_factor=2)
        out = self.convs[0](out)
        out = self.cond_bns[1](out, labels)
        out = leaky_relu(out, 0.2)
        out = self.convs[1](out)
        return out + self.skip(x)


class ResGenerator(nn.Module):
    def __init__(self, z_dim, num_classes, channels):
        super().__init__()
        self.z_dim = z_dim
        self.num_classes = num_classes
        num_blocks = 4 # no of downsamples
        self.init_size = IMG_SIZE // (2**num_blocks)

        self.l1 = nn.Linear(z_dim + num_classes,
                            self.init_size * self.init_size * 512)

        self.blocks = nn.ModuleList([
            ResUpBlock(512, 256),
            ResUpBlock(256, 128),
            ResUpBlock(128, 64),
            ResUpBlock(64, 32),
        ]
        )
        self.cbn = ConditionalBatchNorm2d(32, num_classes)

        self.output = nn.Sequential(
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(32, channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )
        self.initialize()

    def initialize(self):
        torch.nn.init.xavier_uniform_(self.l1.weight)
        torch.nn.init.zeros_(self.l1.bias)
        for m in self.output.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.xavier_uniform_(m.weight)
                torch.nn.init.zeros_(m.bias)

    def forward(self, z, labels):
        label_embeddings = one_hot(labels, self.num_classes)
        z = torch.cat([z, label_embeddings], dim=1)
        z = self.l1(z)
        z = z.view(-1, 512, self.init_size, self.init_size)
        for block in self.blocks:
            z = block(z, labels)
        z = self.cbn(z, labels)
        return self.output(z)

NameError: name 'nn' is not defined

### Discriminator

- The image is first downsampled through multiple convolutional blocks till we get a set of images features in a single channel
- The features are then concatenated with the label embedding
- The concatenation is then pass through a linear layer which outputs a scalar value
- For the resnet variant, instead of concatenating the feature maps with the label embedding, we add a projection of the features on the label embedding to the final scalar output. According to [Miyato et al. 2018](https://arxiv.org/abs/1802.05637), this approach is a better way encode class information compared to concatenation since it allows a linear interaction between the class label and the image features. Projection works well with the resnet variant because the pooling layers aggregate spatial information into a single vector. Whereas, the convolution architecture retains spatial information and flattening the structure for projection would lead to a loss of spatial information.
- The label embeddings are generate from a static embedding layer where the class embeddings are orthogonal to each other. This is done to maintain class separability in the image features.

## Training C-GAN
- We train using Wasserstein's distance as our loss function
- The discriminator is trained `d_turns` times more than the generator to prevent the generator from overpowering
- We also apply [gradient penalty](https://arxiv.org/abs/1704.00028) to the discriminator loss to prevent overfitting.
- Our dataset is quite limited so we apply augmentations to both the real and fake images for better generalization. Since backpropagation needs to accomodate for the augmentation information when updating the generator, our augmentations need to be differentiable. This idea directly comes from [Differentiable Augmentation for Data-Efficient GAN Training](https://arxiv.org/abs/2006.10738)

In [None]:
# Differentiable Augmentation for Data-Efficient GAN Training
# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
# https://arxiv.org/pdf/2006.10738
# Code take from: https://github.com/mit-han-lab/data-efficient-gans/blob/master/DiffAugment_pytorch.py

def DiffAugment(x, policy='color,translation,cutout', channels_first=True):
    if policy:
        if not channels_first:
            x = x.permute(0, 3, 1, 2)
        for p in policy.split(','):
            for f in AUGMENT_FNS[p]:
                x = f(x)
        if not channels_first:
            x = x.permute(0, 2, 3, 1)
        x = x.contiguous()
    return x


def rand_brightness(x):
    x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
    return x


def rand_saturation(x):
    x_mean = x.mean(dim=1, keepdim=True)
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
    return x


def rand_contrast(x):
    x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
    return x


def rand_translation(x, ratio=0.125):
    shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
    translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(x.size(2), dtype=torch.long, device=x.device),
        torch.arange(x.size(3), dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
    grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
    x_pad = torch.nn.functional.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
    x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2).contiguous()
    return x


def rand_cutout(x, ratio=0.5):
    cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
    offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
        torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
    grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
    mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
    mask[grid_batch, grid_x, grid_y] = 0
    x = x * mask.unsqueeze(1)
    return x


AUGMENT_FNS = {
    'color': [rand_brightness, rand_saturation, rand_contrast],
    'translation': [rand_translation],
    'cutout': [rand_cutout],
}

In [None]:
# Save samples periodically to visually track progress
def generate_and_save_images(generator, epoch, display=False):
    # Generate random noise for the generator
    with torch.no_grad():
        # Generate 5 samples for each class
        samples_per_class = 5
        total_samples = NUM_CLASSES * samples_per_class
        z = torch.randn(total_samples, LATENT_DIM, device=device)
        labels = torch.tensor(list(range(NUM_CLASSES)) *
                              samples_per_class, device=device)

        # Generate images
        generated_images = generator(z, labels)

        # Move images to CPU and convert to numpy arrays
        generated_images = generated_images.cpu().numpy()

        # Rescale images from [-1, 1] to [0, 1]
        generated_images = (generated_images + 1) / 2.0

        # Plot images
        fig, axes = plt.subplots(NUM_CLASSES, samples_per_class, figsize=(
            samples_per_class*2, NUM_CLASSES*2))
        for i, ax in enumerate(axes.flatten()):
            img = np.transpose(generated_images[i], (1, 2, 0))
            ax.imshow(img)
            ax.axis('off')
            if i % samples_per_class == 0:
                ax.set_title(
                    f"Class {dataset.classes[i // samples_per_class]}")

        plt.tight_layout()
        plt.title(f'Epoch {epoch}')
        plt.savefig(f'{generated_images_dir}/{epoch}.jpg')
        if display:
            plt.show()
        else:
            plt.close()

In [None]:
# Taken from [1704.00028] Improved Training of Wasserstein GANs
def gradient_penalty(discriminator, real_samples, fake_samples, labels):
    batch_size = real_samples.size(0)
    alpha = torch.rand(batch_size, 1, 1, 1).to(device)
    interpolates = (alpha * real_samples + (1 - alpha)
                    * fake_samples).requires_grad_(True)
    c_interpolates = discriminator(interpolates, labels)
    gradients = torch.autograd.grad(
        outputs=c_interpolates, inputs=interpolates,
        grad_outputs=torch.ones_like(c_interpolates),
        create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    gradients = gradients.view(batch_size, -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty * lambda_gp


def train_cgan(generator, discriminator):
    # Optimizers
    optimizer_G = optim.Adam(generator.parameters(), lr=g_lr, betas=(b1, b2))
    optimizer_D = optim.Adam(
        discriminator.parameters(), lr=d_lr, betas=(b1, b2))

    scheduler_G = optim.lr_scheduler.StepLR(
        optimizer_G, step_size=N_EPOCHS//4, gamma=0.5)
    scheduler_D = optim.lr_scheduler.StepLR(
        optimizer_D, step_size=N_EPOCHS//4, gamma=0.5)


    # Training loop
    for epoch in trange(N_EPOCHS, desc="Training Progress"):
        for i, (images, labels) in enumerate(dataloader):
            batch_size = images.size(0)
            labels = labels.to(device)
            images = images.to(device)

            batch_size = images.size(0)
            optimizer_D.zero_grad()
            z = torch.randn(batch_size, LATENT_DIM).to(device)
            gen_imgs = generator(z, labels)
            images_aug = DiffAugment(images)
            gen_imgs_aug = DiffAugment(gen_imgs)
            real_validity = discriminator(images_aug, labels)
            fake_validity = discriminator(gen_imgs_aug.detach(), labels)

            gp = gradient_penalty(discriminator,
                                  images_aug.data, gen_imgs_aug.data, labels)

            d_loss = -torch.mean(real_validity) + \
                torch.mean(fake_validity) + gp
            d_loss.backward()
            optimizer_D.step()
            if i % D_TURNS == 0:
                optimizer_G.zero_grad()
                z = torch.randn(batch_size, LATENT_DIM).to(device)
                gen_imgs = generator(z, labels)
                g_loss = - \
                    torch.mean(discriminator(DiffAugment(gen_imgs), labels))
                g_loss.backward()
                optimizer_G.step()

        scheduler_D.step()
        scheduler_G.step()

        log_losses_to_tensorboard(epoch, g_loss.item(), d_loss.item())
        log_gradients_to_tensorboard(generator, epoch, 'Generator')
        log_gradients_to_tensorboard(
            discriminator, epoch, 'Discriminator')

        if epoch % 50 == 0:
            print(f"Epoch [{epoch}/{N_EPOCHS}]  D_loss: {d_loss.item():.4f}, G_loss: {g_loss.item():.4f}")
            generate_and_save_images(generator, epoch)


if TRAIN_CGAN:
    # Instantiate generator and discriminator
    if architecture == RESNET:
        # One GPU is not enough for the resnet GAN so use DataParallel
        generator = nn.DataParallel(ResGenerator(
            LATENT_DIM, NUM_CLASSES, CHANNELS)).to(device)
        discriminator = nn.DataParallel(
            ResDiscriminator(NUM_CLASSES, CHANNELS)).to(device)
    else:
        generator = ConditionalGenerator(
            LATENT_DIM, NUM_CLASSES, CHANNELS).to(device)
        discriminator = ConditionalDiscriminator(
            NUM_CLASSES, CHANNELS).to(device)
        generator.apply(init_weights)
        discriminator.apply(init_weights)

    train_cgan(generator, discriminator)
    torch.save(generator.state_dict(), gan_model_path)
    print("C-GAN Training Complete")
else:
    if architecture == RESNET:
        generator = nn.DataParallel(ResGenerator(
            LATENT_DIM, NUM_CLASSES, CHANNELS)).to(device)
    else:
        generator = ConditionalGenerator(
            LATENT_DIM, NUM_CLASSES, CHANNELS).to(device)

    generator.load_state_dict(torch.load(
        gan_model_path, weights_only=True))


generate_and_save_images(generator, N_EPOCHS, display=True)