In [None]:
# ============================================================
# PART 1 (Alternative): Training a GAN on 2D Synthetic Datasets
# ============================================================

# Import necessary packages
import torch                         # For deep learning models
import torch.nn as nn                # For neural network layers
import numpy as np                   # For numeric and array operations
import matplotlib.pyplot as plt      # For plotting results

# Set random seeds to guarantee consistent results on each run
torch.manual_seed(321)
np.random.seed(321)

# Set up device (CUDA GPU if available, else CPU fallback)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ------------------------------------------
# 1. DATA GENERATION (Alternative Functions)
# ------------------------------------------

# Function to generate moons data (two interleaving half circles)
def sample_moons(n=600, noise=0.08):
    from sklearn.datasets import make_moons
    X, _ = make_moons(n_samples=n, noise=noise)
    return X.astype(np.float32)

# Function to generate concentric ring data
def sample_rings(n=600, radius1=1.0, radius2=2.3, noise=0.08):
    n1 = n // 2
    n2 = n - n1
    theta1 = np.random.uniform(0, 2 * np.pi, n1)
    theta2 = np.random.uniform(0, 2 * np.pi, n2)
    x1 = radius1 * np.cos(theta1) + np.random.normal(0, noise, n1)
    y1 = radius1 * np.sin(theta1) + np.random.normal(0, noise, n1)
    x2 = radius2 * np.cos(theta2) + np.random.normal(0, noise, n2)
    y2 = radius2 * np.sin(theta2) + np.random.normal(0, noise, n2)
    X = np.concatenate([
        np.stack([x1, y1], axis=1),
        np.stack([x2, y2], axis=1)
    ], axis=0)
    return X.astype(np.float32)

# ------------------------------------------
# 2. MODEL DEFINITIONS
# ------------------------------------------

# Generator module: maps random latent input to 2D sample
class Gen2D(nn.Module):
    def __init__(self, z_dim=2, h_dim=40):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(z_dim, h_dim),
            nn.ELU(),
            nn.Linear(h_dim, h_dim),
            nn.ELU(),
            nn.Linear(h_dim, 2)
        )

    def forward(self, z):
        return self.layers(z)

# Discriminator module: scores 2D input as real or fake
class Disc2D(nn.Module):
    def __init__(self, h_dim=40):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(2, h_dim),
            nn.Tanh(),
            nn.Linear(h_dim, h_dim),
            nn.Tanh(),
            nn.Linear(h_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.layers(x)

# ------------------------------------------
# 3. TRAINING FUNCTION
# ------------------------------------------

def run_gan(
    data_sampler,      # Function that returns real data array
    z_dim=2,           # Latent noise vector size
    g_hidden=40,       # Generator hidden size
    d_hidden=40,       # Discriminator hidden size
    batch=128,         # Batch size
    epochs=1200,       # Training epochs
    plot_every=300,    # Plot frequency
    tag="exp"          # Output prefix
):
    # Instantiate generator and discriminator
    G = Gen2D(z_dim, g_hidden).to(DEVICE)
    D = Disc2D(d_hidden).to(DEVICE)

    # Loss: binary cross entropy
    bce = nn.BCELoss()
    # Optimizers
    g_opt = torch.optim.Adam(G.parameters(), lr=0.001)
    d_opt = torch.optim.Adam(D.parameters(), lr=0.001)

    # Store losses for curves
    g_loss_hist, d_loss_hist = [], []

    for epoch in range(1, epochs + 1):
        # === Train Discriminator ===
        real_samples = torch.tensor(data_sampler(batch), device=DEVICE)
        z = torch.randn(batch, z_dim, device=DEVICE)
        fake_samples = G(z).detach()
        real_labels = torch.full((batch, 1), 1.0, device=DEVICE)
        fake_labels = torch.full((batch, 1), 0.0, device=DEVICE)
        d_real = D(real_samples)
        d_fake = D(fake_samples)
        d_loss = bce(d_real, real_labels) + bce(d_fake, fake_labels)
        d_opt.zero_grad()
        d_loss.backward()
        d_opt.step()

        # === Train Generator ===
        z = torch.randn(batch, z_dim, device=DEVICE)
        fake_samples = G(z)
        d_pred = D(fake_samples)
        g_loss = bce(d_pred, real_labels)
        g_opt.zero_grad()
        g_loss.backward()
        g_opt.step()

        # Store loss
        g_loss_hist.append(g_loss.item())
        d_loss_hist.append(d_loss.item())

        # Plot and print at intervals
        if epoch % plot_every == 0 or epoch == epochs:
            print(f"[{tag}] Epoch {epoch}/{epochs} | D Loss: {d_loss.item():.3f} | G Loss: {g_loss.item():.3f}")
            show_gan_samples(G, data_sampler, z_dim, f"{tag}_epoch{epoch}.png")

    # Plot final GAN result
    show_gan_samples(G, data_sampler, z_dim, f"{tag}_final.png")

    # Plot loss curves
    plt.figure()
    plt.plot(g_loss_hist, label="Generator")
    plt.plot(d_loss_hist, label="Discriminator")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("GAN Training Loss Curves")
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{tag}_loss_curves.png")
    plt.close()
    print(f"All results saved for {tag}.")

# ------------------------------------------
# 4. PLOTTING FUNCTION
# ------------------------------------------

def show_gan_samples(G, data_sampler, z_dim, fname):
    G.eval()
    with torch.no_grad():
        real_points = data_sampler(500)
        z = torch.randn(500, z_dim, device=DEVICE)
        fake_points = G(z).cpu().numpy()
    plt.figure(figsize=(6,6))
    plt.scatter(real_points[:, 0], real_points[:, 1], c='limegreen', alpha=0.5, label='Real')
    plt.scatter(fake_points[:, 0], fake_points[:, 1], c='purple', alpha=0.5, label='Generated')
    plt.legend()
    plt.title("Real vs Generated 2D Samples")
    plt.tight_layout()
    plt.savefig(fname)
    plt.close()
    G.train()

# ------------------------------------------
# 5. RUN THE EXPERIMENTS
# ------------------------------------------

if __name__ == "__main__":
    # GAN on moons data
    run_gan(
        data_sampler=sample_moons,
        z_dim=2,
        g_hidden=40,
        d_hidden=40,
        batch=128,
        epochs=1300,
        plot_every=400,
        tag="moons"
    )
    # GAN on concentric rings
    run_gan(
        data_sampler=sample_rings,
        z_dim=2,
        g_hidden=40,
        d_hidden=40,
        batch=128,
        epochs=1200,
        plot_every=400,
        tag="rings"
    )

[moons] Epoch 400/1300 | D Loss: 1.429 | G Loss: 0.774
[moons] Epoch 800/1300 | D Loss: 1.419 | G Loss: 0.617
[moons] Epoch 1200/1300 | D Loss: 1.395 | G Loss: 0.672
[moons] Epoch 1300/1300 | D Loss: 1.395 | G Loss: 0.727
All results saved for moons.
[rings] Epoch 400/1200 | D Loss: 1.260 | G Loss: 0.745
[rings] Epoch 800/1200 | D Loss: 1.375 | G Loss: 0.717
[rings] Epoch 1200/1200 | D Loss: 1.352 | G Loss: 0.719
All results saved for rings.
