In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from src.generator import Generator
from src.discriminator import Discriminator
from src.training import train
import torch.optim as optim

In [None]:
# Fijar semilla para reproducibilidad
seed = 42
torch.manual_seed(seed)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# Cargar dataset CIFAR-10
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
cifar_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
selected_images = [idx for idx, l in enumerate(cifar_train) if l[1] == 5]  # Clase de perros en CIFAR-10 es 5
cifar_train = torch.utils.data.Subset(cifar_train, selected_images)
batch_size = 64
train_loader = DataLoader(cifar_train, batch_size=batch_size, shuffle=True)

In [None]:
# Instanciar los modelos
G = Generator().to(device)
D = Discriminator().to(device)

In [None]:
# Definir optimizadores
disc_opt = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
gen_opt = optim.Adam(G.parameters(), lr=0.00014, betas=(0.5, 0.999))

In [None]:
# Entrenar la GAN
train(D, G, disc_opt, gen_opt, train_loader, batch_size, epochs=210, device=device)