In [None]:
# 1) Instalo los paquetes necesarios y los importo
#%pip install --quiet torch torchvision tqdm pillow gradio sentence-transformers pytorch-fid accelerate

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from torchvision.utils import make_grid, save_image
from tqdm import tqdm

print('PyTorch version:', torch.__version__)
print('CUDA available:', torch.cuda.is_available())

In [None]:
# 2) Asigno directorios de las imagenes de entrenamiento, checkpoints y salidas

BASE_DIR = r'c:\Users\Walter\Desktop\IA_GAN_labels_project'
DATA_DIR = os.path.join(BASE_DIR, 'data/labels')
CKPT_DIR = os.path.join(BASE_DIR, 'checkpoints')
OUT_DIR = os.path.join(BASE_DIR, 'outputs')

os.makedirs(CKPT_DIR, exist_ok=True)
os.makedirs(OUT_DIR, exist_ok=True)
print('BASE_DIR =', BASE_DIR)


In [None]:
# 3) Dataset y DataLoader 
IMG_SIZE = 128  
BATCH_SIZE = 32
z_dim = 128

transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.CenterCrop(IMG_SIZE),
    transforms.RandomApply([transforms.ColorJitter(0.2,0.2,0.2)], p=0.8),
    transforms.RandomRotation(3),
    transforms.RandomPerspective(distortion_scale=0.05, p=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)  
])

if not os.path.exists(DATA_DIR):
    print(f"[!] DATA_DIR not found: {DATA_DIR}. Please upload your dataset to Drive before running this cell.")
else:
    dataset = datasets.ImageFolder(DATA_DIR, transform=transform)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
    print('Dataset loaded. Classes:', dataset.classes)
    print('Dataset size:', len(dataset))


In [None]:
# 4) Modelos: Definici√≥n del Generador y Cr√≠tico (estilo WGAN-GP)
import torch.nn as nn

def conv_block(in_c, out_c, k=4, s=2, p=1, batchnorm=True):
    layers = [nn.Conv2d(in_c, out_c, k, s, p, bias=False)]
    if batchnorm:
        layers.append(nn.BatchNorm2d(out_c))
    layers.append(nn.LeakyReLU(0.2, inplace=True))
    return nn.Sequential(*layers)

