<a href="https://colab.research.google.com/github/raycmarange/AML425_RAY/blob/main/project2_generative_models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# AIML425 - Assignment 2 (Generative Models)
# Refactored solution with explicit SGD and comprehensive evaluations
# Author: Ray Marange (original), refactored for completeness and robustness

import os
import math
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from scipy.stats import kstest, chisquare

# =========================
# Reproducibility and device
# =========================
torch.manual_seed(42)
np.random.seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# =========================
# Utilities
# =========================
def to_np(t):
    return t.detach().cpu().numpy()

def ensure_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

OUTDIR = "outputs_assn2"
ensure_dir(OUTDIR)

# =========================
# Distribution samplers
# =========================
def gaussian_2d(batch_size):
    return torch.randn(batch_size, 2)

def uniform_2d(batch_size):
    return torch.rand(batch_size, 2) - 0.5  # uniform on [-0.5, 0.5]^2

def uniform_1d(batch_size):
    return torch.rand(batch_size, 1) - 0.5  # uniform on [-0.5, 0.5]

class EmpiricalSampler:
    """Wraps a dataset (N,d) tensor/array to return random mini-batches."""
    def __init__(self, data):
        self.data = torch.as_tensor(data, dtype=torch.float32)
        self.N = self.data.shape[0]
    def __call__(self, batch_size):
        idx = torch.randint(0, self.N, (batch_size,))
        return self.data[idx]

# =========================
# Kernels and MMD
# =========================
def gaussian_kernel_multi(x, y, sigmas=(0.2, 0.5, 1.0, 2.0, 5.0)):
    """
    Multi-kernel Gaussian RBF: average of RBFs with multiple bandwidths.
    More stable across scales than a single sigma.
    """
    x2 = (x**2).sum(dim=1, keepdim=True)
    y2 = (y**2).sum(dim=1, keepdim=True)
    dist = x2 + y2.T - 2 * (x @ y.T)
    K = 0
    for s in sigmas:
        K = K + torch.exp(-dist / (2 * s**2))
    return K / len(sigmas)

def mmd2_unbiased(x, y, kernel=gaussian_kernel_multi):
    """
    Unbiased MMD^2 estimator (U-statistic).
    """
    n = x.shape[0]
    assert y.shape[0] == n, "Use equal batch sizes for unbiased MMD."
    Kxx = kernel(x, x)
    Kyy = kernel(y, y)
    Kxy = kernel(x, y)
    # Remove diagonals for unbiased estimator
    sum_Kxx = (Kxx.sum() - Kxx.diag().sum()) / (n * (n - 1))
    sum_Kyy = (Kyy.sum() - Kyy.diag().sum()) / (n * (n - 1))
    sum_Kxy = Kxy.mean()
    return sum_Kxx + sum_Kyy - 2 * sum_Kxy

# =========================
# Networks
# =========================
class Generator(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=128, depth=3, activation='relu'):
        super().__init__()
        acts = {
            'relu': nn.ReLU,
            'gelu': nn.GELU,
            'tanh': nn.Tanh,
            'leakyrelu': lambda: nn.LeakyReLU(0.2)
        }
        Act = acts[activation.lower()]
        layers = []
        d_in = input_dim
        for _ in range(depth - 1):
            layers += [nn.Linear(d_in, hidden_dim), Act()]
            d_in = hidden_dim
        layers += [nn.Linear(d_in, output_dim)]
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

# =========================
# Training (explicit SGD)
# =========================
def train_generator(
    source_sampler, target_sampler, input_dim, output_dim,
    epochs=5000, batch_size=512, lr=0.01,
    reg_lambda=0.0, reg_type=None,
    hidden_dim=128, depth=3, activation='relu',
    kernel=gaussian_kernel_multi, log_every=500, name="net"
):
    """
    Trains a generator with explicit SGD updates (no optimizer.step()).
    Returns model, weight history, loss history.
    """
    net = Generator(input_dim, output_dim, hidden_dim, depth, activation).to(device)
    weight_history, loss_history = [], []

    for epoch in range(1, epochs + 1):
        net.train()
        # Sample source and target mini-batches
        x = source_sampler(batch_size).to(device)
        y = target_sampler(batch_size).to(device)

        # Forward
        y_hat = net(x)

        # MMD^2 loss
        loss = mmd2_unbiased(y_hat, y, kernel=kernel)

        # Regularization
        if reg_type == 'l1':
            l1 = sum(p.abs().sum() for p in net.parameters())
            loss = loss + reg_lambda * l1
        elif reg_type == 'l2':
            l2 = sum((p**2).sum() for p in net.parameters())
            loss = loss + reg_lambda * l2

        # Backprop
        for p in net.parameters():
            if p.grad is not None:
                p.grad.detach_()
                p.grad.zero_()
        loss.backward()

        # Explicit SGD update
        with torch.no_grad():
            for p in net.parameters():
                p.data -= lr * p.grad

        # Logging
        loss_history.append(float(loss.detach().cpu()))
        if epoch % log_every == 0 or epoch == 1:
            w = torch.cat([p.data.view(-1) for p in net.parameters()]).detach().cpu().numpy()
            weight_history.append(w)
            print(f"[{name}] epoch {epoch:5d}/{epochs}  MMD2={loss_history[-1]:.6f}")

    return net, np.array(weight_history, dtype=np.float32), np.array(loss_history, dtype=np.float32)

