In [8]:
import os
import pickle
import time
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm


In [7]:

class ImageGenerator(nn.Module):
    """Generates images from a given noise vector."""
    def __init__(self, noise_dim, image_dim):
        super(ImageGenerator, self).__init__()
        self.first_layer = nn.Sequential(nn.Linear(noise_dim, 256), nn.LeakyReLU(0.2))
        self.second_layer = nn.Sequential(nn.Linear(256, 512), nn.LeakyReLU(0.2))
        self.third_layer = nn.Sequential(nn.Linear(512, 1024), nn.LeakyReLU(0.2))
        self.final_layer = nn.Sequential(nn.Linear(1024, image_dim), nn.Tanh())

    def forward(self, noise):
        noise = self.first_layer(noise)
        noise = self.second_layer(noise)
        noise = self.third_layer(noise)
        return self.final_layer(noise)


class ImageDiscriminator(nn.Module):
    """Predicts whether an image is real or generated."""
    def __init__(self, image_dim, output_dim=1):
        super(ImageDiscriminator, self).__init__()
        self.first_layer = nn.Sequential(nn.Linear(image_dim, 1024), nn.LeakyReLU(0.2))
        self.second_layer = nn.Sequential(nn.Linear(1024, 512), nn.LeakyReLU(0.2))
        self.third_layer = nn.Sequential(nn.Linear(512, 256), nn.LeakyReLU(0.2))
        self.final_layer = nn.Sequential(nn.Linear(256, output_dim), nn.Sigmoid())

    def forward(self, image):
        image = self.first_layer(image)
        image = self.second_layer(image)
        image = self.third_layer(image)
        return self.final_layer(image)


def display_images(generator, noise, epoch, show=False, save=False, path='result.png'):
    with torch.no_grad():
        images = generator(noise).cpu().view(-1, 28, 28)
        fig, axes = plt.subplots(5, 5, figsize=(5, 5), sharex=True, sharey=True)
        for ax, img in zip(axes.flatten(), images):
            ax.xaxis.set_visible(False)
            ax.yaxis.set_visible(False)
            ax.imshow(img.detach().numpy(), cmap='gray_r')
    fig.suptitle(f'Epoch {epoch}')
    
    if save:
        plt.savefig(path)
    if show:
        plt.show()
    plt.close()


def plot_loss_history(history, show=False, save=False, path='loss_history.png'):
    plt.figure(figsize=(10, 5))
    plt.plot(history['discriminator_losses'], label='Discriminator Loss')
    plt.plot(history['generator_losses'], label='Generator Loss')
    plt.title("Training Losses")
    plt.xlabel("Batch")
    plt.ylabel("Loss")
    plt.legend()
    if save:
        plt.savefig(path)
    if show:
        plt.show()
    plt.close()


def generate_noise(batch_size, dimensions):
    return torch.randn(batch_size, dimensions)


if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    data_directory = './MNIST_data/'
    results_directory = './MNIST_GAN_results/'
    image_results_directory = './MNIST_GAN_results/images'

    if not os.path.exists(results_directory):
        os.makedirs(results_directory)
    if not os.path.exists(image_results_directory):
        os.makedirs(image_results_directory)

    batch_size = 100
    learning_rate = 0.0002
    training_epochs = 100

    noise_dimensions = 100
    flat_image_dimensions = 28 * 28

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    dataset = datasets.MNIST(root=data_directory, train=True, download=True, transform=transform)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    generator = ImageGenerator(noise_dimensions, flat_image_dimensions).to(device)
    discriminator = ImageDiscriminator(flat_image_dimensions).to(device)

    optimizer_G = torch.optim.Adam(generator.parameters(), lr=learning_rate)
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=learning_rate)

    loss_fn = nn.BCELoss()

    history = {'discriminator_losses': [], 'generator_losses': []}

    start_time = time.time()

    for epoch in range(training_epochs):
        for i, (images, _) in enumerate(loader):
            images = images.view(batch_size, -1).to(device)
            real_labels = torch.ones(batch_size, 1).to(device)
            fake_labels = torch.zeros(batch_size, 1).to(device)

            # Train Discriminator
            optimizer_D.zero_grad()
            
            # Real images
            outputs_real = discriminator(images)
            loss_real = loss_fn(outputs_real, real_labels)

            # Fake images
            noise = generate_noise(batch_size, noise_dimensions).to(device)
            fake_images = generator(noise)
            outputs_fake = discriminator(fake_images.detach())
            loss_fake = loss_fn(outputs_fake, fake_labels)

            # Backprop and optimize
            loss_D = loss_real + loss_fake
            loss_D.backward()
            optimizer_D.step()

            # Train Generator
            optimizer_G.zero_grad()
            outputs_fake = discriminator(fake_images)
            loss_G = loss_fn(outputs_fake, real_labels)
            loss_G.backward()
            optimizer_G.step()

            if (i + 1) % 200 == 0:
                print(f'Epoch [{epoch + 1}/{training_epochs}], Step [{i + 1}/{len(loader)}], Loss_D: {loss_D.item()}, Loss_G: {loss_G.item()}')

        history['discriminator_losses'].append(loss_D.item())
        history['generator_losses'].append(loss_G.item())

        # Display images periodically
        if (epoch + 1) % 10 == 0:
            display_images(generator, generate_noise(25, noise_dimensions).to(device), epoch + 1, save=True, path=f'{image_results_directory}/Epoch_{epoch + 1}.png')

    total_time = time.time() - start_time
    print(f'Total Training Time: {total_time:.2f}s')

    # Save the model checkpoints
    torch.save(generator.state_dict(), f'{results_directory}/generator.pth')
    torch.save(discriminator.state_dict(), f'{results_directory}/discriminator.pth')

    # Plot and save loss history
    plot_loss_history(history, save=True, path=f'{results_directory}/loss_history.png')


            


Epoch [1/100], Step [200/600], Loss_D: 0.09874388575553894, Loss_G: 7.03685188293457
Epoch [1/100], Step [400/600], Loss_D: 0.13508620858192444, Loss_G: 5.099188804626465
Epoch [1/100], Step [600/600], Loss_D: 0.4734697937965393, Loss_G: 4.120473384857178
Epoch [2/100], Step [200/600], Loss_D: 0.4999423623085022, Loss_G: 2.0276451110839844
Epoch [2/100], Step [400/600], Loss_D: 0.8440132141113281, Loss_G: 1.513079285621643
Epoch [2/100], Step [600/600], Loss_D: 1.6606357097625732, Loss_G: 1.2800430059432983
Epoch [3/100], Step [200/600], Loss_D: 0.3642454147338867, Loss_G: 3.244110107421875
Epoch [3/100], Step [400/600], Loss_D: 0.9636448621749878, Loss_G: 1.4988113641738892
Epoch [3/100], Step [600/600], Loss_D: 2.0083513259887695, Loss_G: 4.388504981994629
Epoch [4/100], Step [200/600], Loss_D: 0.12781336903572083, Loss_G: 3.739996910095215
Epoch [4/100], Step [400/600], Loss_D: 2.888962745666504, Loss_G: 1.0993353128433228
Epoch [4/100], Step [600/600], Loss_D: 1.4035427570343018, L