In [None]:
from utils import *
from models import *

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

In [None]:
# Set random seed for reproducibility
torch.manual_seed(42)

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# GAN for MNIST

In [None]:
def gan_mnist():
    
    # Hyperparameters
    latent_dim = 100
    num_epochs = 300
    batch_size = 256
    lr = 0.0002
    
    # Initialize generator and discriminator
    generator = GeneratorMNIST().to(device)
    discriminator = DiscriminatorMNIST().to(device)

    # Loss function and optimizers
    criterion = nn.BCELoss()
    g_optimizer = optim.Adam(generator.parameters(), lr=lr)
    d_optimizer = optim.Adam(discriminator.parameters(), lr=lr)

    # Get MNIST dataloader
    dataloader = getDataLoader('mnist', batch_size)

    # Train GAN
    G_losses, D_losses = trainGan(num_epochs, dataloader, generator, discriminator, latent_dim, criterion, g_optimizer, d_optimizer, device, 'Gan_MNIST')

    # Plot loss curves
    plt.figure(figsize=(10, 5))
    plt.plot(G_losses, label='Generator')
    plt.plot(D_losses, label='Discriminator')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('GAN MNIST Loss Curves')
    plt.savefig('gan_mnist_loss_curves.png')
    plt.show()

    # Perform latent space interpolation
    latent_space_interpolation_gan(generator, latent_dim, device, 'Gan_MNIST', num_steps=10)
    
    # save models
    torch.save(generator.state_dict(), f'models/generator_mnist.pth')
    torch.save(discriminator.state_dict(), f'models/discriminator_mnist.pth')
    
    # delete model from memory to save space
    del generator
    del discriminator

    print("Training complete. Generated images, loss curves, and latent space interpolation have been saved.")

# GAN for CIFAR-10

In [None]:
def gan_cifar():
    
    # Hyperparameters
    latent_dim = 100
    num_epochs = 300
    batch_size = 256
    lr = 0.0002
    beta1 = 0.5
    
    # Initialize generator and discriminator
    generator = GeneratorCIFAR().to(device)
    discriminator = DiscriminatorCIFAR().to(device)

    # Loss function and optimizers
    criterion = nn.BCELoss()
    g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
    d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))


    # Get CIFAR dataloader
    dataloader = getDataLoader('cifar', batch_size)

    # Train GAN
    G_losses, D_losses = trainGan(num_epochs, dataloader, generator, discriminator, latent_dim, criterion, g_optimizer, d_optimizer, device, 'Gan_CIFAR')

    # Plot loss curves
    plt.figure(figsize=(10, 5))
    plt.plot(G_losses, label='Generator')
    plt.plot(D_losses, label='Discriminator')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('GAN CIFAR Loss Curves')
    plt.savefig('gan_cifar_loss_curves.png')
    plt.show()

    # Perform latent space interpolation
    latent_space_interpolation_gan(generator, latent_dim, device, 'Gan_CIFAR', num_steps=10)
    
    # save models
    torch.save(generator.state_dict(), f'models/generator_cifar.pth')
    torch.save(discriminator.state_dict(), f'models/discriminator_cifar.pth')
    
    # Visualize real CIFAR-10 images
    def show_real_images():
        real_batch = next(iter(dataloader))
        plt.figure(figsize=(10, 10))
        plt.imshow(make_grid(real_batch[0][:64], padding=2, normalize=True).permute(1, 2, 0))
        plt.title("Real CIFAR-10 Images")
        plt.axis('off')
        plt.savefig('Gan_CIFAR/real_cifar10_images.png')
        plt.show()

    show_real_images()

    # delete model from memory to save space
    del generator
    del discriminator
        
    print("Training complete. Generated images, loss curves, and latent space interpolation have been saved.")
    

# VAE for MNIST

In [None]:
def vae_mnist():
    
    # Hyperparameters
    batch_size = 256
    num_epochs = 300
    learning_rate = 1e-3
    latent_dim = 20
    
    # Initialize model and optimizer
    model = VAE().to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Get MNIST dataloader
    dataloader = getDataLoader('mnist', batch_size)

    # Train VAE
    elbo_losses, kl_divergences = trainVAE(num_epochs, dataloader, model, latent_dim, optimizer, device)

    # Plot loss curves
    plt.figure(figsize=(10, 5))
    plt.plot(elbo_losses, label='ELBO')
    plt.plot(kl_divergences, label='KLD')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('VAE MNIST Loss Curves')
    plt.savefig('vae_mnist_loss_curves.png')
    plt.show()
    
    # save models
    torch.save(model.state_dict(), f'models/vae_mnist.pth')
    
    def plot_reconstructed_images():
        model.eval()
        with torch.no_grad():
            sample = next(iter(dataloader))[0][:8]
            recon, _, _ = model(sample)
        
            fig, axes = plt.subplots(2, 8, figsize=(15, 5))
            for i in range(8):
                axes[0, i].imshow(sample[i].squeeze(), cmap='gray')
                axes[0, i].axis('off')
                axes[1, i].imshow(recon[i].view(28, 28).cpu(), cmap='gray')
                axes[1, i].axis('off')
            plt.title('Original (top) vs Reconstructed (bottom)')
            plt.savefig('vae_reconstructed_images.png')
            plt.show()

    def plot_generated_images():
        model.eval()
        with torch.no_grad():
            sample = torch.randn(64, latent_dim)
            generated = model.decode(sample).view(-1, 1, 28, 28)
        
            fig, axes = plt.subplots(8, 8, figsize=(15, 15))
            for i, ax in enumerate(axes.flat):
                ax.imshow(generated[i].squeeze().cpu(), cmap='gray')
                ax.axis('off')
            plt.title('Generated Images')
            plt.savefig('vae_generated_images.png')
            plt.show()

    plot_reconstructed_images()
    plot_generated_images()
    
    # delete model from memory to save space
    del model
    
    print("Training complete. Generated loss curves have been saved.")