# =========================
# Evaluation helpers
# =========================
def eval_uniform_2d(samples_np, grid_n=10):
    """
    Diagnostics for uniformity on [-0.5,0.5]^2:
    - Marginal KS tests
    - Correlation
    - 2D chi-square on a grid
    """
    x = samples_np[:, 0]
    y = samples_np[:, 1]

    # Marginal KS against Uniform(-0.5, 0.5)
    px = kstest(x, 'uniform', args=(-0.5, 1.0)).pvalue
    py = kstest(y, 'uniform', args=(-0.5, 1.0)).pvalue

    # Independence (Pearson correlation ~ 0)
    corr = np.corrcoef(x, y)[0, 1]

    # 2D chi-square
    H, _, _ = np.histogram2d(x, y, bins=grid_n, range=[[-0.5, 0.5], [-0.5, 0.5]])
    H = H.astype(np.float64)
    expected = H.sum() / (grid_n * grid_n)
    chi2_stat = ((H - expected) ** 2 / (expected + 1e-12)).sum()
    dof = grid_n * grid_n - 1
    # Approx p-value via gamma approximation of chi-square (scipy.chi2.sf if desired)
    # We use chisquare on flattened counts with a flat expected:
    chi2_p = chisquare(H.ravel(), f_exp=np.full(H.size, expected)).pvalue

    return {
        "ks_x_p": px, "ks_y_p": py,
        "corr_xy": corr,
        "chi2_stat": float(chi2_stat), "chi2_p": float(chi2_p)
    }

def eval_gaussian_2d(samples_np):
    """
    Diagnostics for N(0, I) in 2D:
    - Mean near 0, covariance near I
    - Marginal normality (KS is weak for normals but indicative)
    """
    mu = samples_np.mean(axis=0)
    cov = np.cov(samples_np.T)
    # Marginal KS tests against N(0,1)
    px = kstest(samples_np[:, 0], 'norm').pvalue
    py = kstest(samples_np[:, 1], 'norm').pvalue
    return {
        "mean": mu,
        "cov": cov,
        "ks_x_p": float(px),
        "ks_y_p": float(py),
        "trace_cov": float(np.trace(cov)),
        "det_cov": float(np.linalg.det(cov))
    }

# =========================
# 2.1: f1: Gaussian -> Uniform([-0.5,0.5]^2)
# =========================
print("Training f1: Gaussian -> Uniform...")
f1, w_f1_hist, f1_losses = train_generator(
    source_sampler=gaussian_2d,
    target_sampler=uniform_2d,
    input_dim=2, output_dim=2,
    epochs=10000, batch_size=512, lr=0.01,
    reg_lambda=0.0, reg_type=None,
    hidden_dim=128, depth=3, activation='relu',
    name="f1"
)

# Evaluate and visualize
with torch.no_grad():
    z = gaussian_2d(4000).to(device)
    y_gen = f1(z).detach().cpu()
y_eval = eval_uniform_2d(to_np(y_gen))

plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.scatter(to_np(z)[:,0], to_np(z)[:,1], s=6, alpha=0.4)
plt.title("Source: Gaussian (2D)")
plt.subplot(1,2,2)
plt.scatter(to_np(y_gen)[:,0], to_np(y_gen)[:,1], s=6, alpha=0.4)
plt.title("Generated: Uniform [-0.5,0.5]^2")
plt.tight_layout()
plt.savefig(os.path.join(OUTDIR, "2_1_gaussian_to_uniform_scatter.png"), dpi=150)
plt.close()

plt.figure()
plt.plot(f1_losses)
plt.title("f1 training: MMD^2 vs epoch")
plt.xlabel("epoch")
plt.ylabel("MMD^2")
plt.tight_layout()
plt.savefig(os.path.join(OUTDIR, "2_1_f1_loss_curve.png"), dpi=150)
plt.close()

