# Generative Adversarial Network

## Example
- Face Generation using the CelebA dataset (a large-scale dataset of celebrity images).
- The goal is to train a GAN to generate realistic images of celebrity faces by learning from the CelebA dataset.

### Step 1: Dataset Preparation
  - Dataset: The CelebA dataset contains over 200,000 images of celebrity faces.
  - Download and Preprocess:
    - Download the dataset and unzip it to a directory.
    - Preprocess the images: Resize to $64\times64$, normalize pixel values to [-1, 1].
  - PyTorch Dataset: Use torchvision.datasets.ImageFolder to load and preprocess the images.

In [None]:
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

# Define image transformations
transform = transforms.Compose([
    transforms.Resize((64, 64)),   # Resize images to 64x64
    transforms.ToTensor(),         # Convert to Tensor
    transforms.Normalize([0.5], [0.5])  # Normalize pixel values to [-1, 1]
])

# Load CelebA dataset
dataset_path = "path_to_celeba_dataset"  
dataset = ImageFolder(root=dataset_path, transform=transform)

# Create DataLoader
batch_size = 128
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
from PIL import Image
import os

# ----------------------
# Custom Dataset Class
# ----------------------
class ImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_files = [f for f in os.listdir(root_dir) if f.endswith(('png', 'jpg', 'jpeg'))]
    
    def __getitem__(self, index):
        img_path = os.path.join(self.root_dir, self.image_files[index])
        high_res_img = Image.open(img_path).convert("RGB")
        
        # Simulate a low-resolution image
        low_res_img = transforms.Resize((64, 64))(high_res_img)
        low_res_img = transforms.Resize((256, 256))(low_res_img)
        
        if self.transform:
            high_res_img = self.transform(high_res_img)
            low_res_img = self.transform(low_res_img)
        
        return low_res_img, high_res_img
    
    def __len__(self):
        return len(self.image_files)

# ----------------------
# Generator Model
# ----------------------
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)
        )
    
    def forward(self, x):
        return self.model(x)

# ----------------------
# Discriminator Model
# ----------------------
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Flatten(),
            nn.Linear(64 * 64 * 64, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.model(x)

# ----------------------
# Training Loop
# ----------------------
def train_gan(dataloader, generator, discriminator, criterion, optimizer_G, optimizer_D, epochs, device):
    generator.to(device)
    discriminator.to(device)

    for epoch in range(epochs):
        for i, (low_res, high_res) in enumerate(dataloader):
            low_res, high_res = low_res.to(device), high_res.to(device)
            
            # ---------------------
            # Train Discriminator
            # ---------------------
            optimizer_D.zero_grad()
            real_labels = torch.ones(low_res.size(0), 1).to(device)
            fake_labels = torch.zeros(low_res.size(0), 1).to(device)
            
            # Loss for real images
            real_loss = criterion(discriminator(high_res), real_labels)
            
            # Loss for fake images
            fake_imgs = generator(low_res)
            fake_loss = criterion(discriminator(fake_imgs.detach()), fake_labels)
            
            # Total discriminator loss
            D_loss = real_loss + fake_loss
            D_loss.backward()
            optimizer_D.step()

            # ---------------------
            # Train Generator
            # ---------------------
            optimizer_G.zero_grad()
            G_loss = criterion(discriminator(fake_imgs), real_labels)  # Fool discriminator
            G_loss.backward()
            optimizer_G.step()

            # Print training progress
            if i % 10 == 0:
                print(f"Epoch [{epoch+1}/{epochs}], Step [{i+1}/{len(dataloader)}], "
                        f"D Loss: {D_loss.item():.4f}, G Loss: {G_loss.item():.4f}")

# ----------------------
# Main Function
# ----------------------
if __name__ == "__main__":
    # Paths and Hyperparameters
    root_dir = "path_to_your_dataset"  # Replace with the path to your dataset
    batch_size = 16
    epochs = 50
    lr = 0.0002
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Transformations
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])

    # Dataset and DataLoader
    dataset = ImageDataset(root_dir=root_dir, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Initialize Models
    generator = Generator()
    discriminator = Discriminator()

    # Loss and Optimizers
    criterion = nn.BCELoss()
    optimizer_G = optim.Adam(generator.parameters(), lr=lr)
    optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)

    # Train the GAN
    train_gan(dataloader, generator, discriminator, criterion, optimizer_G, optimizer_D, epochs, device)
