<table class="table table-bordered">
    <tr>
        <th style="text-align:center;"><h1>Visual Generative AI Application: Generative Adversarial Networks</h1><h2>Assignment</h2><h3>Specialist Diploma in Applied Generative AI (SDGAI) 
</h3></th>
    </tr>
</table>

# (1) State clearly the goal and objectives you hope to achieve in this notebook


The objective of this project is to design and implement a generative model capable of creating images of fashion items from 10 distinct categories (classes). For this part of the assignment, I will develop a **conditional (Vanilla) GAN**.

For each model, we will analyze the model performance and tune the model hyperparameters during training phase. For this, we will explore the following:
- Vary the number of epochs:
    - Start with 50
    - Change this to 100
- Change the display step from 500 to 200
- Increase the latent dimension to 200

For each model we will do the following:
1. Load and explore the dataset
2. Build the model
    1. Start with a baseline model
    2. Consider the different hyperparameters that can be tuned
    3. Perform the tuning
    4. Analyse the model's performance based on the different tuning strategies
3. Evaluate the model
    1. Use the model to generate new images from noise

# (2) Import libraries

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST 
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from torchvision import datasets

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import random

torch.manual_seed(0) 

In [None]:
device ='cuda' if torch.cuda.is_available else 'cpu'
print(f'Using {device} device')

# (1) Load/Download Dataset  

We will now download the Dataset for this assignment. You may amend the __batch_size__ parameter, __transform__ function or the attributes of the Dataloader as you deem fit for your processing.

In [None]:
batch_size = 64  # Batch size for training and testing

dataset_path = 'D:\\Users\\ng_a\\My NP SDGAI\\PDC-2\\VGAA\\Assignment\\'  # Use double backslashes

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

dataset = datasets.FashionMNIST(dataset_path, download=False, transform=transform)

dataloader = DataLoader(
    # datasets.FashionMNIST(dataset_path, download=True, transform=transforms.ToTensor()),
    dataset,
    batch_size=batch_size,
    shuffle=True
)

In [None]:
# Get the first image and its label
image, label = dataset[0]

# Print the shape of the image
print("Shape of the Fashion MNIST image:", image.shape, label)

mnist_shape = tuple(image.shape)
n_classes = len(dataset.classes)

print (mnist_shape)
print (n_classes)

# (2) Explore the data

In [None]:
labels_map = {
    0: 'T-shirt',
    1: 'Trouser',
    2: 'Pullover',
    3: 'Dress',
    4: 'Coat',
    5: 'Sandal',
    6: 'Shirt',
    7: 'Sneaker',
    8: 'Bag',
    9: 'Ankle Boot',
}
# Create a subplot with 4x4 grid
fig, axs = plt.subplots(4, 4, figsize=(8, 8))

# Loop through each subplot and plot an image
for i in range(4):
    for j in range(4):
        image, label = dataset[i * 4 + j]  # Get image and label
        image_numpy = image.numpy().squeeze()    # Convert image tensor to numpy array
        axs[i, j].imshow(image_numpy, cmap='gray')  # Plot the image
        axs[i, j].axis('off')  # Turn off axis
        axs[i, j].set_title(f"{labels_map[label]}")  # Set title with label

plt.tight_layout()  # Adjust layout
plt.show()  # Show plot