print(f"2.1 diagnostics: KS p(x)={y_eval['ks_x_p']:.4f}, KS p(y)={y_eval['ks_y_p']:.4f}, "
      f"corr={y_eval['corr_xy']:.4f}, chi2 p={y_eval['chi2_p']:.4f}")

# =========================
# 2.2: f2: Uniform -> Gaussian (inverse map)
# =========================
print("\nTraining f2: Uniform -> Gaussian...")
f2, w_f2_hist, f2_losses = train_generator(
    source_sampler=uniform_2d,
    target_sampler=gaussian_2d,
    input_dim=2, output_dim=2,
    epochs=8000, batch_size=512, lr=0.01,
    reg_lambda=0.0, reg_type=None,
    hidden_dim=128, depth=3, activation='relu',
    name="f2"
)

with torch.no_grad():
    u = uniform_2d(4000).to(device)
    z_gen = f2(u).detach().cpu()
z_eval = eval_gaussian_2d(to_np(z_gen))

plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.scatter(to_np(u)[:,0], to_np(u)[:,1], s=6, alpha=0.4)
plt.title("Source: Uniform [-0.5,0.5]^2")
plt.subplot(1,2,2)
plt.scatter(to_np(z_gen)[:,0], to_np(z_gen)[:,1], s=6, alpha=0.4)
plt.title("Generated: Gaussian (2D)")
plt.tight_layout()
plt.savefig(os.path.join(OUTDIR, "2_2_uniform_to_gaussian_scatter.png"), dpi=150)
plt.close()

plt.figure()
plt.plot(f2_losses)
plt.title("f2 training: MMD^2 vs epoch")
plt.xlabel("epoch")
plt.ylabel("MMD^2")
plt.tight_layout()
plt.savefig(os.path.join(OUTDIR, "2_2_f2_loss_curve.png"), dpi=150)
plt.close()

# Cycle consistency (sanity)
with torch.no_grad():
    z0 = gaussian_2d(3000).to(device)
    y1 = f1(z0)
    z_rec = f2(y1)
    mse_cycle = torch.mean((z0 - z_rec)**2).item()
print(f"2.2 diagnostics: mean={z_eval['mean']}, trace(cov)={z_eval['trace_cov']:.3f}, "
      f"det(cov)={z_eval['det_cov']:.3f}, KS p(x)={z_eval['ks_x_p']:.4f}, KS p(y)={z_eval['ks_y_p']:.4f}, "
      f"cycle MSE={mse_cycle:.6f}")

# =========================
# 2.3: Regularization analysis for f1
# =========================
print("\n2.3: Regularization analysis on f1...")
reg_grid = {
    "none_0.0": (None, 0.0),
    "l1_0.001": ('l1', 0.001),
    "l1_0.01":  ('l1', 0.01),
    "l1_0.1":   ('l1', 0.1),
    "l2_0.001": ('l2', 0.001),
    "l2_0.01":  ('l2', 0.01),
    "l2_0.1":   ('l2', 0.1),
}
weight_snaps = {}
summary_stats = {}

for name, (rtype, lam) in reg_grid.items():
    print(f"Training f1 with reg={name} ...")
    _, w_hist, _ = train_generator(
        source_sampler=gaussian_2d,
        target_sampler=uniform_2d,
        input_dim=2, output_dim=2,
        epochs=5000, batch_size=512, lr=0.01,
        reg_lambda=lam, reg_type=rtype,
        hidden_dim=128, depth=3, activation='relu',
        name=f"f1_{name}"
    )
    w_last = w_hist[-1] if len(w_hist) > 0 else None
    weight_snaps[name] = w_last
    if w_last is not None:
        summary_stats[name] = {
            "mean": float(w_last.mean()),
            "std": float(w_last.std()),
            "frac_small": float((np.abs(w_last) < 1e-3).mean())
        }

# Plot histograms of final weights
cols = 3
rows = math.ceil(len(weight_snaps) / cols)
plt.figure(figsize=(5*cols, 3.5*rows))
for i, (name, w) in enumerate(weight_snaps.items(), start=1):
    plt.subplot(rows, cols, i)
    if w is not None:
        plt.hist(w, bins=60, alpha=0.8, color="#4472c4")
    plt.title(name)
    plt.xlabel("weight")
    plt.ylabel("freq")
plt.tight_layout()
plt.savefig(os.path.join(OUTDIR, "2_3_regularization_weight_histograms.png"), dpi=150)
plt.close()

