<a href="https://colab.research.google.com/github/saakshigupta2002/COMP3710/blob/main/Demo2/PatternLab3c.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import os
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.utils import save_image

# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# Create a directory to save generated images
os.makedirs('generated_images', exist_ok=True)

# Hyperparameters
image_size = 64  # Images are 64 by 64 pixels
batch_size = 128
learning_rate = 0.0002  # Learning rate for optimal training
epochs = 100  # Increase the number of training epochs for better results
latent_size = 128

# Define data transformations
transform_train = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load the CelebA dataset
trainset = torchvision.datasets.CelebA(root='./data', download=True, transform=transform_train)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)

# Discriminator architecture
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.discriminator = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0, bias=False),
            nn.Flatten(),
            nn.Sigmoid()
        )

    def forward(self, image):
        return self.discriminator(image)

# Generator architecture
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.generator = nn.Sequential(
            nn.ConvTranspose2d(latent_size, 512, kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1, bias=False),
            nn.Tanh()
        )

    def forward(self, latent):
        return self.generator(latent)

# Instantiate the models and move them to the GPU if available
gen = Generator().to(device)
dis = Discriminator().to(device)

# Binary cross-entropy loss for discriminator
criterion = nn.BCELoss()

# Adam optimizers for both generator and discriminator
optimizerD = torch.optim.Adam(dis.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizerG = torch.optim.Adam(gen.parameters(), lr=learning_rate, betas=(0.5, 0.999))

# Training loop
for epoch in range(epochs):
    for batch, _ in train_loader:
        batch = batch.to(device)

        # Train the discriminator
        optimizerD.zero_grad()
        real_labels = torch.ones(batch.size(0), 1, device=device)
        fake_labels = torch.zeros(batch.size(0), 1, device=device)
        real_outputs = dis(batch)
        real_loss = criterion(real_outputs, real_labels)

        latent = torch.randn(batch.size(0), latent_size, 1, 1, device=device)
        fake_images = gen(latent)
        fake_outputs = dis(fake_images.detach())
        fake_loss = criterion(fake_outputs, fake_labels)

        d_loss = real_loss + fake_loss
        d_loss.backward()
        optimizerD.step()

        # Train the generator
        optimizerG.zero_grad()
        latent = torch.randn(batch.size(0), latent_size, 1, 1, device=device)
        fake_images = gen(latent)
        fake_outputs = dis(fake_images)
        g_loss = criterion(fake_outputs, real_labels)
        g_loss.backward()
        optimizerG.step()

    # Save generated images
    fake_images = gen(latent)
    fake_fname = f'generated_images_epoch{epoch + 1}.png'
    save_image(fake_images, fake_fname, nrow=8, normalize=True)
    print(f'Saved {fake_fname}')

    # Print and log losses
    print(f'Epoch [{epoch + 1}/{epochs}], D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}')
