<a href="https://www.kaggle.com/code/tigershiva02/dl-project-enhancing-image-generation?scriptVersionId=208336405" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [None]:
pip install lpips


In [None]:
import os
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Subset
from torchvision import transforms, datasets
from torchvision.utils import save_image
from tqdm import tqdm
from math import log10
from torch.cuda.amp import GradScaler, autocast
import matplotlib.pyplot as plt

# Enable optimized CUDA kernels
import torch.backends.cudnn as cudnn
cudnn.benchmark = True

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dataset Preprocessing
IMG_SIZE = 32  # Reduced resolution to save memory
BATCH_SIZE = 8  # Reduced batch size
DATASET_DIR = "/kaggle/input/celeba-dataset/img_align_celeba"  # Update path if necessary

transform = transforms.Compose([
    transforms.CenterCrop(178),
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # Normalize to [-1, 1]
])

dataset = datasets.ImageFolder(root=DATASET_DIR, transform=transform)

# Use a smaller dataset subset
subset_size = 5000  # Further reduced dataset size
indices = torch.randperm(len(dataset))[:subset_size]
dataset = Subset(dataset, indices)

dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# Define U-Net
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
        )

        # Middle Layers
        self.middle = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.middle(x)
        x = self.decoder(x)
        assert x.shape[-2:] == (IMG_SIZE, IMG_SIZE), f"Output shape mismatch: {x.shape}"
        return x

model = UNet().to(device)

# Training Setup
optimizer = Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = StepLR(optimizer, step_size=10, gamma=0.5)
criterion_mse = nn.MSELoss()
scaler = GradScaler()  # For mixed precision training
EPOCHS = 50


# PSNR Calculation Function
def calculate_psnr(original, generated):
    mse = torch.mean((original - generated) ** 2)
    if mse == 0:
        return 100  # No difference
    psnr = 10 * log10(1 / mse.item())
    return psnr

# Pixel Accuracy Function
def pixel_accuracy(original, generated, threshold=0.1):
    return torch.mean((torch.abs(original - generated) < threshold).float()).item()

# Metrics Storage
loss_history = []
psnr_history = []
accuracy_history = []

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    psnr_total = 0.0
    accuracy_total = 0.0  # Initialize accuracy tracker

    for images, _ in tqdm(dataloader):
        images = images.to(device)

        # Adding noise
        noise = torch.randn_like(images) * 0.1
        noisy_images = images + noise.to(device)
        noisy_images = torch.clip(noisy_images, -1, 1)

        with autocast():  # Mixed precision context
            # Forward pass
            outputs = model(noisy_images)

            # Compute loss
            mse_loss = criterion_mse(outputs, images)
            total_loss = mse_loss

        # Backpropagation
        optimizer.zero_grad()
        scaler.scale(total_loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += total_loss.item()

        # Calculate PSNR and Accuracy for the batch
        psnr_total += calculate_psnr(images, outputs)
        accuracy_total += pixel_accuracy(images, outputs)

        # Clear unused memory
        torch.cuda.empty_cache()

    # Store average metrics
    avg_loss = running_loss / len(dataloader)
    avg_psnr = psnr_total / len(dataloader)
    avg_accuracy = accuracy_total / len(dataloader)  # Average accuracy

    loss_history.append(avg_loss)
    psnr_history.append(avg_psnr)
    accuracy_history.append(avg_accuracy)

    scheduler.step()
    print(f"Epoch [{epoch + 1}/{EPOCHS}], Loss: {avg_loss:.4f}, PSNR: {avg_psnr:.2f} dB, Accuracy: {avg_accuracy * 100:.2f}%")

# Plotting Results
def plot_results(loss_history, psnr_history):
    plt.figure(figsize=(10, 5))

    # Plot Training Loss
    plt.subplot(1, 2, 1)
    plt.plot(loss_history, label='Training Loss')
    plt.title("Training Loss Over Epochs")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()

    # Plot PSNR
    plt.subplot(1, 2, 2)
    plt.plot(psnr_history, label='PSNR (dB)', color='orange')
    plt.title("PSNR Over Epochs")
    plt.xlabel("Epochs")
    plt.ylabel("PSNR (dB)")
    plt.legend()

    # Accuracy
    plt.subplot(1, 3, 3)
    plt.plot(accuracy_history, label="Pixel Accuracy (%)", color="green")
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy (%)")
    plt.title("Pixel Accuracy Over Epochs")
    plt.legend()

    plt.tight_layout()
    plt.show()

# Call the function to plot results
plot_results(loss_history, psnr_history)
