In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as dset
import torchvision.utils as vutils
import os
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import glob

In [7]:
class CelebADataset(Dataset):
    """
    Custom PyTorch Dataset class for loading images from the CelebA dataset.

    This class loads images from a specified directory without requiring class subfolders,
    making it suitable for datasets where images are stored in a single directory.

    Attributes:
        root_dir (str): Path to the directory containing CelebA images.
        transform (callable, optional): A function/transform to apply to the images.

    Methods:
        __len__(): Returns the number of images in the dataset.
        __getitem__(idx): Loads an image at the given index, applies transformations (if any),
                          and returns the transformed image along with a dummy label (0).

    """

    def __init__(self, root_dir, transform=None):
        """
        Initializes the dataset by collecting all image file paths.
        Args:
            root_dir (str): Directory where CelebA images are stored.
            transform (callable, optional): Transformations to apply to images (default is None).
        """
        self.files = sorted(glob.glob(os.path.join(root_dir, "*.jpg"))) 
        self.transform = transform

    def __len__(self):
        """
        Returns the total number of images in the dataset.
        Returns:
            int: Number of images.
        """
        return len(self.files)

    def __getitem__(self, idx):
        """
        Loads an image from the dataset, applies transformations, and returns it.
        Args:
            idx (int): Index of the image to load.
        Returns:
            Tuple[Tensor, int]: Transformed image as a PyTorch tensor and a dummy label (0).
        """
        img_path = self.files[idx]
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, 0  # No labels in CelebA, so return a dummy label


In [9]:
# transformations include resizing to 64x64, center cropping and normalization of gray values t0 [0, 1]
transform = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [12]:
# specifying dataset path and creating custom dataset class object
DATA_PATH = "/kaggle/input/celeba-dataset/img_align_celeba/img_align_celeba/"
dataset = CelebADataset(root_dir=DATA_PATH, transform=transform)

In [13]:
# creating a dataloader object for CelebA dataset
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

In [24]:
# initialize the variables
LATENT_DIM = 100
EPOCHS = 10
LEARNING_RATE = 0.0002
BETA1 = 0.5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [25]:
class Generator(nn.Module):
    """
    DCGAN Generator model that transforms a latent noise vector into a 64x64 RGB image.
    Architecture:
    - Uses transposed convolutions for upsampling.
    - Batch normalization and ReLU activations in all layers except the output.
    - Final layer uses Tanh activation to scale pixel values to [-1, 1].
    Args:
        LATENT_DIM (int): Size of the input noise vector.
    Forward Pass:
        x (Tensor): Input noise tensor of shape (batch_size, LATENT_DIM, 1, 1).
        Returns: Generated image tensor of shape (batch_size, 3, 64, 64).
    """

    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(LATENT_DIM, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)


In [26]:
class Discriminator(nn.Module):
    """
    DCGAN Discriminator model that classifies images as real or fake.
    Architecture:
    - Uses strided convolutions for downsampling.
    - Batch normalization after the first layer for stable training.
    - LeakyReLU activation (slope=0.2) for all layers except the output.
    - Final layer uses a Sigmoid activation to output a probability.
    Forward Pass:
        x (Tensor): Input image tensor of shape (batch_size, 3, 64, 64).
        Returns: Probability tensor of shape (batch_size, 1) indicating real (1) or fake (0).
    """

    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)


In [27]:
def weights_init(m):
    """
    Initializes the weights of convolutional and batch normalization layers 
    in the Generator and Discriminator models as per the DCGAN paper.

    Initialization Details:
    - Convolutional layers: Weights are initialized from a normal distribution 
      with mean 0 and standard deviation 0.02.
    - BatchNorm layers: Weights are initialized from a normal distribution 
      (mean=1, std=0.02), and biases are set to 0.

    Args:
        m (nn.Module): A model layer passed during model initialization.
    """
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


In [28]:
# instantiating the models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

In [29]:
# initialize the model weights
generator.apply(weights_init)
discriminator.apply(weights_init)

Discriminator(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (12): Sigmoid()
  )
)

In [30]:
# specify loss criteria and optimizers for generator and discriminator
criterion = nn.BCELoss()
optimizerG = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(BETA1, 0.999))
optimizerD = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(BETA1, 0.999))

In [31]:
# train the model and save generated images
for epoch in range(EPOCHS):
    for i, (real_images, _) in enumerate(dataloader):
        discriminator.zero_grad()
        real_images = real_images.to(device)
        real_labels = torch.ones(real_images.size(0), 1, device=device)
        fake_labels = torch.zeros(real_images.size(0), 1, device=device)
        
        output = discriminator(real_images).view(-1, 1)
        loss_real = criterion(output, real_labels)
        loss_real.backward()
        
        noise = torch.randn(real_images.size(0), LATENT_DIM, 1, 1, device=device)
        fake_images = generator(noise)
        output = discriminator(fake_images.detach()).view(-1, 1)
        loss_fake = criterion(output, fake_labels)
        loss_fake.backward()
        optimizerD.step()
        
        generator.zero_grad()
        output = discriminator(fake_images).view(-1, 1)
        loss_G = criterion(output, real_labels)
        loss_G.backward()
        optimizerG.step()
        
        if i % 100 == 0:
            print(f"Epoch [{epoch}/{EPOCHS}] Step [{i}/{len(dataloader)}] Loss D: {loss_real + loss_fake:.4f}, Loss G: {loss_G:.4f}")
    
    vutils.save_image(fake_images, f"fake_images_epoch_{epoch}.png", normalize=True)

Epoch [0/10] Step [0/1583] Loss D: 1.5798, Loss G: 3.2713
Epoch [0/10] Step [100/1583] Loss D: 0.7159, Loss G: 7.9165
Epoch [0/10] Step [200/1583] Loss D: 0.3347, Loss G: 5.9586
Epoch [0/10] Step [300/1583] Loss D: 0.5652, Loss G: 3.1923
Epoch [0/10] Step [400/1583] Loss D: 0.8153, Loss G: 6.9473
Epoch [0/10] Step [500/1583] Loss D: 0.5408, Loss G: 3.5140
Epoch [0/10] Step [600/1583] Loss D: 0.4716, Loss G: 3.0437
Epoch [0/10] Step [700/1583] Loss D: 0.7238, Loss G: 3.4431
Epoch [0/10] Step [800/1583] Loss D: 0.7049, Loss G: 6.8274
Epoch [0/10] Step [900/1583] Loss D: 0.6737, Loss G: 6.0510
Epoch [0/10] Step [1000/1583] Loss D: 0.7181, Loss G: 6.7699
Epoch [0/10] Step [1100/1583] Loss D: 0.4998, Loss G: 5.3280
Epoch [0/10] Step [1200/1583] Loss D: 0.2871, Loss G: 5.0793
Epoch [0/10] Step [1300/1583] Loss D: 0.5343, Loss G: 5.6479
Epoch [0/10] Step [1400/1583] Loss D: 0.4275, Loss G: 3.4961
Epoch [0/10] Step [1500/1583] Loss D: 0.9757, Loss G: 2.3891
Epoch [1/10] Step [0/1583] Loss D: 0