# Lab 3: Implement GAN from Pseudocode (Goodfellow)

In [None]:
# --- Setup / hyperparameters ---
latent_dim = 100
batch_size = 64
num_steps = 100000
lr_G = 0.0002
lr_D = 0.0002
beta1 = 0.5        # optimizer betas for Adam
n_critic = 1       # number of D updates per G update (often >1 for WGAN)
use_non_saturating = True   # if True use -log(D(G)) for G (better gradients)
use_wgan_gp = False         # set True to use WGAN-GP variant
gp_lambda = 10.0            # gradient penalty weight for WGAN-GP
device = "cuda"

# Initialize models and optimizers (pseudocode)
G = Generator(latent_dim).to(device)
D = Discriminator().to(device)
opt_G = Adam(G.parameters(), lr=lr_G, betas=(beta1, 0.999))
opt_D = Adam(D.parameters(), lr=lr_D, betas=(beta1, 0.999))

# Utility functions (pseudocode)
def sample_noise(batch_size, latent_dim):
    return random_normal(batch_size, latent_dim)   # e.g. N(0,1)

def sample_real_batch(dataset, batch_size):
    return dataset.next_batch(batch_size)

def detach(tensor):
    # means: no gradient flows back
    return stop_gradient(tensor)

# --- Training loop ---
for step in range(num_steps):

    # ---- Update Discriminator ----
    for _ in range(n_critic):
        real_x = sample_real_batch(dataset, batch_size).to(device)
        z = sample_noise(batch_size, latent_dim).to(device)
        fake_x = G(z)                         # generated samples

        # Discriminator outputs (probabilities or scores)
        D_real = D(real_x)                    # shape: (batch,)
        D_fake = D(detach(fake_x))            # detach so D step doesn't update G

        if use_wgan_gp:
            # WGAN-GP: D outputs real-valued scores (no sigmoid)
            loss_D = mean(D_fake) - mean(D_real)   # Wasserstein critic loss

            # compute gradient penalty
            eps = random_uniform(0,1, size=batch_size).reshape(-1,1,...)
            x_hat = eps * real_x + (1-eps) * fake_x
            D_x_hat = D(x_hat)
            grad = grad_of(D_x_hat, x_hat)     # gradient of D(x_hat) wrt x_hat
            grad_norm = sqrt(sum_over_dims(grad**2) + 1e-12)
            gp = gp_lambda * mean((grad_norm - 1.0)**2)
            loss_D = loss_D + gp

        else:
            # Vanilla GAN: D outputs probabilities (after sigmoid)
            # Standard minimax discriminator loss:
            # loss_D = - mean(log(D_real)) - mean(log(1 - D_fake))
            # In practice use stable numerics / BCE loss function
            loss_D_real = binary_cross_entropy(D_real, ones_like(D_real))
            loss_D_fake = binary_cross_entropy(D_fake, zeros_like(D_fake))
            loss_D = loss_D_real + loss_D_fake

        # Backprop and optimizer step for D
        opt_D.zero_grad()
        loss_D.backward()
        opt_D.step()

    # ---- Update Generator ----
    z = sample_noise(batch_size, latent_dim).to(device)
    fake_x = G(z)
    D_fake_for_G = D(fake_x)    # do NOT detach: gradients should flow to G

    if use_wgan_gp:
        # WGAN generator objective: minimize -E[D(G(z))]
        loss_G = - mean(D_fake_for_G)

    else:
        if use_non_saturating:
            # Non-saturating heuristic:
            # loss_G = - mean(log(D(G(z)))) which gives stronger gradients early
            loss_G = binary_cross_entropy(D_fake_for_G, ones_like(D_fake_for_G))
        else:
            # Minimax (saturating) version (less commonly used in practice):
            # loss_G = mean(log(1 - D(G(z))))
            loss_G = binary_cross_entropy(D_fake_for_G, zeros_like(D_fake_for_G))  # conceptual

    # Backprop and optimizer step for G
    opt_G.zero_grad()
    loss_G.backward()
    opt_G.step()

    # --- Logging & periodic evaluation ---
    if step % log_interval == 0:
        print("step", step, "loss_D", loss_D.item(), "loss_G", loss_G.item())
    if step % sample_interval == 0:
        z_vis = sample_noise(visual_batch, latent_dim)
        imgs = G(z_vis)                     # generate images for monitoring
        save_images(imgs, f"samples/step_{step}.png")
    if step % checkpoint_interval == 0:
        save_model(G, f"G_{step}.pt")
        save_model(D, f"D_{step}.pt")
