In [None]:
!pip install opendatasets --upgrade --quiet

In [None]:
import opendatasets as od

dataset_url = 'https://www.kaggle.com/datasets/greatgamedota/ffhq-face-data-set'
od.download(dataset_url)

Dataset URL: https://www.kaggle.com/datasets/greatgamedota/ffhq-face-data-set
Downloading ffhq-face-data-set.zip to ./ffhq-face-data-set


100%|██████████| 1.97G/1.97G [00:18<00:00, 116MB/s]





In [None]:
import os

DATA_DIR = './male-and-female-faces-dataset/Male and Female face dataset'
print(os.listdir(DATA_DIR))

['Male Faces', 'Female Faces']


In [None]:
print(os.listdir(DATA_DIR+'/Male Faces')[:10])

['1 (1548).jpg', '1 (1915).jpg', '1 (1856).jpg', '1 (1994).jpg', '1 (1234).jpg', '1 (2579).jpg', '1 (2278).jpg', '1 (628).jpg', '1 (1572).jpg', '1 (1589).jpg']


In [None]:
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms as T

In [None]:
device = get_default_device()
device

device(type='cuda')

In [None]:
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

def train_gan(dataloader, output_dir, num_epochs=100, latent_dim=100, lr=0.0002, beta1=0.5, device='cuda', checkpoint_path=None):
    generator = Generator(latent_dim).to(device)
    discriminator = Discriminator().to(device)

    # Adjusted learning rates
    g_optimizer = optim.Adam(generator.parameters(), lr=lr * 0.5, betas=(beta1, 0.999))
    d_optimizer = optim.Adam(discriminator.parameters(), lr=lr * 0.2, betas=(beta1, 0.999))

    # Learning rate schedulers to gradually decay learning rate
    g_scheduler = StepLR(g_optimizer, step_size=100, gamma=0.8)
    d_scheduler = StepLR(d_optimizer, step_size=100, gamma=0.8)

    criterion = nn.BCELoss()

    # Initialize starting epoch
    start_epoch = 0

    # Load from checkpoint if provided
    if checkpoint_path:
        print(f"Loading checkpoint from {checkpoint_path}...")
        checkpoint = torch.load(checkpoint_path)
        generator.load_state_dict(checkpoint['generator_state_dict'])
        discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
        g_optimizer.load_state_dict(checkpoint['g_optimizer_state_dict'])
        d_optimizer.load_state_dict(checkpoint['d_optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"Resumed from epoch {start_epoch}")

    fixed_noise = torch.randn(64, latent_dim, device=device)
    os.makedirs(output_dir, exist_ok=True)

    for epoch in range(start_epoch, num_epochs):
        for i, (real_images, _) in enumerate(dataloader):
            batch_size = real_images.size(0)
            real_images = real_images.to(device)

            # Label smoothing
            real_labels = torch.ones(batch_size, device=device) * 0.9
            fake_labels = torch.zeros(batch_size, device=device)

            # Train Discriminator
            d_optimizer.zero_grad()
            output_real = discriminator(real_images)
            d_loss_real = criterion(output_real, real_labels)

            noise = torch.randn(batch_size, latent_dim, device=device)
            fake_images = generator(noise)
            output_fake = discriminator(fake_images.detach())
            d_loss_fake = criterion(output_fake, fake_labels)

            d_loss = (d_loss_real + d_loss_fake) / 2
            d_loss.backward()
            d_optimizer.step()

            # Train Generator
            g_optimizer.zero_grad()
            output_fake = discriminator(fake_images)
            g_loss = criterion(output_fake, real_labels)
            g_loss.backward()
            g_optimizer.step()

            if i % 100 == 0:
                print(f'Epoch [{epoch}/{num_epochs}] Batch [{i}/{len(dataloader)}] '
                      f'D_loss: {d_loss.item():.4f} G_loss: {g_loss.item():.4f}')

        # Adjust learning rates with the schedulers at the end of each epoch
        g_scheduler.step()
        d_scheduler.step()

        if epoch % 5 == 0 or epoch == num_epochs - 1:
            save_generated_images(generator, epoch, output_dir, fixed_noise=fixed_noise, device=device)

        if epoch % 10 == 0:
            torch.save({
                'generator_state_dict': generator.state_dict(),
                'discriminator_state_dict': discriminator.state_dict(),
                'g_optimizer_state_dict': g_optimizer.state_dict(),
                'd_optimizer_state_dict': d_optimizer.state_dict(),
                'epoch': epoch,
            }, os.path.join(output_dir, f'checkpoint_epoch_{epoch}.pt'))

    return generator, discriminator


In [None]:
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

# Update these paths as needed
data_path = '/content/ffhq-face-data-set/thumbnails128x128'  # Path to the folder containing images
output_dir = 'generated_images'    # Directory to save generated images and model checkpoints

# Set parameters
image_size = 64         # Image size expected by the model
batch_size = 64         # Number of images in each batch
num_epochs = 1000        # Total number of epochs (try starting with 100 and adjust as needed)
latent_dim = 100        # Dimensionality of the latent space
lr = 0.0002             # Base learning rate
checkpoint_path = '/content/checkpoint_epoch_450.pt'  # Optional: Path to a previous checkpoint to resume training

# Define the data transformations
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# Create the dataset and dataloader
dataset = FaceDataset(data_path, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)

# Set device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Start training
generator, discriminator = train_gan(
    dataloader=dataloader,
    output_dir=output_dir,
    num_epochs=num_epochs,
    latent_dim=latent_dim,
    lr=lr,
    device=device,
    checkpoint_path=checkpoint_path
)

print("Training completed.")


Found 70000 images
Using device: cuda
Loading checkpoint from /content/checkpoint_epoch_450.pt...
Resumed from epoch 451


  checkpoint = torch.load(checkpoint_path)


Epoch [451/1000] Batch [0/1093] D_loss: 0.1634 G_loss: 7.0972
Epoch [451/1000] Batch [100/1093] D_loss: 0.1639 G_loss: 12.0856
Epoch [451/1000] Batch [200/1093] D_loss: 0.1677 G_loss: 7.2193
Epoch [451/1000] Batch [300/1093] D_loss: 0.1639 G_loss: 8.5507
Epoch [451/1000] Batch [400/1093] D_loss: 0.1790 G_loss: 4.6491
Epoch [451/1000] Batch [500/1093] D_loss: 0.1649 G_loss: 7.2039
Epoch [451/1000] Batch [600/1093] D_loss: 0.1655 G_loss: 6.3064
Epoch [451/1000] Batch [700/1093] D_loss: 0.1646 G_loss: 6.7214
Epoch [451/1000] Batch [800/1093] D_loss: 0.1645 G_loss: 8.4637
Epoch [451/1000] Batch [900/1093] D_loss: 0.1638 G_loss: 6.6527
Epoch [451/1000] Batch [1000/1093] D_loss: 0.1642 G_loss: 7.6956
Epoch [452/1000] Batch [0/1093] D_loss: 0.2187 G_loss: 4.1996
Epoch [452/1000] Batch [100/1093] D_loss: 0.1667 G_loss: 7.3881
Epoch [452/1000] Batch [200/1093] D_loss: 0.1642 G_loss: 6.8492
Epoch [452/1000] Batch [300/1093] D_loss: 0.1654 G_loss: 6.3942
Epoch [452/1000] Batch [400/1093] D_loss: 