<a href="https://colab.research.google.com/github/sans-mishra/Underwater-Image-Enhancement-using-VAE-GAN-and-Diffusion/blob/main/GAN_image_enhancement.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import required libraries

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import os
from skimage.metrics import peak_signal_noise_ratio as psnr, structural_similarity as ssim
from torch.nn.functional import mse_loss
import numpy as np

# Dataset paths

In [None]:
TRAIN_RAW_PATH = "/content/drive/MyDrive/archive (2)/Train/Raw"
TRAIN_REFERENCE_PATH = "/content/drive/MyDrive/archive (2)/Train/Reference"
TEST_RAW_PATH = "/content/drive/MyDrive/archive (2)/Test/Raw"
TEST_REFERENCE_PATH = "/content/drive/MyDrive/archive (2)/Test/Reference"

# Define device (CUDA if available, else CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define dataset class

In [None]:
class UnderwaterImageDataset(Dataset):
    def __init__(self, raw_dir, reference_dir, transform=None):
        self.raw_dir = raw_dir
        self.reference_dir = reference_dir
        self.transform = transform

        # List of all raw image file names
        self.raw_images = os.listdir(raw_dir)

    def __len__(self):
        # Returns the total number of raw images
        return len(self.raw_images)

    def __getitem__(self, idx):
        # Load raw image and corresponding reference image
        raw_image_path = os.path.join(self.raw_dir, self.raw_images[idx])
        reference_image_path = os.path.join(self.reference_dir, self.raw_images[idx])

        # Open the images
        raw_image = Image.open(raw_image_path).convert("RGB")
        reference_image = Image.open(reference_image_path).convert("RGB")

        if self.transform:
            # Apply the same transformations to both raw and reference images
            raw_image = self.transform(raw_image)
            reference_image = self.transform(reference_image)

        return {'raw': raw_image, 'reference': reference_image}

# Transformations

In [None]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize images to 256x256
    transforms.ToTensor(),  # Convert images to tensors
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize to [-1, 1]
])

# Define the Generator (U-Net Architecture)

In [None]:
class UNetGenerator(nn.Module):
    # U-Net Generator with skip connections as described in the Pix2Pix paper.
    def __init__(self):
        super(UNetGenerator, self).__init__()

        # Encoder (Downsampling path)
        self.enc1 = self.down_block(3, 64, apply_batchnorm=False)  # First layer, no batch norm
        self.enc2 = self.down_block(64, 128)
        self.enc3 = self.down_block(128, 256)
        self.enc4 = self.down_block(256, 512)
        self.enc5 = self.down_block(512, 512)
        self.enc6 = self.down_block(512, 512)

        # Decoder (Upsampling path with skip connections)
        self.dec1 = self.up_block(512, 512)
        self.dec2 = self.up_block(1024, 512)
        self.dec3 = self.up_block(1024, 256)
        self.dec4 = self.up_block(512, 128)
        self.dec5 = self.up_block(256, 64)

        self.final_layer = nn.ConvTranspose2d(128, 3, kernel_size=4, stride=2, padding=1)  # Final output layer

    def down_block(self, in_channels, out_channels, apply_batchnorm=True):
        #Downsampling block with convolution, batch normalization, and LeakyReLU.

        layers = [nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)]
        if apply_batchnorm:
            layers.append(nn.BatchNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2))
        return nn.Sequential(*layers)

    def up_block(self, in_channels, out_channels, apply_dropout=False):
        # Upsampling block with transposed convolution, batch normalization, and ReLU.

        layers = [nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)]
        layers.append(nn.BatchNorm2d(out_channels))
        if apply_dropout:
            layers.append(nn.Dropout(0.5))
        layers.append(nn.ReLU())
        return nn.Sequential(*layers)

    def forward(self, x):
        # Downsampling path (encoder)
        enc1 = self.enc1(x)
        enc2 = self.enc2(enc1)
        enc3 = self.enc3(enc2)
        enc4 = self.enc4(enc3)
        enc5 = self.enc5(enc4)
        enc6 = self.enc6(enc5)

        # Upsampling path (decoder with skip connections)
        dec1 = self.dec1(enc6)
        dec2 = self.dec2(torch.cat([dec1, enc5], dim=1))
        dec3 = self.dec3(torch.cat([dec2, enc4], dim=1))
        dec4 = self.dec4(torch.cat([dec3, enc3], dim=1))
        dec5 = self.dec5(torch.cat([dec4, enc2], dim=1))

        return torch.tanh(self.final_layer(torch.cat([dec5, enc1], dim=1)))