class Critic(nn.Module):
    def __init__(self, in_channels=3, base_feat=64):
        super().__init__()
        self.net = nn.Sequential(
            # [3, 128, 128] ‚Üí [64, 64, 64]
            nn.Conv2d(in_channels, base_feat, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            # [64, 64, 64] ‚Üí [128, 32, 32]
            nn.Conv2d(base_feat, base_feat*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(base_feat*2),
            nn.LeakyReLU(0.2, inplace=True),

            # [128, 32, 32] ‚Üí [256, 16, 16]
            nn.Conv2d(base_feat*2, base_feat*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(base_feat*4),
            nn.LeakyReLU(0.2, inplace=True),

            # [256, 16, 16] ‚Üí [512, 8, 8]
            nn.Conv2d(base_feat*4, base_feat*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(base_feat*8),
            nn.LeakyReLU(0.2, inplace=True),

            # [512, 8, 8] ‚Üí [1, 1, 1]
            nn.Conv2d(base_feat*8, 1, 8, 1, 0, bias=False)
        )

    def forward(self, x):
        out = self.net(x)
        return out.view(x.size(0))  


class Generator(nn.Module):
    def __init__(self, z_dim=128, out_channels=3, base_feat=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(z_dim, base_feat*8, 4, 1, 0, bias=False),  
            nn.BatchNorm2d(base_feat*8),
            nn.ReLU(True),

            nn.ConvTranspose2d(base_feat*8, base_feat*4, 4, 2, 1, bias=False),  
            nn.BatchNorm2d(base_feat*4),
            nn.ReLU(True),

            nn.ConvTranspose2d(base_feat*4, base_feat*2, 4, 2, 1, bias=False),  
            nn.BatchNorm2d(base_feat*2),
            nn.ReLU(True),

            nn.ConvTranspose2d(base_feat*2, base_feat, 4, 2, 1, bias=False),    
            nn.BatchNorm2d(base_feat),
            nn.ReLU(True),

            nn.ConvTranspose2d(base_feat, base_feat//2, 4, 2, 1, bias=False),   
            nn.BatchNorm2d(base_feat//2),
            nn.ReLU(True),

            nn.ConvTranspose2d(base_feat//2, out_channels, 4, 2, 1, bias=False),  
            nn.Tanh()
        )

    def forward(self, z):
        return self.net(z)


# Verificaci√≥n r√°pida de la cantidad de par√°metros
z_dim = 128
gen = Generator(z_dim=z_dim)
critic = Critic()
print('Generator params:', sum(p.numel() for p in gen.parameters()))
print('Critic params:', sum(p.numel() for p in critic.parameters()))


In [None]:
# 5) Utilidades de entrenamiento y bucle de entrenamiento
import torch.autograd as autograd
from torch.cuda.amp import autocast, GradScaler

critic.to('cuda')
gen.to('cuda')  


def gradient_penalty(critic, real, fake, device):
    bs = real.size(0)
    eps = torch.rand(bs, 1, 1, 1, device=device)
    inter = eps * real + (1 - eps) * fake
    inter.requires_grad_(True)
    out = critic(inter)
    grad = torch.autograd.grad(outputs=out, inputs=inter,
                               grad_outputs=torch.ones_like(out),
                               create_graph=True, retain_graph=True, only_inputs=True)[0]
    grad = grad.view(bs, -1)
    gp = ((grad.norm(2, dim=1) - 1)**2).mean()
    return gp




def train_wgangp(gen, critic, dataloader, epochs=400, z_dim=128, device='cuda'):
    gen.to(device); critic.to(device)
    opt_g = optim.Adam(gen.parameters(), lr=1e-4, betas=(0.5, 0.9))
    opt_c = optim.Adam(critic.parameters(), lr=1e-4, betas=(0.5, 0.9))
    scaler_g = GradScaler(); scaler_c = GradScaler()
    fixed_z = torch.randn(64, z_dim, 1, 1, device=device)

    os.makedirs(CKPT_DIR, exist_ok=True); os.makedirs(OUT_DIR, exist_ok=True)

    step = 0
    for epoch in range(epochs):
        pbar = tqdm(dataloader, desc=f'Epoch {epoch}')
        for i, (real, _) in enumerate(pbar):
            real = real.to(device)
            bs = real.size(0)

            
            for _ in range(5):
                z = torch.randn(bs, z_dim, 1, 1, device=device)
                with autocast():
                    fake = gen(z).detach()
                    c_real = critic(real)
                    c_fake = critic(fake)
                    gp = gradient_penalty(critic, real, fake, device)
                    loss_c = -(torch.mean(c_real) - torch.mean(c_fake)) + 10.0 * gp
                opt_c.zero_grad(); scaler_c.scale(loss_c).backward(); scaler_c.step(opt_c); scaler_c.update()

            
            z = torch.randn(bs, z_dim, 1, 1, device=device)
            with autocast():
                fake = gen(z)
                loss_g = -torch.mean(critic(fake))
            opt_g.zero_grad(); scaler_g.scale(loss_g).backward(); scaler_g.step(opt_g); scaler_g.update()

            if step % 200 == 0:
                with torch.no_grad():
                    fake_fixed = gen(fixed_z).detach().cpu()
                    save_image(fake_fixed, os.path.join(OUT_DIR, f'fake_ep{epoch}_step{step}.png'), nrow=8, normalize=True)
            step += 1

        # checkpoint por epocas
        torch.save({'gen': gen.state_dict(), 'critic': critic.state_dict()}, os.path.join(CKPT_DIR, f'ckpt_epoch_{epoch}.pt'))
        print(f"[Epoch {epoch}] checkpoint saved to {CKPT_DIR}")


train_wgangp(gen, critic, dataloader)

In [None]:
# 6) Mapeo de texto a espacio latente usando Sentence Transformers
from sentence_transformers import SentenceTransformer

class TextToLatent(nn.Module):
    def __init__(self, text_dim=384, z_dim=128):
        super().__init__()
        self.map = nn.Sequential(nn.Linear(text_dim, 256), nn.ReLU(), nn.Linear(256, z_dim))
        self.encoder = SentenceTransformer('all-MiniLM-L6-v2')
    def forward(self, text):
        with torch.no_grad():
            emb = self.encoder.encode([text], convert_to_tensor=True)
        z = self.map(emb)
        return z.unsqueeze(-1).unsqueeze(-1)

mapper = TextToLatent(text_dim=384, z_dim=z_dim)
print('Mapper created. Note: mapper is untrained; consider training with text-image pairs for better results.')


In [None]:
# 7) Interfaz Gradio para texto->imagen
import gradio as gr
import os

def load_generator(checkpoint_path, z_dim=128):
    gen = Generator(z_dim=z_dim)
    ckpt = torch.load(checkpoint_path, map_location='cpu')
    gen.load_state_dict(ckpt['gen'])
    return gen

def generate_from_text(prompt, gen, mapper, n=4, device='cuda'):
    gen.to(device); gen.eval()
    mapper.to(device)
    with torch.no_grad():
        z = mapper(prompt).to(device)
        z = z.repeat(n, 1, 1, 1)
        imgs = gen(z)
        imgs = (imgs + 1) / 2.0
        grid = make_grid(imgs, nrow=2)
        return grid.permute(1,2,0).cpu().numpy()


def launch_gradio_interface(checkpoint_name='ckpt_epoch_111.pt'):
    checkpoint_path = os.path.join(CKPT_DIR, checkpoint_name)
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f'Checkpoint not found: {checkpoint_path}. Train first or upload a checkpoint.')
    gen = load_generator(checkpoint_path, z_dim=z_dim)
    mapper = TextToLatent(text_dim=384, z_dim=z_dim)
    iface = gr.Interface(fn=lambda prompt: generate_from_text(prompt, gen, mapper, n=4, device='cuda' if torch.cuda.is_available() else 'cpu'),
                         inputs=gr.Textbox(lines=2, placeholder='Ej: etiqueta de vino elegante, fondo negro, tipograf√≠a dorada'),
                         outputs=gr.Image(type='numpy'),
                         title='Generador de etiquetas (GAN + Texto)',
                         description='Describe la etiqueta que quer√©s generar. Usa t√©rminos visuales: colores, composici√≥n, estilo.')
    iface.launch(share=True)




In [None]:
launch_gradio_interface('ckpt_epoch_111.pt')

In [None]:
# ==========================================================
# üìä Evaluaci√≥n de la GAN - Curvas, FID, Ejemplos, Evaluaci√≥n Humana
# ==========================================================
import os
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid, save_image
from torchvision import datasets, transforms
from pytorch_fid import fid_score
import pandas as pd
from datetime import datetime

# ==========================
# 1. üìà Curvas de p√©rdida
# ==========================
def plot_losses(losses_g, losses_c, out_dir):
    plt.figure(figsize=(10,5))
    plt.title("Curvas de p√©rdida - Generador vs Cr√≠tico")
    plt.plot(losses_g, label="Generator Loss")
    plt.plot(losses_c, label="Critic Loss")
    plt.xlabel("Step")
    plt.ylabel("Loss")
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(out_dir, "training_losses.png"))
    plt.close()
    print("‚úÖ Gr√°fico de p√©rdidas guardado en:", out_dir)

# ==========================
# 2. üß© Muestras visuales por epoch
# ==========================
def save_sample_images(gen, fixed_z, epoch, out_dir):
    gen.eval()
    with torch.no_grad():
        fake_images = gen(fixed_z).detach().cpu()
    grid = make_grid(fake_images, nrow=8, normalize=True)
    save_path = os.path.join(out_dir, f"samples_epoch_{epoch}.png")
    save_image(grid, save_path)
    print(f"‚úÖ Imagen generada guardada: {save_path}")
    gen.train()

# ==========================
# 3. üìè FID Score (real vs generado)
# ==========================
def compute_fid(real_dir, fake_dir, device='cuda'):
    fid_value = fid_score.calculate_fid_given_paths([real_dir, fake_dir],
                                                    batch_size=32,
                                                    device=device,
                                                    dims=2048)
    print(f"‚úÖ FID Score: {fid_value:.2f}")
    return fid_value

# ==========================
# 4. üßç Evaluaci√≥n humana
# ==========================
def prepare_human_eval(fake_dir, n_samples=10, out_csv="human_eval_list.csv"):
    imgs = sorted([f for f in os.listdir(fake_dir) if f.endswith(".png")])
    sample_imgs = imgs[:n_samples]
    df = pd.DataFrame({
        "imagen": sample_imgs,
        "plausibilidad(1-5)": ["" for _ in range(n_samples)],
        "estetica(1-5)": ["" for _ in range(n_samples)]
    })
    df.to_csv(os.path.join(fake_dir, out_csv), index=False)
    print(f"‚úÖ CSV generado para evaluaci√≥n humana: {out_csv}")

# ==========================
# 5. üöÄ Ejemplo de uso
# ==========================
OUT_DIR = "/content/drive/MyDrive/IA_GAN_labels_project/outputs"
CKPT_DIR = "/content/drive/MyDrive/IA_GAN_labels_project/checkpoints"
REAL_DIR = "/content/drive/MyDrive/IA_GAN_labels_project/data/labels/beer"  # o wine / coffee

# Cargar √∫ltimo checkpoint
latest_ckpt = sorted(os.listdir(CKPT_DIR))[-1]
ckpt_path = os.path.join(CKPT_DIR, latest_ckpt)
checkpoint = torch.load(ckpt_path, map_location='cuda')
gen.load_state_dict(checkpoint["gen"])
print(f"‚úÖ Checkpoint cargado: {ckpt_path}")

# Generar muestras y graficar p√©rdidas (asume que guardaste las listas en el entrenamiento)
fixed_z = torch.randn(64, 128, 1, 1, device='cuda')
save_sample_images(gen, fixed_z, epoch="final", out_dir=OUT_DIR)

# Si guardaste losses durante el entrenamiento
# plot_losses(losses_g, losses_c, OUT_DIR)

# Calcular FID (usando real vs generado)
FAKE_DIR = os.path.join(OUT_DIR, "fid_fake_samples")
os.makedirs(FAKE_DIR, exist_ok=True)
with torch.no_grad():
    for i in range(100):
        z = torch.randn(1, 128, 1, 1, device='cuda')
        fake = gen(z).detach().cpu()
        save_image(fake, os.path.join(FAKE_DIR, f"fake_{i}.png"), normalize=True)

fid = compute_fid(REAL_DIR, FAKE_DIR)

# Preparar CSV para evaluaci√≥n humana
prepare_human_eval(FAKE_DIR, n_samples=10)