In [None]:
def show_tensor_images(image_tensor, num_images=9, size=(1, 28, 28)):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in a uniform grid.
    '''

    # Move the image tensor to CPU
    image_unflat = image_tensor.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=3)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.axis('off')       
    plt.show()

In [None]:
def get_one_hot_labels(labels, n_classes):
    '''
    Function for creating one-hot vectors for the labels, returns a tensor of shape (?, num_classes).
    Parameters:
        labels: tensor of labels from the dataloader, size (?)
        n_classes: the total number of classes in the dataset, an integer scalar
    '''
    # CHECK ME
    # Return a one-hot representation of vector
    return F.one_hot(labels, n_classes)

In [None]:
def combine_vectors(x, y):
    '''
    Function for combining two vectors with shapes (n_samples, ?) and (n_samples, ?).
    Parameters:
      x: (n_samples, ?) the first vector. 
        In this assignment, this will be the noise vector of shape (n_samples, z_dim), 
        but you shouldn't need to know the second dimension's size.
      y: (n_samples, ?) the second vector.
        Once again, in this assignment this will be the one-hot class vector 
        with the shape (n_samples, n_classes), but you shouldn't assume this in your code.
    '''
    # CHECK ME
    # Note: Make sure this function outputs a float no matter what inputs it receives
    # Concatentate the 2 vectors together
    combined = torch.cat((x, y), dim=1).float()
    
    return combined

In [None]:
def get_input_dimensions(z_dim, mnist_shape, n_classes):
    '''
    Function for getting the size of the conditional input dimensions 
    from z_dim, the image shape, and number of classes.
    Parameters:
        z_dim: the dimension of the noise vector, a scalar
        mnist_shape: the shape of each MNIST image as (C, W, H), which is (1, 28, 28)
        n_classes: the total number of classes in the dataset, an integer scalar
                (10 for MNIST)
    Returns: 
        generator_input_dim: the input dimensionality of the conditional generator, 
                          which takes the noise and class vectors
        discriminator_im_chan: the number of input channels to the discriminator
                            (e.g. C x 28 x 28 for MNIST)
    '''
    # compute the generator input dimension, taking into account the number of classes.
    generator_input_dim = z_dim + n_classes # CHECK ME
    discriminator_im_chan = mnist_shape[0] + n_classes
    return generator_input_dim, discriminator_im_chan

In [None]:
def plot_eval_curves(dis_loss_combine, gen_loss_combine, title, filename):
    plt.figure()
    plt.plot(np.arange(len(dis_loss_combine)), dis_loss_combine,'r')
    plt.plot(np.arange(len(gen_loss_combine)), gen_loss_combine,'b')
    plt.legend(['Dis Loss','Gen Loss'])
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(title)
    #plt.savefig(filename)

# (3) Modeling

In [None]:
# Define hyperparameters

# Use the criterion BCE with Logits Loss
criterion = nn.BCEWithLogitsLoss()
n_epochs = 200
z_dim = 64
display_step = 500
batch_size = 64
lr = 0.0002

## (3a) Noise Generator

In [None]:
def get_noise(n_samples, input_dim, device='cuda'):
    '''
    Function for creating noise vectors: Given the dimensions (n_samples, input_dim)
    creates a tensor of that shape filled with random numbers from the normal distribution.
    Parameters:
        n_samples: the number of samples to generate, a scalar
        input_dim: the dimension of the input vector, a scalar
        device: the device type
    '''
    return torch.randn(n_samples, input_dim, device=device)

## (3b) Generator

In [None]:
#  Defines a generator class
class Generator(nn.Module):
    '''
    Generator Class
    Values:
        input_dim: the dimension of the input vector, a scalar
        im_chan: the number of channels of the output image, a scalar
              (MNIST is black-and-white, so 1 channel is your default)
        hidden_dim: the inner dimension, a scalar
    '''
    def __init__(self, input_dim=10, im_chan=1, hidden_dim=64):
        super(Generator, self).__init__()
        self.input_dim = input_dim
        # Build the neural network
        self.gen = nn.Sequential(
            self.make_gen_block(input_dim, hidden_dim * 4),
            self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),
            self.make_gen_block(hidden_dim * 2, hidden_dim),
            self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
        )

    def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):
        '''
        Function to return a sequence of operations corresponding to a generator block of DCGAN;
        a transposed convolution, a batchnorm (except in the final layer), and an activation.
        Parameters:
            input_channels: how many channels the input feature representation has
            output_channels: how many channels the output feature representation should have
            kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
            stride: the stride of the convolution
            final_layer: a boolean, true if it is the final layer and false otherwise 
                      (affects activation and batchnorm)
        '''
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True),
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.Tanh(),
            )

    def forward(self, noise):
        '''
        Function for completing a forward pass of the generator: Given a noise tensor, 
        returns generated images.
        Parameters:
            noise: a noise tensor with dimensions (n_samples, input_dim)
        '''
        x = noise.view(len(noise), self.input_dim, 1, 1)
        return self.gen(x)

## (3c) Discriminator

In [None]:
# Defines a Disciminator class
class Discriminator(nn.Module):
    '''
    Discriminator Class
    Values:
      im_chan: the number of channels of the output image, a scalar
            (MNIST is black-and-white, so 1 channel is your default)
      hidden_dim: the inner dimension, a scalar
    '''
    def __init__(self, im_chan=1, hidden_dim=64):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            self.make_disc_block(im_chan, hidden_dim),
            self.make_disc_block(hidden_dim, hidden_dim * 2),
            self.make_disc_block(hidden_dim * 2, 1, final_layer=True),
        )

    def make_disc_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):
        '''
        Function to return a sequence of operations corresponding to a discriminator block of the DCGAN; 
        a convolution, a batchnorm (except in the final layer), and an activation (except in the final layer).
        Parameters:
            input_channels: how many channels the input feature representation has
            output_channels: how many channels the output feature representation should have
            kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
            stride: the stride of the convolution
            final_layer: a boolean, true if it is the final layer and false otherwise 
                      (affects activation and batchnorm)
        '''
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace=True),
            )
        else:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
            )

    def forward(self, image):
        '''
        Function for completing a forward pass of the discriminator: Given an image tensor, 
        returns a 1-dimension tensor representing fake/real.
        Parameters:
            image: a flattened image tensor with dimension (im_chan)
        '''
        disc_pred = self.disc(image)
        return disc_pred.view(len(disc_pred), -1)

# (4) Training, Tuning, Evaluation

In [None]:
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

In [None]:
def do_create_gen(input_dim, weights_init, lr, device):
    gen = Generator(input_dim).to(device)
    gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
    gen = gen.apply(weights_init)

    return gen, gen_opt

In [None]:
def do_create_disc(im_chan, weights_init, lr, device):
    disc = Discriminator(im_chan).to(device)
    disc_opt = torch.optim.Adam(disc.parameters(), lr=lr) 
    disc = disc.apply(weights_init)

    return disc, disc_opt

In [None]:
def do_train(dataloader, gen, gen_opt, disc, disc_opt, n_epochs, display_step, z_dim, mnist_shape, n_classes, device):
    cur_step = 0
    generator_losses = []
    discriminator_losses = []
    
    #UNIT TEST NOTE: Initializations needed for grading
    noise_and_labels = False
    fake = False
    
    fake_image_and_labels = False
    real_image_and_labels = False
    disc_fake_pred = False
    disc_real_pred = False
    
    for epoch in range(n_epochs):
        print('Training at Epoch ', epoch)
        # Dataloader returns the batches and the labels
        for real, labels in tqdm(dataloader):
            cur_batch_size = len(real)
            # Flatten the batch of real images from the dataset
            real = real.to(device)
    
            one_hot_labels = get_one_hot_labels(labels.to(device), n_classes)
            image_one_hot_labels = one_hot_labels[:, :, None, None]
            image_one_hot_labels = image_one_hot_labels.repeat(1, 1, mnist_shape[1], mnist_shape[2])
    
            ### Update discriminator ###
            # Zero out the discriminator gradients
            disc_opt.zero_grad() # CHECK ME
            
            # Get noise corresponding to the current batch_size 
            fake_noise = get_noise(cur_batch_size, z_dim, device=device) # CHECK ME

            # Now you can get the images from the generator
            # Steps: 1) Combine the noise vectors and the one-hot labels for the generator
            #        2) Generate the conditioned fake images
           
            
            # Combine the noise vectors and the one-hot labels for the generator
            noise_and_labels = combine_vectors(fake_noise, one_hot_labels)
    
            # Generate the conditioned fake images
            fake = gen(noise_and_labels)
            
            # Now you can get the predictions from the discriminator
            # Steps: 1) Create the input for the discriminator
            #           a) Combine the fake images with image_one_hot_labels, 
            #              remember to detach the generator (.detach()) so you do not backpropagate through it
            #           b) Combine the real images with image_one_hot_labels
            #        2) Get the discriminator's prediction on the fakes as disc_fake_pred
            #        3) Get the discriminator's prediction on the reals as disc_real_pred
            
            #### START CODE HERE ####
    
            # Combine the fake images with image_one_hot_labels, remember to detach the generator (.detach()) so you do not backpropagate through it
            fake_image_and_labels = combine_vectors(fake.detach(), image_one_hot_labels) # CHECK ME
    
            # Combine the real images with image_one_hot_labels
            real_image_and_labels = combine_vectors(real, image_one_hot_labels) # CHECK ME
    
            # Get the discriminator's prediction on the fakes as disc_fake_pred
            disc_fake_pred = disc(fake_image_and_labels) # CHECK ME
    
            # Get the discriminator's prediction on the reals as disc_real_pred
            disc_real_pred = disc(real_image_and_labels) # CHECK ME
            
            disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
            disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
            disc_loss = (disc_fake_loss + disc_real_loss) / 2
            disc_loss.backward(retain_graph=True)
            disc_opt.step()

            # Keep track of the average discriminator loss
            discriminator_losses += [disc_loss.item()]
    
            ### Update generator ###
            # Zero out the generator gradients
            gen_opt.zero_grad() # CHECK ME
    
            fake_image_and_labels = combine_vectors(fake, image_one_hot_labels)
            # This will error if you didn't concatenate your labels to your image correctly
            disc_fake_pred = disc(fake_image_and_labels)
            gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
            gen_loss.backward()
            gen_opt.step()
    
            # Keep track of the generator losses
            generator_losses += [gen_loss.item()]
            #

            if cur_step % display_step == 0 and cur_step > 0:
                gen_mean = sum(generator_losses[-display_step:]) / display_step
                disc_mean = sum(discriminator_losses[-display_step:]) / display_step
                print(f"Step {cur_step}: Generator loss: {gen_mean}, discriminator loss: {disc_mean}")
                show_tensor_images(fake)
                show_tensor_images(real)
                
                step_bins = 20
                x_axis = sorted([i * step_bins for i in range(len(generator_losses) // step_bins)] * step_bins)
                num_examples = (len(generator_losses) // step_bins) * step_bins
                plt.plot(
                    range(num_examples // step_bins), 
                    torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
                    label="Generator Loss"
                )
                plt.plot(
                    range(num_examples // step_bins), 
                    torch.Tensor(discriminator_losses[:num_examples]).view(-1, step_bins).mean(1),
                    label="Discriminator Loss"
                )
                plt.legend()
                plt.show()
            elif cur_step == 0:
                print("Congratulations! If you've gotten here, it's working. Please let this train until you're happy with how the generated numbers look, and then go on to the exploration!")
            cur_step += 1
    return generator_losses, discriminator_losses

## (4a) Start with 50 Epochs

In [None]:
# Initialize the generator, discriminator, and optimizers.
generator_input_dim, discriminator_im_chan = get_input_dimensions(z_dim, mnist_shape, n_classes)

In [None]:
gen_1, gen_opt_1 = do_create_gen(generator_input_dim, weights_init, lr, device)
disc_1, disc_opt_1 = do_create_disc(discriminator_im_chan, weights_init, lr, device)

In [None]:
n_epochs = 50
gen_loss_combine_1, dis_loss_combine_1 = do_train(dataloader, gen_1, gen_opt_1, disc_1, disc_opt_1, n_epochs, display_step, z_dim, mnist_shape, n_classes, device)

## (5a) Tuning: Increase number of Epochs to 100

In [None]:
# Initialize the generator, discriminator, and optimizers.
generator_input_dim_2, discriminator_im_chan_2 = get_input_dimensions(z_dim, mnist_shape, n_classes)

gen_2, gen_opt_2 = do_create_gen(generator_input_dim_2, weights_init, lr, device)
disc_2, disc_opt_2 = do_create_disc(discriminator_im_chan_2, weights_init, lr, device)

n_epochs = 100
gen_loss_combine_2, dis_loss_combine_2 = do_train(dataloader, gen_2, gen_opt_2, disc_2, disc_opt_2, n_epochs, display_step, z_dim, mnist_shape, n_classes, device)

## (5b) Tuning: Change display step to 200

In [None]:
n_epochs = 100

z_dim = 64
display_step = 200

# Initialize the generator, discriminator, and optimizers.
generator_input_dim_3, discriminator_im_chan_3 = get_input_dimensions(z_dim, mnist_shape, n_classes)

gen_3, gen_opt_3 = do_create_gen(generator_input_dim_3, weights_init, lr, device)
disc_3, disc_opt_3 = do_create_disc(discriminator_im_chan_3, weights_init, lr, device)

gen_loss_combine_3, dis_loss_combine_3 = do_train(dataloader, gen_3, gen_opt_3, disc_3, disc_opt_3, n_epochs, display_step, z_dim, mnist_shape, n_classes, device)

## (5c) Tuning:  Increase latent dimension to 200

In [None]:
n_epochs = 100
z_dim = 200
display_step = 500

# Initialize the generator, discriminator, and optimizers.
generator_input_dim_4, discriminator_im_chan_4 = get_input_dimensions(z_dim, mnist_shape, n_classes)

gen_4, gen_opt_4 = do_create_gen(generator_input_dim_4, weights_init, lr, device)
disc_4, disc_opt_4 = do_create_disc(discriminator_im_chan_4, weights_init, lr, device)

gen_loss_combine_4, dis_loss_combine_4 = do_train(dataloader, gen_4, gen_opt_4, disc_4, disc_opt_4, n_epochs, display_step, z_dim, mnist_shape, n_classes, device)

In [None]:
# To do

# n_epochs = 200
# z_dim = 64
# display_step = 500
# batch_size = 128
# lr = 0.0002

print (n_epochs,  z_dim, display_step, batch_size, lr, mnist_shape, n_classes, device)

In [None]:
snapshot = 1
for image, _ in tqdm(dataloader):
    print(image.shape)
    show_tensor_images(image)
    if snapshot == 100:
        break
    else:
        snapshot = snapshot+1

# (6) Evaluation

In [None]:
def evaluate_model(gen, labels, z_dim, batch_size, num_classes, mnist_shape, device):
    # Before you explore, you should put the generator in eval mode, both in general and so that batch norm
    # doesn't cause you issues and is using its eval statistics
    gen_e = gen.eval()

    # CHECK ME
    # Convert the labels to tensor
    labels_ = torch.tensor(labels * batch_size, device=device)
    
    # Encode labels with one-hot encoding
    one_hot_labels_e = get_one_hot_labels(labels_, num_classes) # CHECK ME
     
    image_one_hot_labels_e = one_hot_labels_e[:, :, None, None]
    image_one_hot_labels_e = image_one_hot_labels_e.repeat(1, 1, mnist_shape[1], mnist_shape[2])
    
    # Obtain the noise 
    fake_noise = get_noise(batch_size, z_dim, device) # CHECK ME
   
    # Concatenate noise with one_hot_encode labels
    noise_and_labels = torch.cat((fake_noise, one_hot_labels_e), dim=1) # CHECK ME
    
    # CHECK ME
    # Generate fake image
    fake = gen_e(noise_and_labels)

    return fake

# (6) Generating New Samples From Generative Models

In [None]:
print (batch_size)

In [None]:
# 0: 'T-shirt',
# 1: 'Trouser',
# 2: 'Pullover',
# 3: 'Dress',
# 4: 'Coat',
# 5: 'Sandal',
# 6: 'Shirt',
# 7: 'Sneaker',
# 8: 'Bag',
# 9: 'Ankle Boot'

labels = [0] # T-shirt

# epochs = 50

n_epochs = 50
z_dim = 64
display_step = 500
batch_size = 64
lr = 0.0002
fake_1 = evaluate_model(gen_1, labels, z_dim, batch_size, n_classes, mnist_shape, device)
show_tensor_images(fake_1)

# epochs = 100
n_epochs = 100
z_dim = 64
display_step = 500
batch_size = 64
lr = 0.0002
fake_2 = evaluate_model(gen_2, labels, z_dim, batch_size, n_classes, mnist_shape, device)
show_tensor_images(fake_2)

# z_dim = 200
# z_dim_ = 200

n_epochs = 100
z_dim = 64
display_step = 200
batch_size = 64
lr = 0.0002
fake_3 = evaluate_model(gen_3, labels, z_dim, batch_size, n_classes, mnist_shape, device)
show_tensor_images(fake_3)

# z dim = 200
n_epochs = 200
z_dim = 200
display_step = 500
batch_size = 64
lr = 0.0002
fake_4 = evaluate_model(gen_4, labels, z_dim, batch_size, n_classes, mnist_shape, device)
show_tensor_images(fake_4)

## Saving the trained model

In [None]:
torch.save(gen_1.state_dict(), 'cgan_generator_1.pth')
torch.save(disc_1.state_dict(), 'cgan_discriminator_1.pth')

torch.save(gen_2.state_dict(), 'cgan_generator_2.pth')
torch.save(disc_2.state_dict(), 'cgan_discriminator_2.pth')

torch.save(gen_3.state_dict(), 'cgan_generator_3.pth')
torch.save(disc_3.state_dict(), 'cgan_discriminator_3.pth')

torch.save(gen_4.state_dict(), 'cgan_generator_4.pth')
torch.save(disc_4.state_dict(), 'cgan_discriminator_4.pth')

# (7) Conclusion

In [None]:
gen_1.load_state_dict(torch.load('D:\\Users\\ng_a\\My NP SDGAI\\PDC-2\\VGAA\\Assignment\\SavedModels\\cGANs\\9Nov2024\\cgan_generator_1.pth'))
gen_1.eval()

In [None]:
def save_model(model, filename):
    model_scripted = torch.jit.script(model)
    model_scripted.save(filename)

def load_model(filename, eval=True):
    model = torch.jit.load(filename)
    if eval is True:
        model.eval()
    else:
        model.train()
    return model

In [None]:
save_model(gen_1, 'cgan_generator_1.pt')
save_model(disc_1, 'cgan_discriminator_1.pt')

save_model(gen_2, 'cgan_generator_2.pt')
save_model(disc_2, 'cgan_discriminator_2.pt')

save_model(gen_3, 'cgan_generator_3.pt')
save_model(disc_3, 'cgan_discriminator_3.pt')

save_model(gen_4, 'cgan_generator_4.pt')
save_model(disc_4, 'cgan_discriminator_4.pt')

In [None]:
save_model(gen_1, 'gen_1.pt')
gen_1_a = load_model('gen_1.pt')

In [None]:
labels = [0] # T-Shirt
fake_1 = evaluate_model(gen_1_a, labels, z_dim, batch_size, n_classes, mnist_shape, device)
show_tensor_images(fake_1)

In [None]:
labels = [0] # T-Shirt
fake_1 = evaluate_model(gen_1, labels, z_dim, batch_size, n_classes, mnist_shape, device)
show_tensor_images(fake_1)