In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import make_grid, save_image
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import os
from pathlib import Path

# Configuración
IMG_SIZE = 64
BATCH_SIZE = 128
LATENT_DIM = 100
EPOCHS = 100
LR_G = 0.0002
LR_D = 0.0002
BETA1 = 0.5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Crear directorio para imágenes generadas
os.makedirs("generated_images", exist_ok=True)

# Dataset personalizado
class PortraitDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.images = list(self.root_dir.glob("**/*.jpg")) + \
                      list(self.root_dir.glob("**/*.png")) + \
                      list(self.root_dir.glob("**/*.jpeg"))
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image

# Transformaciones
transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

# Inicialización de pesos
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

# Generador
class Generator(nn.Module):
    def __init__(self, latent_dim=100):
        super(Generator, self).__init__()
        
        self.main = nn.Sequential(
            # Entrada: latent_dim x 1 x 1
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            # Estado: 512 x 4 x 4
            
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # Estado: 256 x 8 x 8
            
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # Estado: 128 x 16 x 16
            
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            # Estado: 64 x 32 x 32
            
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh()
            # Salida: 3 x 64 x 64
        )
    
    def forward(self, z):
        return self.main(z)

# Discriminador
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.main = nn.Sequential(
            # Entrada: 3 x 64 x 64
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # Estado: 64 x 32 x 32
            
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            # Estado: 128 x 16 x 16
            
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            # Estado: 256 x 8 x 8
            
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            # Estado: 512 x 4 x 4
            
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
            # Salida: 1 x 1 x 1
        )
    
    def forward(self, x):
        return self.main(x).view(-1, 1).squeeze(1)

# Función de entrenamiento
def train_gan(data_path, epochs=EPOCHS):
    # Cargar datos
    dataset = PortraitDataset(data_path, transform=transform)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, 
                           shuffle=True, num_workers=2, drop_last=True)
    
    # Modelos
    G = Generator(LATENT_DIM).to(DEVICE)
    D = Discriminator().to(DEVICE)
    
    G.apply(weights_init)
    D.apply(weights_init)
    
    # Optimizadores
    optimizer_G = optim.Adam(G.parameters(), lr=LR_G, betas=(BETA1, 0.999))
    optimizer_D = optim.Adam(D.parameters(), lr=LR_D, betas=(BETA1, 0.999))
    
    # Pérdida
    criterion = nn.BCELoss()
    
    # Ruido fijo para visualización
    fixed_noise = torch.randn(64, LATENT_DIM, 1, 1, device=DEVICE)
    
    # Historiales
    G_losses = []
    D_losses = []
    img_list = []
    
    print(f"Iniciando entrenamiento en {DEVICE}")
    print(f"Dataset: {len(dataset)} imágenes")
    
    for epoch in range(epochs):
        for i, real_imgs in enumerate(dataloader):
            batch_size = real_imgs.size(0)
            real_imgs = real_imgs.to(DEVICE)
            
            # Etiquetas
            real_labels = torch.ones(batch_size, device=DEVICE)
            fake_labels = torch.zeros(batch_size, device=DEVICE)
            
            # =============== Entrenar Discriminador ===============
            optimizer_D.zero_grad()
            
            # Pérdida con imágenes reales
            output_real = D(real_imgs)
            loss_D_real = criterion(output_real, real_labels)
            
            # Pérdida con imágenes falsas
            z = torch.randn(batch_size, LATENT_DIM, 1, 1, device=DEVICE)
            fake_imgs = G(z)
            output_fake = D(fake_imgs.detach())
            loss_D_fake = criterion(output_fake, fake_labels)
            
            # Pérdida total del discriminador
            loss_D = loss_D_real + loss_D_fake
            loss_D.backward()
            optimizer_D.step()
            
            # =============== Entrenar Generador ===============
            optimizer_G.zero_grad()
            
            # El generador quiere que el discriminador clasifique sus imágenes como reales
            output = D(fake_imgs)
            loss_G = criterion(output, real_labels)
            
            loss_G.backward()
            optimizer_G.step()
            
            # Guardar pérdidas
            if i % 50 == 0:
                print(f"[Época {epoch+1}/{epochs}] [Batch {i}/{len(dataloader)}] "
                      f"[D loss: {loss_D.item():.4f}] [G loss: {loss_G.item():.4f}]")
        
        # Guardar pérdidas por época
        G_losses.append(loss_G.item())
        D_losses.append(loss_D.item())
        
        # Generar y guardar imágenes CADA época
        with torch.no_grad():
            G.eval()
            fake = G(fixed_noise).detach().cpu()
            G.train()
            
            # Guardar grid de imágenes
            save_image(fake[:16], 
                      f"generated_images/epoch_{epoch+1:04d}.png",
                      nrow=4, 
                      normalize=True,
                      padding=2)
            
            # Guardar para visualización final
            img_list.append(fake)
            
            print(f"✓ Imagen guardada: generated_images/epoch_{epoch+1:04d}.png")
    
    return G, D, G_losses, D_losses, img_list

