# Práctica 5: Simple GAN en MNIST

Plantilla limpia con: (1) entrenamiento GAN, (2) gráficas de evolución de pérdidas D/G y (3) cálculo de métricas básicas.

In [None]:
# Configuración e imports
import torch, torch.nn as nn, torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np, time

torch.manual_seed(42); np.random.seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_set = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)

In [None]:
# Modelos: Generator y Discriminator (MLP)
class Generator(nn.Module):
    def __init__(self, z_dim=100, img_dim=28*28):
        super().__init__(); self.net = nn.Sequential(
            nn.Linear(z_dim, 256), nn.ReLU(True),
            nn.Linear(256, 512), nn.ReLU(True),
            nn.Linear(512, img_dim), nn.Tanh()
        )
    def forward(self, z): return self.net(z)

class Discriminator(nn.Module):
    def __init__(self, img_dim=28*28):
        super().__init__(); self.net = nn.Sequential(
            nn.Linear(img_dim, 512), nn.LeakyReLU(0.2),
            nn.Linear(512, 256), nn.LeakyReLU(0.2),
            nn.Linear(256, 1), nn.Sigmoid()
        )
    def forward(self, x): return self.net(x)

z_dim = 100; G = Generator(z_dim).to(device); D = Discriminator().to(device)
print('Params G:', sum(p.numel() for p in G.parameters() if p.requires_grad))
print('Params D:', sum(p.numel() for p in D.parameters() if p.requires_grad))

In [None]:
# Entrenamiento GAN + registro de métricas
num_epochs = 25; lr = 2e-4; beta1 = 0.5
opt_G = optim.Adam(G.parameters(), lr=lr, betas=(beta1, 0.999))
opt_D = optim.Adam(D.parameters(), lr=lr, betas=(beta1, 0.999))
criterion = nn.BCELoss()
history = {'g': [], 'd_total': [], 'd_real': [], 'd_fake': [], 'd_acc_real': [], 'd_acc_fake': []}

for epoch in range(num_epochs):
    g_loss_e = d_real_e = d_fake_e = acc_real_e = acc_fake_e = 0.0
    for x,_ in train_loader:
        b = x.size(0); x = x.view(b, -1).to(device)
        real = torch.ones(b,1, device=device); fake = torch.zeros(b,1, device=device)
        # D: reales
        opt_D.zero_grad(); out_r = D(x); loss_r = criterion(out_r, real); acc_r = ((out_r>0.5).float()==real).float().mean()
        # D: falsas
        z = torch.randn(b, z_dim, device=device); x_fake = G(z).detach(); out_f = D(x_fake)
        loss_f = criterion(out_f, fake); acc_f = ((out_f>0.5).float()==fake).float().mean()
        loss_d = 0.5*(loss_r + loss_f); loss_d.backward(); opt_D.step()
        # G
        opt_G.zero_grad(); z = torch.randn(b, z_dim, device=device); x_fake = G(z); out = D(x_fake)
        loss_g = criterion(out, real); loss_g.backward(); opt_G.step()
        # acumular
        g_loss_e += loss_g.item(); d_real_e += loss_r.item(); d_fake_e += loss_f.item(); acc_real_e += acc_r.item(); acc_fake_e += acc_f.item()
    n = len(train_loader); history['g'].append(g_loss_e/n); history['d_real'].append(d_real_e/n); history['d_fake'].append(d_fake_e/n)
    history['d_total'].append(0.5*(history['d_real'][-1]+history['d_fake'][-1])); history['d_acc_real'].append(acc_real_e/n); history['d_acc_fake'].append(acc_fake_e/n)
    if (epoch+1)%5==0 or epoch==0: print(f"Epoch {epoch+1}/{num_epochs} | G:{history['g'][-1]:.4f} D:{history['d_total'][-1]:.4f} Dr:{history['d_acc_real'][-1]:.3f} Df:{history['d_acc_fake'][-1]:.3f}")

In [None]:
# Gráficas de evolución de pérdidas y accuracies
epochs = range(1, len(history['g'])+1)
fig, axes = plt.subplots(2,2, figsize=(12,9))
axes[0,0].plot(epochs, history['g'], label='G'); axes[0,0].plot(epochs, history['d_total'], label='D'); axes[0,0].set_title('Loss G y D'); axes[0,0].legend(); axes[0,0].grid(True,alpha=0.3)
axes[0,1].plot(epochs, history['d_real'], label='D real'); axes[0,1].plot(epochs, history['d_fake'], label='D fake'); axes[0,1].set_title('Loss D (real vs fake)'); axes[0,1].legend(); axes[0,1].grid(True,alpha=0.3)
axes[1,0].plot(epochs, history['d_acc_real'], label='Acc real'); axes[1,0].plot(epochs, history['d_acc_fake'], label='Acc fake'); axes[1,0].set_title('Accuracy D'); axes[1,0].legend(); axes[1,0].grid(True,alpha=0.3)
axes[1,1].axis('off')
plt.tight_layout(); plt.show()

In [None]:
# Métricas del modelo (D y G) y muestras generadas
test_set = datasets.MNIST('./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_set, batch_size=1000, shuffle=True)
x_real,_ = next(iter(test_loader)); x_real = x_real.view(x_real.size(0), -1).to(device)
with torch.no_grad():
    out_r = D(x_real); acc_r = ((out_r>0.5).float()==torch.ones_like(out_r)).float().mean().item()
    z = torch.randn(1000, z_dim, device=device); x_fake = G(z); out_f = D(x_fake); acc_f = ((out_f>0.5).float()==torch.zeros_like(out_f)).float().mean().item()
    fool_rate = (out_f>0.5).float().mean().item(); g_loss = nn.BCELoss()(out_f, torch.ones_like(out_f)).item()
print(f'D accuracy real: {acc_r:.4f} | D accuracy fake: {acc_f:.4f} | G fool rate: {fool_rate:.4f} | G loss: {g_loss:.4f}')

# Grid de muestras
with torch.no_grad(): imgs = G(torch.randn(25, z_dim, device=device)).cpu().view(-1,28,28); imgs = (imgs+1)/2
fig, axes = plt.subplots(5,5, figsize=(8,8))
for i in range(25): axes[i//5, i%5].imshow(imgs[i], cmap='gray'); axes[i//5, i%5].axis('off')
plt.tight_layout(); plt.show()