print("2.3 summary stats (mean, std, frac(|w|<1e-3)):")
for k, v in summary_stats.items():
    print(f"  {k:10s} -> mean={v['mean']:.4f}, std={v['std']:.4f}, frac_small={v['frac_small']:.3f}")

# =========================
# 2.4: f3: 1D uniform -> 2D Gaussian (deterministic)
# =========================
print("\nTraining f3: 1D Uniform -> 2D Gaussian...")
f3, w_f3_hist, f3_losses = train_generator(
    source_sampler=uniform_1d,
    target_sampler=gaussian_2d,
    input_dim=1, output_dim=2,
    epochs=12000, batch_size=512, lr=0.01,
    reg_lambda=0.0, reg_type=None,
    hidden_dim=128, depth=4, activation='tanh',  # tanh can help smooth 1D mapping
    name="f3"
)

with torch.no_grad():
    u_line = torch.linspace(-0.5, 0.5, 3000).view(-1,1).to(device)
    z_from_line = f3(u_line).detach().cpu()
    z_eval_f3 = eval_gaussian_2d(to_np(z_from_line))

# Visualize mapping and distribution
plt.figure(figsize=(15,4))
plt.subplot(1,3,1)
plt.scatter(to_np(u_line), np.zeros_like(to_np(u_line)), s=6, alpha=0.5)
plt.title("1D Uniform Input (line)")
plt.subplot(1,3,2)
plt.scatter(to_np(z_from_line)[:,0], to_np(z_from_line)[:,1],
            c=to_np(u_line).ravel(), cmap='viridis', s=8)
plt.colorbar(label='input u')
plt.title("2D Output Colored by Input")
plt.subplot(1,3,3)
plt.scatter(to_np(z_from_line)[:,0], to_np(z_from_line)[:,1], s=6, alpha=0.5)
plt.title("Output Distribution (samples from f3(u))")
plt.tight_layout()
plt.savefig(os.path.join(OUTDIR, "2_4_1d_to_2d_mapping.png"), dpi=150)
plt.close()

# Covariance eigenvalues (is it roughly full support in covariance sense?)
cov_f3 = np.cov(to_np(z_from_line).T)
eigs = np.linalg.eigvalsh(cov_f3)

plt.figure()
plt.plot(f3_losses)
plt.title("f3 training: MMD^2 vs epoch")
plt.xlabel("epoch")
plt.ylabel("MMD^2")
plt.tight_layout()
plt.savefig(os.path.join(OUTDIR, "2_4_f3_loss_curve.png"), dpi=150)
plt.close()

print(f"2.4 diagnostics (deterministic): mean={z_eval_f3['mean']}, "
      f"trace(cov)={z_eval_f3['trace_cov']:.3f}, det(cov)={z_eval_f3['det_cov']:.6f}, "
      f"KS p(x)={z_eval_f3['ks_x_p']:.4f}, KS p(y)={z_eval_f3['ks_y_p']:.4f}, "
      f"cov eigenvalues={eigs}")

print("\nAll plots saved in:", OUTDIR)

Training f1: Gaussian -> Uniform...
[f1] epoch     1/10000  MMD2=0.120015
[f1] epoch   500/10000  MMD2=-0.000145
[f1] epoch  1000/10000  MMD2=-0.000278
[f1] epoch  1500/10000  MMD2=0.000973
[f1] epoch  2000/10000  MMD2=-0.000392
[f1] epoch  2500/10000  MMD2=0.000000
[f1] epoch  3000/10000  MMD2=-0.000061
[f1] epoch  3500/10000  MMD2=0.001112
[f1] epoch  4000/10000  MMD2=-0.000137
[f1] epoch  4500/10000  MMD2=-0.000580
[f1] epoch  5000/10000  MMD2=-0.000104
[f1] epoch  5500/10000  MMD2=0.000086
[f1] epoch  6000/10000  MMD2=0.000386
[f1] epoch  6500/10000  MMD2=-0.000265
[f1] epoch  7000/10000  MMD2=0.001418
[f1] epoch  7500/10000  MMD2=0.000404
[f1] epoch  8000/10000  MMD2=0.000533
[f1] epoch  8500/10000  MMD2=-0.000335
[f1] epoch  9000/10000  MMD2=-0.000491
[f1] epoch  9500/10000  MMD2=-0.000003
[f1] epoch 10000/10000  MMD2=-0.000011
2.1 diagnostics: KS p(x)=0.0343, KS p(y)=0.0098, corr=0.0169, chi2 p=0.0000

Training f2: Uniform -> Gaussian...
[f2] epoch     1/8000  MMD2=0.401008
[f2]