# Define the Discriminator (PatchGAN Architecture)

In [None]:
class PatchGANDiscriminator(nn.Module):
    #PatchGAN Discriminator as described in the Pix2Pix paper.

    def __init__(self):
        super(PatchGANDiscriminator, self).__init__()

        self.model = nn.Sequential(
            self.down_block(6, 64, apply_batchnorm=False),
            self.down_block(64, 128),
            self.down_block(128, 256),
            self.down_block(256, 512),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)  # Final layer (output: patch map)
        )

    def down_block(self, in_channels, out_channels, apply_batchnorm=True):
        #Downsampling block with convolution, batch normalization, and LeakyReLU.

        layers = [nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)]
        if apply_batchnorm:
            layers.append(nn.BatchNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2))
        return nn.Sequential(*layers)

    def forward(self, input, target):
        # Forward pass with concatenated input and target images.

        x = torch.cat([input, target], dim=1)  # Concatenate input and target along the channel dimension
        return self.model(x)


# Checking for unmatched images
(encountered this error in this model)

In [None]:
def check_and_remove_unmatched_images(raw_dir, reference_dir):

    raw_images = set(os.listdir(raw_dir))
    reference_images = set(os.listdir(reference_dir))

    unmatched_raw = raw_images - reference_images  # Images in raw but not in reference
    unmatched_reference = reference_images - raw_images  # Images in reference but not in raw

    # Remove unmatched images from both directories
    for image in unmatched_raw:
        os.remove(os.path.join(raw_dir, image))
        print(f"Removed unmatched raw image: {image}")

    for image in unmatched_reference:
        os.remove(os.path.join(reference_dir, image))
        print(f"Removed unmatched reference image: {image}")

# Loss Functions

In [None]:
criterion_GAN = nn.MSELoss()  # GAN loss (MSE between real and generated)
criterion_L1 = nn.L1Loss()  # Reconstruction loss (L1 loss between generated and reference)

# Training function

In [None]:
def train(generator, discriminator, train_loader, optimizer_G, optimizer_D, epochs=20):
    # Training loop for the GAN model.

    for epoch in range(epochs):
        for i, batch in enumerate(train_loader):
            raw_image = batch['raw'].to(device)
            reference_image = batch['reference'].to(device)

            # Train Discriminator
            fake_image = generator(raw_image)

            # Real loss (real images should be classified as real)
            real_output = discriminator(raw_image, reference_image)
            real_loss = criterion_GAN(real_output, torch.ones_like(real_output))

            # Fake loss (generated images should be classified as fake)
            fake_output = discriminator(raw_image, fake_image.detach())
            fake_loss = criterion_GAN(fake_output, torch.zeros_like(fake_output))

            d_loss = (real_loss + fake_loss) / 2

            optimizer_D.zero_grad()
            d_loss.backward()
            optimizer_D.step()

            # Train Generator
            fake_output = discriminator(raw_image, fake_image)
            g_gan_loss = criterion_GAN(fake_output, torch.ones_like(fake_output))  # GAN loss for generator
            g_l1_loss = criterion_L1(fake_image, reference_image) * 100  # L1 loss (reconstruction)
            g_loss = g_gan_loss + g_l1_loss

            optimizer_G.zero_grad()
            g_loss.backward()
            optimizer_G.step()

            # Print progress
            if i % 10 == 0:
                print(f"Epoch [{epoch}/{epochs}], Step [{i}/{len(train_loader)}], D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")

# Evaluation Function

