In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [12]:
%load_ext autoreload
%autoreload 2

import warnings
warnings.filterwarnings("ignore")

import numpy as np
import json
import sys, os
sys.path.insert(0, os.path.abspath("diffusion"))
sys.path.insert(0, os.path.abspath("utils"))
sys.path.insert(0, os.path.abspath("models"))
sys.path.insert(0, os.path.abspath("samplers"))

from functools import partial

import torch
from torch.utils.data import (
    DataLoader,
    Dataset,
    Subset,
)

from torch.optim import Adam

from diffusion.schedules import LinearSchedule, CosineSchedule, NoiseSchedule
from diffusion.sde       import VESDE, VPSDE, SubVPSDE

from diffusion_utilities import (
    plot_image_grid,
    plot_image_evolution,
    animation_images,
)

import euler_maruyama, predictor_corrector, probability_flow_ode, exponential_integrator
from models  import score_net
import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import math
import os
from tqdm.auto import tqdm # Para barra de progreso
from abc import ABC, abstractmethod # Para clases abstractas
# Asumiendo que ScoreNet y sus bloques están definidos en otro lugar o arriba
# from models.score_net import GaussianRandomFourierFeatures


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:

# --- Configuraciones ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMG_SIZE = 32
IMG_CHANNELS = 3
NUM_CLASSES = 10 # CIFAR-10 tiene 10 clases
BATCH_SIZE = 1024 # Ajustar según memoria de GPU
LEARNING_RATE = 4*1e-4 # Puede necesitar ajuste para el clasificador
N_EPOCHS = 50 # Número de épocas de entrenamiento para el clasificador
T_END = 1.0 # Tiempo final para la SDE (DEBE SER EL MISMO QUE PARA SCORENET)
SDE_TYPE = 'VP' # 'VE' o 'VP' o 'SubVP'
SCHEDULE_TYPE = 'Linear' # 'Linear' o 'Cosine'
BETA_MIN = 0.1 # Mismos parámetros SDE/Schedule
BETA_MAX = 20.0
SIGMA_MIN = 0.01
SIGMA_MAX = 50.0
SIGMA = 25.0
CLASSIFIER_CHECKPOINT_DIR = './checkpoints_classifier'
CHECKPOINT_FREQ = 5 # Guardar checkpoint cada N épocas
EPS_T_SAMPLING = 1e-5 # Epsilon para muestrear t
S = 0.008 # Parametro para cosine Schedule

In [14]:
# --- Dataset ---
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normaliza a [-1, 1]
])

# Cargar dataset de entrenamiento CON etiquetas
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)

# (Opcional) Cargar dataset de validación para evaluar el clasificador
# val_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
# val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=min(4, os.cpu_count()), pin_memory=True)


Files already downloaded and verified


In [15]:
# --- Inicialización ---
# Crear Schedule (DEBE SER LA MISMA INSTANCIA/CONFIGURACIÓN QUE PARA SCORENET)
if SCHEDULE_TYPE == 'Linear':
    schedule = LinearSchedule(beta_min=BETA_MIN, beta_max=BETA_MAX, T=T_END)
elif SCHEDULE_TYPE == 'Cosine':
    schedule = CosineSchedule(T=T_END, s=S)
else:
    raise ValueError(f"Schedule type {SCHEDULE_TYPE} no soportado.")

In [None]:
# Crear SDE (DEBE SER LA MISMA INSTANCIA/CONFIGURACIÓN QUE PARA SCORENET)
if SDE_TYPE == 'VP':
    sde = VPSDE(schedule=schedule)
elif SDE_TYPE == 'VE':
    sde = VESDE(sigma_min=SIGMA_MIN, sigma_max=SIGMA_MAX, sigma=SIGMA)
elif SDE_TYPE == 'SubVP':
    sde = SubVPSDE(schedule=schedule)
else:
    raise ValueError(f"SDE type {SDE_TYPE} no soportado.")

In [17]:
from models.WideResNet import TimeDependentWideResNet
# Crear Modelo Clasificador (Usando Placeholder)
# Ajusta time_emb_dim si es necesario
classifier_model = TimeDependentWideResNet(
    num_classes=NUM_CLASSES,
    time_emb_dim=128 # Ejemplo, ajusta según tu implementación de WRN
    # Pasa otros parámetros de WRN aquí (depth, widen_factor, etc.)
).to(DEVICE)

# Crear Optimizador para el clasificador
optimizer = optim.AdamW(classifier_model.parameters(), lr=LEARNING_RATE)

# Función de Pérdida para clasificación
criterion = nn.CrossEntropyLoss()

# Crear directorio de checkpoints para el clasificador
os.makedirs(CLASSIFIER_CHECKPOINT_DIR, exist_ok=True)

