In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import random

In [2]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if using multi-GPU
    torch.backends.cudnn.deterministic = True  # Ensure deterministic behavior
    torch.backends.cudnn.benchmark = False  # Turn off the cuDNN auto-tuner to avoid nondeterministic behavior


img_size = 28
img_channels = 1

In [3]:
# Generator model
class Generator(nn.Module):
    def __init__(self, z_dim):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, img_size * img_size * img_channels),
            nn.Tanh()
        )

    def forward(self, x):
        return self.gen(x).view(-1, img_channels, img_size, img_size)

# Discriminator model
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_size * img_size * img_channels, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.disc(x.view(-1, img_size * img_size * img_channels))

# Encoder model for orthogonal vector computation
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(img_size * img_size * img_channels, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 64)
        )

    def forward(self, x):
        return self.encoder(x.view(-1, img_size * img_size * img_channels))

# Loss function
criterion = nn.BCELoss()

In [4]:

# Function to generate noise for the Generator
def generate_noise(batch_size, z_dim, device):
    return torch.randn(batch_size, z_dim).to(device)

# Orthogonal loss function
def orthogonal_loss(feature1, feature2):
    inner_product = torch.sum(feature1 * feature2, dim=1)
    norm1 = torch.norm(feature1, dim=1)
    norm2 = torch.norm(feature2, dim=1)
    cosine_similarity = inner_product / (norm1 * norm2 + 1e-8)
    return torch.mean(cosine_similarity**2)  # Minimize the cosine similarity to make vectors orthogonal


In [5]:
# Visualize generated images after training
def visualize_generated_images(generators, z_dim, num_images, device):
    noise = generate_noise(num_images, z_dim, device)
    for idx, gen in enumerate(generators):
        fake_images = gen(noise).detach().cpu()
        grid = torchvision.utils.make_grid(fake_images, normalize=True)
        plt.title(f"Generator {idx + 1}")
        plt.imshow(grid.permute(1, 2, 0))
        plt.show()

In [None]:
def train_gan(num_epochs, z_dim, lr, batch_size, num_generators, seed):
    # Set the seed for reproducibility
    set_seed(seed)

    # Check for device (GPU or CPU)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Transformation for the MNIST dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])

    # Load the MNIST dataset
    train_dataset = torchvision.datasets.MNIST(root="./data", train=True, transform=transform, download=True)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    # Define the models, optimizers, etc.
    generators = [Generator(z_dim).to(device) for _ in range(num_generators)]
    disc = Discriminator().to(device)
    encoder = Encoder().to(device)
    
    optimizer_gens = [optim.Adam(gen.parameters(), lr=lr) for gen in generators]
    optimizer_disc = optim.Adam(disc.parameters(), lr=lr)
    
    # Reset loss history
    loss_history_gens = [[] for _ in range(num_generators)]
    loss_history_disc = []

    # Training loop
    for epoch in range(num_epochs):
        loss_disc_epoch = 0
        loss_gens_epoch = [0] * num_generators  # Track each generator's loss
        
        for batch_idx, (real, _) in enumerate(train_loader):
            real = real.to(device)
            batch_size = real.size(0)
        
            # Train Discriminator
            noise = [generate_noise(batch_size, z_dim, device) for _ in range(num_generators)]
            fakes = [gen(noise[idx]) for idx, gen in enumerate(generators)]
        
            disc_real = disc(real).view(-1)
            loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
        
            loss_disc_fake_total = 0
            for fake in fakes:
                disc_fake = disc(fake.detach()).view(-1)
                loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
                loss_disc_fake_total += loss_disc_fake
        
            loss_disc = (loss_disc_real + loss_disc_fake_total) / (num_generators + 1)
            optimizer_disc.zero_grad()
            loss_disc.backward()
            optimizer_disc.step()
            
            loss_disc_epoch += loss_disc.item()
            
            # Train Generators with Orthogonal Loss
            for idx, gen in enumerate(generators):
                output = disc(gen(noise[idx])).view(-1)
                loss_gen = criterion(output, torch.ones_like(output))

                # Compute orthogonal loss with other generators
                gen_feature = encoder(gen(noise[idx]))
                ortho_loss_total = 0
                for other_idx, other_gen in enumerate(generators):
                    if idx != other_idx:
                        other_feature = encoder(other_gen(noise[other_idx]))
                        ortho_loss = orthogonal_loss(gen_feature, other_feature)
                        ortho_loss_total += ortho_loss
                
                # Combine GAN loss and orthogonal loss
                total_loss_gen = loss_gen + ortho_loss_total / (num_generators - 1)
                optimizer_gens[idx].zero_grad()
                total_loss_gen.backward()
                optimizer_gens[idx].step()

                loss_gens_epoch[idx] += total_loss_gen.item()

        # Store average loss values for the epoch
        loss_history_disc.append(loss_disc_epoch / len(train_loader))
        for idx in range(num_generators):
            loss_history_gens[idx].append(loss_gens_epoch[idx] / len(train_loader))

        # **Debugging: Print losses at the end of each epoch**
        print(f"Epoch [{epoch+1}/{num_epochs}] Loss D: {loss_disc_epoch/len(train_loader):.4f}")
        for idx in range(num_generators):
            print(f"Epoch [{epoch+1}/{num_epochs}] Loss G{idx+1}: {loss_gens_epoch[idx]/len(train_loader):.4f}")
        print(f"Epoch [{epoch+1}/{num_epochs}] Total Loss G: {sum(loss_gens_epoch)/len(train_loader):.4f}")
        print('-' * 50)

        # Visualize images every 5 epochs
        if (epoch + 1) % 5 == 0:
            visualize_generated_images(generators, z_dim, num_images=16, device=device)

    # Plot loss history
    plt.figure(figsize=(10, 5))
    plt.title("Discriminator Loss")
    plt.plot(loss_history_disc, label="Discriminator Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

    for idx in range(num_generators):
        plt.figure(figsize=(10, 5))
        plt.title(f"Generator {idx+1} Loss")
        plt.plot(loss_history_gens[idx], label=f"Generator {idx+1} Loss")
        plt.xlabel("Epochs")
        plt.ylabel("Loss")
        plt.legend()
        plt.show()

# Example 1: Run with 2 generators, seed 42, and specified hyperparameters
train_gan(num_epochs=100, z_dim=100, lr=0.0002, batch_size=64, num_generators=2, seed=42)

In [None]:
train_gan(num_epochs=100, z_dim=100, lr=0.0002, batch_size=64, num_generators=2, seed=42)

In [None]:
# Example 1: Run with 2 generators, seed 42, and specified hyperparameters
train_gan(num_epochs=100, z_dim=100, lr=0.0002, batch_size=64, num_generators=3, seed=42)