# Función para visualizar resultados
def plot_results(G_losses, D_losses, img_list):
    # Gráfica de pérdidas
    plt.figure(figsize=(10, 5))
    plt.title("Pérdidas del Generador y Discriminador")
    plt.plot(G_losses, label="Generador", alpha=0.8)
    plt.plot(D_losses, label="Discriminador", alpha=0.8)
    plt.xlabel("Época")
    plt.ylabel("Pérdida")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig("losses.png", dpi=150, bbox_inches='tight')
    plt.show()
    
    # Análisis de estabilidad
    plt.figure(figsize=(10, 5))
    window = 5
    G_smoothed = np.convolve(G_losses, np.ones(window)/window, mode='valid')
    D_smoothed = np.convolve(D_losses, np.ones(window)/window, mode='valid')
    
    plt.title("Estabilidad del Entrenamiento (Suavizado)")
    plt.plot(G_smoothed, label="Generador (suavizado)", linewidth=2)
    plt.plot(D_smoothed, label="Discriminador (suavizado)", linewidth=2)
    plt.xlabel("Época")
    plt.ylabel("Pérdida")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig("stability.png", dpi=150, bbox_inches='tight')
    plt.show()
    
    # Evolución de imágenes generadas (muestreo cada 5 épocas)
    fig = plt.figure(figsize=(15, 8))
    epochs_to_show = min(20, len(img_list))
    step = max(1, len(img_list) // epochs_to_show)
    
    for idx, epoch_idx in enumerate(range(0, len(img_list), step)):
        if idx >= epochs_to_show:
            break
        ax = plt.subplot(4, 5, idx + 1)
        imgs = img_list[epoch_idx]
        grid = make_grid(imgs[:16], nrow=4, normalize=True, padding=2)
        ax.imshow(np.transpose(grid, (1, 2, 0)))
        ax.set_title(f"Época {epoch_idx + 1}")
        ax.axis('off')
    
    plt.suptitle("Evolución de Imágenes Generadas (Muestreo)", fontsize=16)
    plt.tight_layout()
    plt.savefig("evolution.png", dpi=150, bbox_inches='tight')
    plt.show()

# Función principal
def main():
    # IMPORTANTE: Cambia esta ruta a donde tienes el dataset descargado
    data_path = "../dataset/resized_together_images"  # Ajusta esta ruta
    
    print("=" * 60)
    print("DCGAN para Generación de Retratos 64x64")
    print("=" * 60)
    
    # Entrenar
    G, D, G_losses, D_losses, img_list = train_gan(data_path)
    
    # Visualizar resultados
    print("\nGenerando visualizaciones...")
    plot_results(G_losses, D_losses, img_list)
    
    # Guardar modelos
    torch.save(G.state_dict(), "generator.pth")
    torch.save(D.state_dict(), "discriminator.pth")
    print("\nModelos guardados: generator.pth y discriminator.pth")
    
    # Generar muestras adicionales
    print("\nGenerando muestras finales...")
    G.eval()
    with torch.no_grad():
        z = torch.randn(64, LATENT_DIM, 1, 1, device=DEVICE)
        samples = G(z).cpu()
    
    fig = plt.figure(figsize=(12, 12))
    grid = make_grid(samples, nrow=8, normalize=True, padding=2)
    plt.imshow(np.transpose(grid, (1, 2, 0)))
    plt.title("Muestras Finales Generadas", fontsize=16)
    plt.axis('off')
    plt.tight_layout()
    plt.savefig("final_samples.png", dpi=150, bbox_inches='tight')
    plt.show()
    
    print("\n¡Entrenamiento completado!")
    print("Archivos generados:")
    print("  - generated_images/: Imágenes por época")
    print("  - losses.png: Gráfica de pérdidas")
    print("  - stability.png: Análisis de estabilidad")
    print("  - evolution.png: Evolución de imágenes")
    print("  - final_samples.png: Muestras finales")

if __name__ == "__main__":
    main()
