<a href="https://colab.research.google.com/github/sans-mishra/Underwater-Image-Enhancement-using-VAE-GAN-and-Diffusion/blob/main/Diffusion_image_enhancement.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import required libraries

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import numpy as np
from skimage.metrics import peak_signal_noise_ratio, structural_similarity, mean_squared_error


# Define the forward diffusion process

In [None]:
def forward_diffusion(x0, noise, t, T):

    t = t.to(x0.device)
    # Ensure t is a float tensor
    t = t.float()

    # Reshape alpha to match x0 and noise
    alpha = 1 - (t / T)[:, None, None, None]

    return alpha * x0 + (1 - alpha) * noise

# U-Net model for denoising

In [None]:
class SimpleUNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(SimpleUNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, out_channels, kernel_size=3, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)  # Encode input
        x = self.decoder(x)  # Decode back to image
        return x

# Custom dataset for underwater image enhancement

In [None]:
class UnderwaterDataset(Dataset):
    def __init__(self, raw_dir, ref_dir, transform=None):
        self.raw_dir = raw_dir
        self.ref_dir = ref_dir
        self.transform = transform
        self.raw_images = os.listdir(raw_dir)

    def __len__(self):
        return len(self.raw_images)

    def __getitem__(self, idx):
        raw_image = Image.open(os.path.join(self.raw_dir, self.raw_images[idx])).convert('RGB')
        ref_image = Image.open(os.path.join(self.ref_dir, self.raw_images[idx])).convert('RGB')

        if self.transform:
            raw_image = self.transform(raw_image)
            ref_image = self.transform(ref_image)

        return raw_image, ref_image

# Transformation: resize and normalize images

In [None]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Data Loading for Training data

In [None]:
# Define dataset paths
train_raw_dir = '/content/drive/MyDrive/archive (2)/Train/Raw'
train_ref_dir = '/content/drive/MyDrive/archive (2)/Train/Reference'

# DataLoader for the training data
train_dataset = UnderwaterDataset(train_raw_dir, train_ref_dir, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)


# Model, optimizer, and loss function

In [None]:
model = SimpleUNet(in_channels=3, out_channels=3).cuda()  # 3 channels for RGB
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()

# Training loop

In [None]:
epochs = 10
T = 1000  # Diffusion steps
for epoch in range(epochs):
    model.train()

    for batch_idx, (raw, ref) in enumerate(train_loader):
        raw, ref = raw.cuda(), ref.cuda()
        optimizer.zero_grad()

        noise = torch.randn_like(raw)  # Gaussian noise
        t = torch.randint(0, T, (raw.shape[0],))  # Random time steps
        xt = forward_diffusion(raw, noise, t, T)  # Apply noise

        reconstructed = model(xt)  # Denoise
        loss = criterion(reconstructed, ref)  # Compute loss
        loss.backward()
        optimizer.step()

    print(f'Epoch {epoch+1}/{epochs}, Loss: {loss.item()}')

Epoch 1/10, Loss: 0.4498656988143921
Epoch 2/10, Loss: 0.28697434067726135
Epoch 3/10, Loss: 0.2705778181552887
Epoch 4/10, Loss: 0.25321105122566223
Epoch 5/10, Loss: 0.24637527763843536
Epoch 6/10, Loss: 0.24316878616809845
Epoch 7/10, Loss: 0.26397451758384705
Epoch 8/10, Loss: 0.23810561001300812
Epoch 9/10, Loss: 0.2239563763141632
Epoch 10/10, Loss: 0.23351508378982544


# Evaluation function for PSNR, SSIM, and MSE

In [None]:
def evaluate_metrics(original, enhanced):
    psnr = peak_signal_noise_ratio(original, enhanced)
    min_value = np.min(ref)  # Minimum pixel value in 'ref' image
    max_value = np.max(ref)  # Maximum pixel value in 'ref' image
    data_range = max_value - min_value  # Calculate data range

    ssim = structural_similarity(original, enhanced, multichannel=True, win_size=3, data_range=data_range)
    mse = mean_squared_error(original, enhanced)
    return psnr, ssim, mse

# Load test data

In [None]:
test_raw_dir = '/content/drive/MyDrive/archive (2)/Test/Raw'
test_ref_dir = '/content/drive/MyDrive/archive (2)/Test/Reference'

test_dataset = UnderwaterDataset(test_raw_dir, test_ref_dir, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# Evaluate the model

In [None]:
model.eval()
with torch.no_grad():
    for raw, ref in test_loader:
        raw = raw.cuda()
        noise = torch.randn_like(raw)  # Generate random noise
        t = torch.randint(0, T, (raw.shape[0],))  # Random time step
        xt = forward_diffusion(raw, noise, t, T)  # Apply noise

        enhanced = model(xt)  # Denoise

        # Move tensors to CPU for evaluation
        ref = ref.squeeze().permute(1, 2, 0).cpu().numpy()
        enhanced = enhanced.squeeze().permute(1, 2, 0).cpu().numpy()

        # Evaluate PSNR, SSIM, and MSE
        psnr, ssim, mse = evaluate_metrics(ref, enhanced)
        print(f"PSNR: {psnr}, SSIM: {ssim}, MSE: {mse}")

PSNR: 13.13808137811682, SSIM: -0.01711915610157074, MSE: 0.1942011751204478
PSNR: 12.014139418081399, SSIM: 0.01631523194941615, MSE: 0.2515625854222137
PSNR: 13.397052816903445, SSIM: 0.0004944980679069372, MSE: 0.18295939253025043
PSNR: 13.282965473591076, SSIM: 0.02898224223302146, MSE: 0.18782934497259704
PSNR: 11.634252024801274, SSIM: -0.1052635010717341, MSE: 0.2745584338478135
PSNR: 11.5041275445527, SSIM: 0.01608662752840337, MSE: 0.28290930829742283
PSNR: 14.26549004274751, SSIM: -0.008399057130857642, MSE: 0.14979971498335626
PSNR: 12.05526617798798, SSIM: -0.07899924693957412, MSE: 0.2491915856144131
PSNR: 11.377977897766272, SSIM: 0.013625137735428308, MSE: 0.2912474969036634
PSNR: 12.895117000675864, SSIM: 0.13819347176769972, MSE: 0.2053753380326512
PSNR: 11.814529260560143, SSIM: 0.05059349051651649, MSE: 0.26339472024787375
PSNR: 12.674528134336693, SSIM: 0.049218206935405066, MSE: 0.2160763215571154
PSNR: 11.726470340325534, SSIM: 0.05447072080301181, MSE: 0.26878990