1. Install and Import Libraries

In [37]:
#!pip install torch torchvision matplotlib tqdm

import os
import numpy as np
import matplotlib.pyplot as plt
#from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid, save_image
from IPython.display import HTML, display
import matplotlib.animation as animation

from torch.autograd import Variable


2. Image Folder Creation

In [None]:
os.makedirs("progress_images", exist_ok=True)

 2. Set Hyperparameters

In [38]:
n_epochs = 5
batch_size = 64
lr = 0.0002
b1 = 0.5
b2 = 0.999
latent_dim = 100
img_size = 64
channels = 3
sample_interval = 500

img_shape = (channels, img_size, img_size)
cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")

3. Define Generator

In [39]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img

4. Define Discriminator

In [40]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

5. Prepare Dataset

In [70]:
base_path = "./data/celeba"
print("Dataset folder exists:", os.path.exists(base_path))
print("Images folder exists:", os.path.exists(os.path.join(base_path, "img_align_celeba")))
print("Attributes file exists:", os.path.exists(os.path.join(base_path, "list_attr_celeba.txt")))
if os.path.exists(os.path.join(base_path, "img_align_celeba")):
    print("Number of images:", len(os.listdir(os.path.join(base_path, "img_align_celeba"))))


root = "./data"
 

transform = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

celeba_dataset = datasets.CelebA(
    root="./data",
    download=False,  # Important to prevent re-downloading
    transform=transform
)

# dataloader = DataLoader(
#     celeba_dataset,
#     batch_size=batch_size,
#     shuffle = True,
#     num_workers = 4
# )



Dataset folder exists: True
Images folder exists: True
Attributes file exists: True
Number of images: 202599


RuntimeError: Dataset not found or corrupted. You can use download=True to download it

6. Initialize Models/Optimizers 

In [None]:
generator = Generator().to(device)
discriminator = Discriminator().to(device)
adversarial_loss = nn.BCELoss().to(device)

optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))


7. Image Display

In [None]:
def show_images(images, nrow=5):
    grid = make_grid(images, nrow=nrow, normalize=True)
    plt.figure(figsize=(10,10))
    plt.imshow(np.transpose(grid.cpu(), (1,2,0)))
    plt.axis('off')
    plt.show()

8. Training Loop with Real Time Output

In [None]:
G_losses = []
D_losses = []

for epoch in range(n_epochs):
    for i, (imgs, _) in enumerate(tqdm(dataloader)):
        
        valid = torch.ones(imgs.size(0), 1, device=device)
        fake = torch.zeros(imgs.size(0), 1, device=device)

        real_imgs = imgs.to(device)

        # Train Generator
        optimizer_G.zero_grad()
        z = torch.randn(imgs.size(0), latent_dim, device=device)
        gen_imgs = generator(z)
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)
        g_loss.backward()
        optimizer_G.step()

        # Train Discriminator
        optimizer_D.zero_grad()
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        G_losses.append(g_loss.item())
        D_losses.append(d_loss.item())

        if i % sample_interval == 0:
            clear_output(wait=True)
            print(f"[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(dataloader)}] [D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")
            show_images(gen_imgs[:25])

        if epoch % 5 == 0:
            real_imgs, _ = next(iter(dataloader))
            real_imgs = denormalize(real_imgs)

            z = torch.randn(real_imgs.size(0), opt.latent_dim).to(device)
            gen_imgs = generator(z)
            gen_imgs = denormalize(gen_imgs)

            fig, axs = plt.subplots(1, 2, figsize=(12, 6))

            axs[0].imshow(make_grid(real_imgs[:25], nrow=5).permute(1, 2, 0).cpu())
            axs[0].set_title(f'Real CelebA Faces (Epoch {epoch})')
            axs[0].axis('off')

            axs[1].imshow(make_grid(gen_imgs[:25], nrow=5).permute(1, 2, 0).detach().cpu())
            axs[1].set_title(f'Generated Faces (Epoch {epoch})')
            axs[1].axis('off')

            plt.show()


    with torch.no_grad():
        z = torch.randn(25, opt.latent_dim).to(device)
        gen_imgs = generator(z)
        gen_imgs = denormalize(gen_imgs)
        save_image(gen_imgs, f"progress_images/epoch_{epoch:03d}.png", nrow=5, normalize=True)
    


9. Loss Plot Post Training

In [None]:
plt.plot(G_losses, label="Generator Loss")
plt.plot(D_losses, label="Discriminator Loss")
plt.legend()
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.title("GAN Losses")
plt.show()


10. Faces Animation

In [None]:
import glob
from PIL import Image

# Load saved images
image_files = sorted(glob.glob('progress_images/epoch_*.png'))

frames = []
for filename in image_files:
    img = Image.open(filename)
    frames.append(img)

# Create animation
frames[0].save('gan_training_progress.gif',
               format='GIF',
               append_images=frames[1:],
               save_all=True,
               duration=300,  # ms per frame
               loop=0)

# Display animation
from IPython.display import Image as IPyImage
IPyImage(open('gan_training_progress.gif', 'rb').read())

11. Generation of New Faces

In [None]:
z = torch.randn(25, latent_dim, device=device)
gen_imgs = generator(z)
show_images(gen_imgs, nrow=5)