In [18]:
# --- Bucle de Entrenamiento del Clasificador ---
print(f"Iniciando entrenamiento del clasificador en {DEVICE}...")
print(f"Clasificador: {classifier_model.__class__.__name__}")
print(f"Usando SDE: {sde.__class__.__name__}, Schedule: {schedule.__class__.__name__}")
print(f"Epochs: {N_EPOCHS}, Batch Size: {BATCH_SIZE}, LR: {LEARNING_RATE}")

global_step = 0
for epoch in range(N_EPOCHS):
    classifier_model.train() # Poner modelo en modo entrenamiento
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{N_EPOCHS}", leave=False)
    epoch_loss = 0.0
    epoch_correct = 0
    epoch_total = 0

    for batch_idx, (images_0, labels) in enumerate(progress_bar):
        images_0 = images_0.to(DEVICE) # Imagen limpia x_0
        labels = labels.to(DEVICE)     # Etiqueta verdadera y
        batch_size = images_0.shape[0]

        # 1. Muestrear tiempo t
        t = torch.rand(batch_size, device=DEVICE) * (T_END - EPS_T_SAMPLING) + EPS_T_SAMPLING

        # 2. Generar imagen ruidosa x_t
        mu = sde.mu_t(images_0, t)
        std = sde.sigma_t(t)
        # Asegurar broadcasting de std
        while len(std.shape) < len(images_0.shape):
            std = std.unsqueeze(-1)
        noise = torch.randn_like(images_0)
        images_t = mu + std * noise # Imagen ruidosa x_t

        # 3. Pasar x_t y t por el clasificador
        logits = classifier_model(images_t, t)

        # 4. Calcular pérdida Cross-Entropy
        loss = criterion(logits, labels)

        # 5. Backpropagation y optimización
        optimizer.zero_grad()
        loss.backward()
        # Opcional: torch.nn.utils.clip_grad_norm_(classifier_model.parameters(), max_norm=1.0)
        optimizer.step()

        # Calcular accuracy para logging
        preds = torch.argmax(logits, dim=1)
        epoch_correct += (preds == labels).sum().item()
        epoch_total += labels.size(0)

        epoch_loss += loss.item()
        global_step += 1

        # Actualizar barra de progreso
        progress_bar.set_postfix(loss=loss.item(), acc=f"{(preds == labels).float().mean().item():.3f}")

    avg_epoch_loss = epoch_loss / len(train_loader)
    avg_epoch_acc = epoch_correct / epoch_total
    print(f"Epoch {epoch+1}/{N_EPOCHS} - Loss: {avg_epoch_loss:.4f} - Accuracy: {avg_epoch_acc:.4f}")

    # Guardar checkpoint del clasificador
    if (epoch + 1) % CHECKPOINT_FREQ == 0 or epoch == N_EPOCHS - 1:
        checkpoint_path = os.path.join(CLASSIFIER_CHECKPOINT_DIR, f'classifier_cifar10_{SDE_TYPE}_{SCHEDULE_TYPE}_epoch_{epoch+1}.pth')
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': classifier_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_epoch_loss,
            'accuracy': avg_epoch_acc,
            'sde_type': SDE_TYPE, # Guardar para referencia
            'schedule_type': SCHEDULE_TYPE,
        }, checkpoint_path)
        print(f"Checkpoint del clasificador guardado en: {checkpoint_path}")

print("Entrenamiento del clasificador completado.")


Iniciando entrenamiento del clasificador en cuda...
Clasificador: TimeDependentWideResNet
Usando SDE: VPSDE, Schedule: LinearSchedule
Epochs: 50, Batch Size: 1024, LR: 0.0004


Epoch 1/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 1/50 - Loss: 2.5628 - Accuracy: 0.1026


Epoch 2/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 2/50 - Loss: 2.2929 - Accuracy: 0.1210


Epoch 3/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 3/50 - Loss: 2.2627 - Accuracy: 0.1404


Epoch 4/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 4/50 - Loss: 2.2416 - Accuracy: 0.1500


Epoch 5/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 5/50 - Loss: 2.2131 - Accuracy: 0.1656
Checkpoint del clasificador guardado en: ./checkpoints_classifier/classifier_cifar10_VP_Linear_epoch_5.pth


Epoch 6/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 6/50 - Loss: 2.1884 - Accuracy: 0.1787


Epoch 7/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 7/50 - Loss: 2.1602 - Accuracy: 0.1883


Epoch 8/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 8/50 - Loss: 2.1364 - Accuracy: 0.2009


Epoch 9/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 9/50 - Loss: 2.1144 - Accuracy: 0.2059


Epoch 10/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 10/50 - Loss: 2.0985 - Accuracy: 0.2147
Checkpoint del clasificador guardado en: ./checkpoints_classifier/classifier_cifar10_VP_Linear_epoch_10.pth


