<a href="https://colab.research.google.com/github/sans-mishra/Underwater-Image-Enhancement-using-VAE-GAN-and-Diffusion/blob/main/VAE_image_enhancement.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import required libraries

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import os
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as psnr, mean_squared_error as mse, structural_similarity as ssim

# Define transformations

In [None]:
# Transformation
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize images to 256x256
    transforms.ToTensor(),  # Convert images to tensors
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize to [-1, 1]
])

# Custom Dataset class for loading images

In [None]:
class UnderwaterDataset(Dataset):
    def __init__(self, raw_dir, reference_dir, transform=None):
        self.raw_dir = raw_dir
        self.reference_dir = reference_dir
        self.transform = transform
        self.raw_images = sorted(os.listdir(raw_dir))
        self.reference_images = sorted(os.listdir(reference_dir))

    def __len__(self):
        return len(self.raw_images)

    def __getitem__(self, idx):
        raw_path = os.path.join(self.raw_dir, self.raw_images[idx])
        reference_path = os.path.join(self.reference_dir, self.reference_images[idx])

        raw_image = Image.open(raw_path).convert("RGB")
        reference_image = Image.open(reference_path).convert("RGB")

        if self.transform:
            raw_image = self.transform(raw_image)
            reference_image = self.transform(reference_image)

        return raw_image, reference_image

# Define the VAE architecture

In [None]:
class VAE(nn.Module):
    def __init__(self, image_size):
        super(VAE, self).__init__()
        self.image_size = image_size

        # Encoder
        self.enc_conv1 = nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1)
        self.enc_conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
        self.flattened_size = 128 * (image_size[1] // 4) * (image_size[2] // 4)
        self.enc_fc1 = nn.Linear(self.flattened_size, 256)
        self.enc_fc21 = nn.Linear(256, 256)  # Mu (mean)
        self.enc_fc22 = nn.Linear(256, 256)  # Logvar (log variance)
        self.dropout = nn.Dropout(0.2)  # Dropout layer

        # Decoder
        self.dec_fc1 = nn.Linear(256, 256)
        self.dec_fc2 = nn.Linear(256, self.flattened_size)
        self.dec_conv1 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)
        self.dec_conv2 = nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1)

    def encode(self, x):
        x = F.relu(self.enc_conv1(x))
        x = F.relu(self.enc_conv2(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.enc_fc1(x))
        x = self.dropout(x)  # Apply dropout in the encoder
        mu = self.enc_fc21(x)
        logvar = self.enc_fc22(x)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(mu)
        return mu + eps * std

    def decode(self, z):
        z = F.relu(self.dec_fc1(z))
        z = F.relu(self.dec_fc2(z))
        z = z.view(-1, 128, self.image_size[1] // 4, self.image_size[2] // 4)
        z = F.relu(self.dec_conv1(z))
        z = torch.sigmoid(self.dec_conv2(z))
        return z

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# Loss function for VAE (reconstruction + KL divergence)

In [None]:
def vae_loss(recon_x, x, mu, logvar):
    recon_loss = nn.functional.mse_loss(recon_x, x, reduction='mean')
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0)
    return recon_loss + 0.01 * kl_loss  # Adjusted weight for KL divergence

# Metrics for evaluation: PSNR, MSE, SSIM

In [None]:
def compute_metrics(output, reference):
    output_np = output.permute(1, 2, 0).cpu().numpy()  # Convert to HxWxC format
    reference_np = reference.permute(1, 2, 0).cpu().numpy()

    # Ensure both images are in range [0, 1]
    output_np = (output_np + 1) / 2  # Convert from [-1, 1] to [0, 1]
    reference_np = (reference_np + 1) / 2

    data_range=1.0
    psnr_val = psnr(reference_np, output_np, data_range=data_range)
    mse_val = mse(reference_np, output_np)
    ssim_val = ssim(reference_np, output_np, multichannel=True, win_size=3, data_range=data_range)

    return psnr_val, mse_val, ssim_val

# Main function with dataset paths and training/testing loops

In [None]:
def main():
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Define dataset paths
    train_raw_dir = '/content/drive/MyDrive/archive (2)/Train/Raw'
    train_reference_dir = '/content/drive/MyDrive/archive (2)/Train/Reference'
    test_raw_dir = '/content/drive/MyDrive/archive (2)/Test/Raw'
    test_reference_dir = '/content/drive/MyDrive/archive (2)/Test/Reference'

    # Load datasets
    train_dataset = UnderwaterDataset(train_raw_dir, train_reference_dir, transform=transform)
    test_dataset = UnderwaterDataset(test_raw_dir, test_reference_dir, transform=transform)

    # DataLoader for batching
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

    # Model, optimizer, and device setup
    image_size = (3, 256, 256)
    vae_model = VAE(image_size).to(device)
    optimizer = optim.Adam(vae_model.parameters(), lr=1e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

    # Training loop
    for epoch in range(30):  # Increased number of epochs
        vae_model.train()
        total_loss = 0
        for raw, reference in train_loader:
            raw, reference = raw.to(device), reference.to(device)
            optimizer.zero_grad()
            recon_batch, mu, logvar = vae_model(raw)
            loss = vae_loss(recon_batch, reference, mu, logvar)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(vae_model.parameters(), max_norm=1.0)
            optimizer.step()
            total_loss += loss.item()
        scheduler.step()
        print(f'Epoch {epoch + 1}, Loss: {total_loss / len(train_loader):.4f}')

    # Evaluation
    vae_model.eval()
    psnr_scores, mse_scores, ssim_scores = [], [], []
    with torch.no_grad():
        for raw, reference in test_loader:
            raw, reference = raw.to(device), reference.to(device)
            recon_batch, _, _ = vae_model(raw)
            for i in range(raw.size(0)):
                psnr_val, mse_val, ssim_val = compute_metrics(recon_batch[i], reference[i])
                psnr_scores.append(psnr_val)
                mse_scores.append(mse_val)
                ssim_scores.append(ssim_val)

    print(f'Average PSNR: {np.mean(psnr_scores):.4f}, Average MSE: {np.mean(mse_scores):.4f}, Average SSIM: {np.mean(ssim_scores):.4f}')



# Running function

In [None]:
if __name__ == "__main__":
    main()

Using device: cuda
Epoch 1, Loss: 0.5559
Epoch 2, Loss: 0.3159
Epoch 3, Loss: 0.2663
Epoch 4, Loss: 0.2631
Epoch 5, Loss: 0.2616
Epoch 6, Loss: 0.2606
Epoch 7, Loss: 0.2597
Epoch 8, Loss: 0.2587
Epoch 9, Loss: 0.2582
Epoch 10, Loss: 0.2575
Epoch 11, Loss: 0.2570
Epoch 12, Loss: 0.2566
Epoch 13, Loss: 0.2564
Epoch 14, Loss: 0.2564
Epoch 15, Loss: 0.2561
Epoch 16, Loss: 0.2560
Epoch 17, Loss: 0.2560
Epoch 18, Loss: 0.2557
Epoch 19, Loss: 0.2557
Epoch 20, Loss: 0.2557
Epoch 21, Loss: 0.2558
Epoch 22, Loss: 0.2554
Epoch 23, Loss: 0.2555
Epoch 24, Loss: 0.2553
Epoch 25, Loss: 0.2554
Epoch 26, Loss: 0.2555
Epoch 27, Loss: 0.2555
Epoch 28, Loss: 0.2555
Epoch 29, Loss: 0.2555
Epoch 30, Loss: 0.2554
Average PSNR: 12.1438, Average MSE: 0.0650, Average SSIM: 0.1592