## VAE for anomaly detection

In [None]:
def vae_anomaly():
    
    # load model
    model = VAE().to(device)
    model.load_state_dict(torch.load(f'models/vae_mnist.pth', weights_only=True))
    model.eval()
    
    # Get MNIST dataloader
    dataloader = getDataLoader('mnist', 128)
    
    # Function to add noise to images
    def add_noise(images, noise_factor=0.5):
        noisy_images = images + noise_factor * torch.randn(*images.shape)
        return torch.clamp(noisy_images, 0., 1.)
    
    # Function to calculate reconstruction error
    def reconstruction_error(model, images):
        model.eval()
        with torch.no_grad():
            recon, _, _ = model(images)
            error = F.mse_loss(recon, images.view(-1, 784), reduction='none').sum(axis=1)
        return error.cpu().numpy()
    
    
    # Calculate reconstruction errors for normal and anomalous images
    normal_errors  = []
    anomalous_errors = []
    for batch_idx, (data, _) in enumerate(dataloader):
        
        normal_images, _ = next(iter(data))
        anomalous_images = add_noise(normal_images)
        
        normal_errors.append(reconstruction_error(model, normal_images))
        anomalous_errors.append(reconstruction_error(model, anomalous_images))
        
    # Plot distribution of reconstruction errors
    plt.figure(figsize=(10, 6))
    plt.hist(normal_errors, bins=50, alpha=0.5, label='Normal')
    plt.hist(anomalous_errors, bins=50, alpha=0.5, label='Anomalous')
    plt.xlabel('Reconstruction Error')
    plt.ylabel('Frequency')
    plt.legend()
    plt.title('Distribution of Reconstruction Errors')
    plt.savefig('vae_recon_errors.png')
    plt.show()
    
    # Set threshold for anomaly detection (95th percentile of normal errors)
    threshold = np.percentile(normal_errors, 95)

    # Classify images as normal or anomalous
    normal_classifications = normal_errors < threshold
    anomalous_classifications = anomalous_errors >= threshold

    print(f"Percentage of normal images classified as normal: {normal_classifications.mean()*100:.2f}%")
    print(f"Percentage of anomalous images classified as anomalous: {anomalous_classifications.mean()*100:.2f}%")
    
    # Calculate and plot ROC curve
    from sklearn.metrics import roc_curve, auc

    y_true = np.concatenate([np.zeros_like(normal_errors), np.ones_like(anomalous_errors)])
    y_scores = np.concatenate([normal_errors, anomalous_errors])

    fpr, tpr, _ = roc_curve(y_true, y_scores)
    roc_auc = auc(fpr, tpr)

    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc="lower right")
    plt.savefig('vae_detection_roc.png')
    plt.show()
    
    # delete model from memory to save space
    del model
        

## VAE Latent space visualization

In [None]:
def latent_space_vis():
    
    # load model
    model = VAE().to(device)
    model.load_state_dict(torch.load(f'models/vae_mnist.pth', weights_only=True))
    model.eval()
    
    # Get MNIST dataloader
    dataloader = getDataLoader('mnist', 128)
    
    # Function to encode the entire test set
    def encode_dataset(model, data_loader, device):
        model.eval()
        latent_vectors = []
        labels = []
    
        with torch.no_grad():
            for data, label in data_loader:
                data = data.to(device)
                mu, _ = model.encode(data)
                latent_vectors.append(mu.cpu())
                labels.append(label)
    
        return torch.cat(latent_vectors, dim=0), torch.cat(labels, dim=0)
    
    # Encode the test set
    latent_vectors, labels = encode_dataset(model, dataloader, device)
    
    # Reduce dimensionality to 2D if necessary
    if latent_vectors.shape[1] > 2:
        tsne = TSNE(n_components=2, random_state=42)
        latent_2d = tsne.fit_transform(latent_vectors.numpy())
    else:
        latent_2d = latent_vectors.numpy()
        
    # Create scatter plot
    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(latent_2d[:, 0], latent_2d[:, 1], c=labels, cmap='tab10')
    plt.colorbar(scatter)
    plt.title('VAE Latent Space Visualization')
    plt.xlabel('Latent Dimension 1')
    plt.ylabel('Latent Dimension 2')
    plt.savefig('vae_latent_space_visualization.png')
    plt.show()
    
    # delete model from memory to save space
    del model



# Run experiment

### Generated Images at intermediate steps for GANs and the final latent space interpolation image is stored in the Gan_MNIST and Gan_CIFAR folders

In [None]:
print("Training GAN for MNIST")
gan_mnist()
print("Training GAN for CIFAR-10")
gan_cifar()
print("Training VAE for MNIST")
vae_mnist()
print("VAE for anomaly detection")
vae_anomaly()
print("VAE Latent Space Visualization")
latent_space_vis()