In [None]:
def evaluate(generator, test_loader):

    # Evaluate the GAN model on the test dataset using PSNR, SSIM, and MSE.

    psnr_scores = []  # Store PSNR values
    ssim_scores = []  # Store SSIM values
    mse_scores = []  # Store MSE values

    with torch.no_grad():
        for batch in test_loader:  # Iterate through batches in the test_loader
            raw_image = batch['raw'].to(device)
            reference_image = batch['reference'].to(device)

            fake_image = generator(raw_image)  # Generate fake images using the generator

            # Iterate through images within the batch
            for i in range(raw_image.size(0)):
                # Extract individual images
                fake_image_np = fake_image[i].cpu().numpy().transpose(1, 2, 0)
                reference_image_np = reference_image[i].cpu().numpy().transpose(1, 2, 0)

                # Ensure images are in the correct data type and range
                fake_image_np = np.clip(fake_image_np * 0.5 + 0.5, 0, 1)  # Unnormalize and clip
                reference_image_np = np.clip(reference_image_np * 0.5 + 0.5, 0, 1)  # Unnormalize and clip

                # Calculate metrics for individual images
                psnr_value = psnr(reference_image_np, fake_image_np, data_range=1.0)
                ssim_value = ssim(reference_image_np, fake_image_np, multichannel=True, win_size= 3, data_range=1.0)
                mse_value = mse_loss(torch.from_numpy(fake_image_np), torch.from_numpy(reference_image_np)).item()

                # Append scores for individual images
                psnr_scores.append(psnr_value)
                ssim_scores.append(ssim_value)
                mse_scores.append(mse_value)

    # Calculate average scores
    if psnr_scores and ssim_scores and mse_scores:
        avg_psnr = np.mean(psnr_scores)
        avg_ssim = np.mean(ssim_scores)
        avg_mse = np.mean(mse_scores)
        print(f"Average PSNR: {avg_psnr:.4f}, Average SSIM: {avg_ssim:.4f}, Average MSE: {avg_mse:.4f}")
    else:
        print("No scores to calculate average.")

# Main function

In [None]:
def main():
    # Check and remove unmatched images
    check_and_remove_unmatched_images(TRAIN_RAW_PATH, TRAIN_REFERENCE_PATH)
    check_and_remove_unmatched_images(TEST_RAW_PATH, TEST_REFERENCE_PATH)

    # Initialize the dataset and dataloaders
    train_dataset = UnderwaterImageDataset(raw_dir=TRAIN_RAW_PATH, reference_dir=TRAIN_REFERENCE_PATH, transform=transform)
    test_dataset = UnderwaterImageDataset(raw_dir=TEST_RAW_PATH, reference_dir=TEST_REFERENCE_PATH, transform=transform)

    # Ensure datasets have equal length
    min_len = min(len(train_dataset), len(test_dataset))
    train_dataset = torch.utils.data.Subset(train_dataset, range(min_len))
    test_dataset = torch.utils.data.Subset(test_dataset, range(min_len))

    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=2)

    # Initialize the generator and discriminator
    generator = UNetGenerator().to(device)
    discriminator = PatchGANDiscriminator().to(device)

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.001, betas=(0.5, 0.999))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.001, betas=(0.5, 0.999))

    # Train the model
    train(generator, discriminator, train_loader, optimizer_G, optimizer_D, epochs=20)

    # Evaluate the model
    evaluate(generator, test_loader)

# Run Training

In [None]:
if __name__ == "__main__":
    main()

Removed unmatched raw image: 301_img_ (1).png
Removed unmatched reference image: 95_img_ (1).png
Removed unmatched reference image: 96_img_ (1).png
Removed unmatched reference image: 928_img_ (1).png
Removed unmatched reference image: 9907 (1).png
Removed unmatched reference image: 929_img_ (1).png
Removed unmatched reference image: 97_img_ (1).png
Removed unmatched reference image: 9896 (1).png
Removed unmatched reference image: 9947 (1).png
Removed unmatched reference image: 9567 (1).png
Removed unmatched reference image: 92_img_ (1).png
Removed unmatched raw image: 494_img_ (1).png
Removed unmatched raw image: 4_img_ (1).png
Removed unmatched raw image: 491_img_ (1).png
Removed unmatched raw image: 50_img_ (1).png
Removed unmatched raw image: 495_img_ (1).png
Removed unmatched raw image: 502_img_ (1).png
Epoch [0/20], Step [0/12], D Loss: 0.3934, G Loss: 75.1066
Epoch [0/20], Step [10/12], D Loss: 0.2738, G Loss: 25.2139
Epoch [1/20], Step [0/12], D Loss: 0.2598, G Loss: 25.1609
Epo