In [1]:
# 1. Imports and Setup
import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms, utils
from torchvision.utils import save_image
from pytorch_fid.fid_score import calculate_fid_given_paths
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from pathlib import Path
from PIL import Image
import time
import random
import shutil
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 2. Model Definitions
class Generator(nn.Module):
    def __init__(self, latent_dim, img_channels, feature_maps):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, feature_maps*8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(feature_maps*8), nn.ReLU(True),
            nn.ConvTranspose2d(feature_maps*8, feature_maps*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps*4), nn.ReLU(True),
            nn.ConvTranspose2d(feature_maps*4, feature_maps*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps*2), nn.ReLU(True),
            nn.ConvTranspose2d(feature_maps*2, feature_maps, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps), nn.ReLU(True),
            nn.ConvTranspose2d(feature_maps, img_channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )
    def forward(self, x):
        return self.net(x)

class Discriminator(nn.Module):
    def __init__(self, img_channels, feature_maps):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(img_channels, feature_maps, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(feature_maps, feature_maps*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps*2), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(feature_maps*2, feature_maps*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps*4), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(feature_maps*4, feature_maps*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps*8), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(feature_maps*8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.net(x).view(-1)

# 3. Training Function
def train(data_path, out_dir, epochs=50, batch_size=128, lr=2e-4,
          latent_dim=100, gf_maps=64, df_maps=64, sample_interval=5):
    os.makedirs(out_dir, exist_ok=True)
    # Data loader
    transform = transforms.Compose([
        transforms.Resize((64,64)), transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])
    class CatDataset(Dataset):
        def __init__(self, file_list, transform=None):
            self.file_list = file_list
            self.transform = transform

        def __len__(self):
            return len(self.file_list)

        def __getitem__(self, idx):
            path = self.file_list[idx]
            img = Image.open(path)
            return transform(img)
    
    files = [os.path.join(data_path, line.strip()) for line in os.listdir(data_path)]
    dataset = CatDataset(files, transform)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # Initialize models
    gen = Generator(latent_dim, 3, gf_maps).to(device)
    disc = Discriminator(3, df_maps).to(device)
    opt_g = optim.Adam(gen.parameters(), lr=lr, betas=(0.5,0.999))
    opt_d = optim.Adam(disc.parameters(), lr=lr, betas=(0.5,0.999))
    criterion = nn.BCELoss()

    for epoch in range(epochs):
        start_time = time.time()
        for imgs in loader:
            bs = imgs.size(0)
            real = imgs.to(device)
            real_labels = torch.ones(bs, device=device)
            fake_labels = torch.zeros(bs, device=device)
            # Train Discriminator
            z = torch.randn(bs, latent_dim, 1, 1, device=device)
            fake = gen(z)
            d_loss = criterion(disc(real), real_labels) + criterion(disc(fake.detach()), fake_labels)
            disc.zero_grad(); d_loss.backward(); opt_d.step()
            # Train Generator
            g_loss = criterion(disc(fake), real_labels)
            gen.zero_grad(); g_loss.backward(); opt_g.step()
        print(f"Epoch {epoch+1}/{epochs} | D: {d_loss:.4f} | G: {g_loss:.4f} | calculation time {round((time.time() - start_time) / 60, 2)} min")
        if (epoch+1) % sample_interval == 0:
            path = os.path.join(out_dir, f"sample_{epoch+1}.png")
            save_image(utils.make_grid(fake[:16], normalize=True), path)
    # Save models
    torch.save(gen.state_dict(), os.path.join(out_dir, "generator.pth"))
    torch.save(disc.state_dict(), os.path.join(out_dir, "discriminator.pth"))
    return gen, disc

# 4. Quantitative Evaluation (FID)
def evaluate_fid(real_path, fake_path, batch_size=50):
    fid = calculate_fid_given_paths([real_path, fake_path], batch_size, device, dims=2048)
    print(f"FID: {fid:.3f}")
    return fid

# 5. Mode Collapse Check
def check_mode_collapse(gen, latent_dim, num_samples=1000, threshold=0.9):
    gen.eval()
    zs = torch.randn(num_samples, latent_dim, 1, 1, device=device)
    with torch.no_grad():
        sims = cosine_similarity(gen(zs).reshape(num_samples, -1).cpu().numpy())
    high = np.sum(sims > threshold) - num_samples
    all = num_samples * (num_samples - 1)
    print(f"High-similarity pairs: {high}/{all}")
    return high

# 6. Latent-Space Interpolation
def interpolate(gen, z1, z2, steps=10):
    gen.eval()
    alphas = torch.linspace(0,1,steps, device=device).view(-1,1,1,1)
    zs = (1-alphas)*z1 + alphas*z2
    with torch.no_grad(): imgs = gen(zs).cpu()
    grid = utils.make_grid(imgs, normalize=True, nrow=steps)
    return grid

# 7. Hyperparameter Grid Search
def grid_search(params):
    import itertools
    best = {'fid': float('inf'), 'params': None}
    for lr, bs in itertools.product(params['lr'], params['batch_size']):
        # Run training & FID eval here
        fid = np.random.rand()*100  # placeholder
        if fid < best['fid']:
            best.update({'fid': fid, 'params': {'lr': lr, 'batch_size': bs}})
    print(best)
    return best

def create_real_subset_for_fid(data_dir):

    # Set seed
    random.seed(42)
    sample_size = 1000

    temp_real_path = Path(data_dir).parent / "real_subset_for_fid"
    if temp_real_path.exists():
        shutil.rmtree(temp_real_path)
    temp_real_path.mkdir(parents=True, exist_ok=True)

    real_images = os.listdir(data_dir)
    sampled_images = random.sample(real_images, sample_size)

    for img_path in sampled_images:
        shutil.copy(f'{data_dir}/{img_path}', temp_real_path / img_path)

DATA_DIR = './cats_data'
OUT_DIR  = './output'

Using device: cpu


In [2]:
# Full Pipeline
# 1) Train on cats

gen, disc = train(DATA_DIR, OUT_DIR, epochs=50, batch_size=128, lr=2e-4) # 10 epochs -> 74 min

Epoch 1/50 | D: 0.3554 | G: 1.7254 | calculation time 6.31 min
Epoch 2/50 | D: 0.8433 | G: 1.4582 | calculation time 6.66 min
Epoch 3/50 | D: 0.9313 | G: 2.5497 | calculation time 6.82 min
Epoch 4/50 | D: 0.7322 | G: 6.8336 | calculation time 7.16 min
Epoch 5/50 | D: 1.3214 | G: 9.0833 | calculation time 7.52 min
Epoch 6/50 | D: 0.1538 | G: 5.3249 | calculation time 7.64 min
Epoch 7/50 | D: 0.5576 | G: 6.0926 | calculation time 7.72 min
Epoch 8/50 | D: 0.2614 | G: 3.5190 | calculation time 7.81 min
Epoch 9/50 | D: 0.3989 | G: 2.9087 | calculation time 7.88 min
Epoch 10/50 | D: 0.0293 | G: 4.8368 | calculation time 7.91 min
Epoch 11/50 | D: 0.3420 | G: 3.3074 | calculation time 7.96 min
Epoch 12/50 | D: 0.4660 | G: 8.2312 | calculation time 8.05 min
Epoch 13/50 | D: 0.3022 | G: 7.6864 | calculation time 7.76 min
Epoch 14/50 | D: 0.0446 | G: 6.3629 | calculation time 7.89 min
Epoch 15/50 | D: 0.7125 | G: 7.2566 | calculation time 8.08 min
Epoch 16/50 | D: 0.2614 | G: 5.2256 | calculation

In [3]:
# 2) Generate and save a fresh batch of samples

fixed_z = torch.randn(64, 100, 1, 1, device=device)
fresh_samples = gen(fixed_z)
save_image(utils.make_grid(fresh_samples, normalize=True), os.path.join(OUT_DIR, 'fresh_samples.png'))


In [4]:
# 3) Quantitative evaluation (FID)

fake_dir = os.path.join(OUT_DIR, 'generated_for_fid')
os.makedirs(fake_dir, exist_ok=True)
#    Generate a set of images for FID
with torch.no_grad():
    z_fid = torch.randn(1000, 100, 1, 1, device=device)
    imgs_fid = gen(z_fid).cpu()
    for i, img in enumerate(imgs_fid):
        save_image(img, os.path.join(fake_dir, f"fid_{i}.png"))
fid_score = evaluate_fid(real_path='real_subset_for_fid', fake_path=fake_dir, batch_size=50)

# 30 epochs -> 114

FID: 132.665


In [5]:
# 4) Mode-collapse check

too_many_similar = check_mode_collapse(gen, latent_dim=100, num_samples=1000, threshold=0.9)

High-similarity pairs: 2412/999000


In [6]:
# 5) Latent-space interpolation (save grid)

z1 = torch.randn(1, 100, 1, 1, device=device)
z2 = torch.randn(1, 100, 1, 1, device=device)
interp_grid = interpolate(gen, z1, z2, steps=10)
save_image(interp_grid, os.path.join(OUT_DIR, 'interpolation.png'))

In [None]:
# TODO # 6) Hyperparameter grid search 

tune_params = {'lr': [1e-4, 2e-4, 5e-4], 'batch_size': [64, 128]}
best = grid_search(tune_params)
print(f"Best hyperparameters: {best['params']} with FID {best['fid']:.2f}")