Epoch 11/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 11/50 - Loss: 2.0730 - Accuracy: 0.2282


Epoch 12/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 12/50 - Loss: 2.0476 - Accuracy: 0.2362


Epoch 13/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 13/50 - Loss: 2.0263 - Accuracy: 0.2448


Epoch 14/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 14/50 - Loss: 2.0060 - Accuracy: 0.2523


Epoch 15/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 15/50 - Loss: 1.9971 - Accuracy: 0.2553
Checkpoint del clasificador guardado en: ./checkpoints_classifier/classifier_cifar10_VP_Linear_epoch_15.pth


Epoch 16/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 16/50 - Loss: 1.9666 - Accuracy: 0.2693


Epoch 17/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 17/50 - Loss: 1.9507 - Accuracy: 0.2773


Epoch 18/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 18/50 - Loss: 1.9480 - Accuracy: 0.2764


Epoch 19/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 19/50 - Loss: 1.9438 - Accuracy: 0.2753


Epoch 20/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 20/50 - Loss: 1.9252 - Accuracy: 0.2870
Checkpoint del clasificador guardado en: ./checkpoints_classifier/classifier_cifar10_VP_Linear_epoch_20.pth


Epoch 21/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 21/50 - Loss: 1.9146 - Accuracy: 0.2885


Epoch 22/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 22/50 - Loss: 1.9103 - Accuracy: 0.2907


Epoch 23/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 23/50 - Loss: 1.8977 - Accuracy: 0.2959


Epoch 24/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 24/50 - Loss: 1.8892 - Accuracy: 0.3010


Epoch 25/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 25/50 - Loss: 1.8826 - Accuracy: 0.3007
Checkpoint del clasificador guardado en: ./checkpoints_classifier/classifier_cifar10_VP_Linear_epoch_25.pth


Epoch 26/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 26/50 - Loss: 1.8708 - Accuracy: 0.3072


Epoch 27/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 27/50 - Loss: 1.8616 - Accuracy: 0.3107


Epoch 28/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 28/50 - Loss: 1.8631 - Accuracy: 0.3108


Epoch 29/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 29/50 - Loss: 1.8465 - Accuracy: 0.3163


Epoch 30/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 30/50 - Loss: 1.8430 - Accuracy: 0.3161
Checkpoint del clasificador guardado en: ./checkpoints_classifier/classifier_cifar10_VP_Linear_epoch_30.pth


Epoch 31/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 31/50 - Loss: 1.8362 - Accuracy: 0.3215


Epoch 32/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 32/50 - Loss: 1.8242 - Accuracy: 0.3248


Epoch 33/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 33/50 - Loss: 1.8235 - Accuracy: 0.3278


Epoch 34/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 34/50 - Loss: 1.8159 - Accuracy: 0.3288


Epoch 35/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 35/50 - Loss: 1.8105 - Accuracy: 0.3338
Checkpoint del clasificador guardado en: ./checkpoints_classifier/classifier_cifar10_VP_Linear_epoch_35.pth


Epoch 36/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 36/50 - Loss: 1.8016 - Accuracy: 0.3316


Epoch 37/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 37/50 - Loss: 1.8035 - Accuracy: 0.3324


Epoch 38/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 38/50 - Loss: 1.7944 - Accuracy: 0.3372


Epoch 39/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 39/50 - Loss: 1.7898 - Accuracy: 0.3390


Epoch 40/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 40/50 - Loss: 1.7802 - Accuracy: 0.3398
Checkpoint del clasificador guardado en: ./checkpoints_classifier/classifier_cifar10_VP_Linear_epoch_40.pth


Epoch 41/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 41/50 - Loss: 1.7793 - Accuracy: 0.3423


Epoch 42/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 42/50 - Loss: 1.7727 - Accuracy: 0.3458


Epoch 43/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 43/50 - Loss: 1.7642 - Accuracy: 0.3504


Epoch 44/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 44/50 - Loss: 1.7797 - Accuracy: 0.3433


Epoch 45/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 45/50 - Loss: 1.7666 - Accuracy: 0.3466
Checkpoint del clasificador guardado en: ./checkpoints_classifier/classifier_cifar10_VP_Linear_epoch_45.pth


Epoch 46/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 46/50 - Loss: 1.7595 - Accuracy: 0.3512


Epoch 47/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 47/50 - Loss: 1.7521 - Accuracy: 0.3512


Epoch 48/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 48/50 - Loss: 1.7489 - Accuracy: 0.3556


Epoch 49/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 49/50 - Loss: 1.7469 - Accuracy: 0.3542


Epoch 50/50:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch 50/50 - Loss: 1.7420 - Accuracy: 0.3557
Checkpoint del clasificador guardado en: ./checkpoints_classifier/classifier_cifar10_VP_Linear_epoch_50.pth
Entrenamiento del clasificador completado.
