In [5]:
import os
import glob
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
from tqdm import tqdm  # ‚úÖ progress bar

from models.generator import UNetGenerator
from models.discriminator import PatchDiscriminator
from utils import sample_random_age

# =====================
# üîß Hyperparameters
# =====================
num_epochs = 10  # ‚ö° Fewer epochs for faster testing
batch_size = 32  # ‚ö° Larger batch for faster convergence
learning_rate = 0.0002
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
scaler = torch.cuda.amp.GradScaler()  # ‚úÖ Mixed precision scaler

# =====================
# üß† Dataset
# =====================
class UTKFaceDataset(Dataset):
    def __init__(self, root, transform=None):
        self.files = glob.glob(os.path.join(root, "*.jpg"))
        self.transform = transform

    def __getitem__(self, index):
        img_path = self.files[index]
        img = Image.open(img_path).convert("RGB")

        try:
            age = int(os.path.basename(img_path).split("_")[0])
        except:
            age = 0

        if self.transform:
            img = self.transform(img)
        return img, torch.tensor(age, dtype=torch.float32)

    def __len__(self):
        return len(self.files)

# =====================
# üñºÔ∏è Transforms
# =====================
transform = transforms.Compose([
    transforms.Resize((64, 64)),   # ‚ö° smaller for faster compute
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# =====================
# üì¶ DataLoader
# =====================
dataset = UTKFaceDataset(root="data/utkface", transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
print(f"Loaded {len(dataset)} images for training!")

# =====================
# ‚öôÔ∏è Models
# =====================
generator = UNetGenerator().to(device)
discriminator = PatchDiscriminator().to(device)

optimizer_G = torch.optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

# =====================
# üöÄ Training Loop
# =====================
for epoch in range(num_epochs):
    progress = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")

    for real_faces, real_ages in progress:
        real_faces, real_ages = real_faces.to(device), real_ages.to(device)
        target_ages = sample_random_age(real_ages.size(0)).to(device)

        # ======================
        # Train Discriminator
        # ======================
        with torch.cuda.amp.autocast():  # ‚úÖ mixed precision
            fake_faces = generator(real_faces, target_ages).detach()
            d_real = discriminator(real_faces, real_ages)
            d_fake = discriminator(fake_faces, target_ages)
            d_loss = -(torch.mean(torch.log(d_real + 1e-8) + torch.log(1 - d_fake + 1e-8)))

        optimizer_D.zero_grad()
        scaler.scale(d_loss).backward()
        scaler.step(optimizer_D)

        # ======================
        # Train Generator
        # ======================
        with torch.cuda.amp.autocast():
            fake_faces = generator(real_faces, target_ages)
            d_fake = discriminator(fake_faces, target_ages)
            g_loss = -torch.mean(torch.log(d_fake + 1e-8))

        optimizer_G.zero_grad()
        scaler.scale(g_loss).backward()
        scaler.step(optimizer_G)
        scaler.update()

        progress.set_postfix({"D_loss": f"{d_loss.item():.4f}", "G_loss": f"{g_loss.item():.4f}"})

print("‚úÖ Training completed successfully!")

os.makedirs("saved_models", exist_ok=True)
torch.save(generator.state_dict(), "saved_models/generator_fast.pth")
torch.save(discriminator.state_dict(), "saved_models/discriminator_fast.pth")


  scaler = torch.cuda.amp.GradScaler()  # ‚úÖ Mixed precision scaler


Loaded 23708 images for training!


  with torch.cuda.amp.autocast():  # ‚úÖ mixed precision
  with torch.cuda.amp.autocast():
Epoch 1/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 741/741 [2:46:37<00:00, 13.49s/it, D_loss=0.0003, G_loss=9.7669]     
Epoch 2/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 741/741 [1:27:02<00:00,  7.05s/it, D_loss=0.0001, G_loss=10.8563]    
Epoch 3/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 741/741 [13:35:55<00:00, 66.07s/it, D_loss=0.0000, G_loss=12.0497]        
Epoch 4/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 741/741 [23:17<00:00,  1.89s/it, D_loss=0.0000, G_loss=12.2667]
Epoch 5/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 741/741 [14:52<00:00,  1.20s/it, D_loss=0.0000, G_loss=13.0618]
Epoch 6/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 741/741 [14:23<00:00,  1.17s/it, D_loss=0.0000, G_loss=13.4370]
Epoch 7/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 741/741 [15:01<00:00,  1.22s/it, D_loss=0.0000, G_loss=14.2681]
Epoch 8/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 741/741 [2:12:17<00:00, 10.71s/it, D_loss=0.00

‚úÖ Training completed successfully!


In [19]:
import os
import glob
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
from tqdm import tqdm  # progress bar

from models.generator import UNetGenerator
from models.discriminator import PatchDiscriminator
from utils import normalize_age, sample_random_age  # make sure normalize_age() is in utils.py

# =====================
# üîß Hyperparameters
# =====================
num_epochs = 30
batch_size = 8
learning_rate = 0.0002
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# =====================
# üß† Custom UTKFace Dataset
# =====================
class UTKFaceDataset(Dataset):
    def __init__(self, root, transform=None):
        self.files = glob.glob(os.path.join(root, "*.jpg"))
        self.transform = transform

    def __getitem__(self, index):
        img_path = self.files[index]
        img = Image.open(img_path).convert("RGB")

        try:
            age = int(os.path.basename(img_path).split("_")[0])
        except:
            age = 0

        if self.transform:
            img = self.transform(img)
        return img, torch.tensor(age, dtype=torch.float32)

    def __len__(self):
        return len(self.files)


# =====================
# üñºÔ∏è Transforms
# =====================
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# =====================
# üì¶ Load Dataset
# =====================
dataset = UTKFaceDataset(root="data/utkface", transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
print(f"Loaded {len(dataset)} images for training!")

# =====================
# ‚öôÔ∏è Initialize Models
# =====================
generator = UNetGenerator().to(device)
discriminator = PatchDiscriminator().to(device)

optimizer_G = torch.optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

criterion = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()  # For extra sharpness

# =====================
# üöÄ Training Loop
# =====================
for epoch in range(num_epochs):
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
    for real_faces, real_ages in progress_bar:
        real_faces = real_faces.to(device)
        real_ages = normalize_age(real_ages).to(device)
        target_ages = normalize_age(sample_random_age(batch_size)).to(device)

        # ===== Generator Forward =====
        fake_faces = generator(real_faces, target_ages)

        # ===== Train Discriminator =====
        real_preds = discriminator(real_faces, real_ages)
        fake_preds = discriminator(fake_faces.detach(), target_ages)

        d_loss_real = criterion(real_preds, torch.ones_like(real_preds))
        d_loss_fake = criterion(fake_preds, torch.zeros_like(fake_preds))
        d_loss = (d_loss_real + d_loss_fake) / 2

        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()

        # ===== Train Generator =====
        fake_preds = discriminator(fake_faces, target_ages)
        g_adv_loss = criterion(fake_preds, torch.ones_like(fake_preds))
        g_l1 = l1_loss(fake_faces, real_faces) * 100  # pixel sharpness boost
        g_loss = g_adv_loss + g_l1

        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()

        progress_bar.set_postfix(D_loss=d_loss.item(), G_loss=g_loss.item())

    # ‚úÖ Save checkpoint every few epochs
    if (epoch + 1) % 5 == 0:
        os.makedirs("saved_models", exist_ok=True)
        torch.save(generator.state_dict(), f"saved_models/gen_epoch_{epoch+1}.pth")

print("‚úÖ Training completed successfully! Model saved in 'saved_models/' folder.")


Loaded 23708 images for training!


Epoch 1/30:   0%|          | 0/2964 [00:00<?, ?it/s]


ValueError: only one element tensors can be converted to Python scalars