In [6]:
import torch
from torchvision import transforms
from PIL import Image
import os
from models.generator import UNetGenerator
from utils import load_model, save_image_tensor

# --- Load Trained Generator ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = UNetGenerator()
generator = load_model(generator, "saved_models/generator_fast.pth", device)

generator.eval()

# --- Transform for input image ---
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

# --- Input folder (images to test) ---
input_dir = "test_images/"
output_dir = "output_images/"
os.makedirs(output_dir, exist_ok=True)

# --- Choose target age manually or random ---
target_age = torch.tensor([0.8]).to(device)   # e.g., 0.8 = 80 years scaled value

# --- Testing Loop ---
for img_name in os.listdir(input_dir):
    if img_name.endswith(('.jpg', '.png', '.jpeg')):
        img_path = os.path.join(input_dir, img_name)
        img = Image.open(img_path).convert("RGB")
        img_tensor = transform(img).unsqueeze(0).to(device)

        with torch.no_grad():
            output = generator(img_tensor, target_age)

        save_image_tensor(output, os.path.join(output_dir, f"aged_{img_name}"))
        print(f"âœ… Generated aged version for: {img_name}")

print("ðŸŽ¯ All test images processed successfully!")


âœ… Model loaded from saved_models/generator_fast.pth
ðŸŽ¯ All test images processed successfully!
