In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import numpy as np

In [10]:
# Custom Dataset Class
class UnderwaterImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.raw_images = sorted(os.listdir(os.path.join(root_dir, 'Train/Raw')))
        self.reference_images = sorted(os.listdir(os.path.join(root_dir, 'Train/Reference')))

    def __len__(self):
        return len(self.raw_images)

    def __getitem__(self, idx):
        raw_image_path = os.path.join(self.root_dir, 'Train/Raw', self.raw_images[idx])
        reference_image_path = os.path.join(self.root_dir, 'Train/Reference', self.reference_images[idx])

        raw_image = Image.open(raw_image_path).convert("RGB")
        reference_image = Image.open(reference_image_path).convert("RGB")

        if self.transform:
            raw_image = self.transform(raw_image)
            reference_image = self.transform(reference_image)

        return raw_image, reference_image

In [11]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 
])

dataset = UnderwaterImageDataset(root_dir='Dataset', transform=transform)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

In [12]:
# UNet Generator
class UNetGenerator(nn.Module):
    def __init__(self):
        super(UNetGenerator, self).__init__()
        self.encoder1 = nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1)  # Downsample
        self.encoder2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
        self.decoder1 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)  # Upsample
        self.decoder2 = nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1)

    def forward(self, x):
        enc1 = F.leaky_relu(self.encoder1(x))
        enc2 = F.leaky_relu(self.encoder2(enc1))
        dec1 = F.relu(self.decoder1(enc2))
        dec2 = self.decoder2(dec1)
        return dec2

class PatchDiscriminator(nn.Module):
    def __init__(self):
        super(PatchDiscriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(6, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 1, kernel_size=4, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

In [13]:
# L1 Loss Function
def l1_loss(y_true, y_pred):
    return F.l1_loss(y_true, y_pred)

In [14]:
# Hyperparameters
lambda_l1 = 100
learning_rate = 0.0001
num_epochs = 150
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

generator = UNetGenerator().to(device)
discriminator = PatchDiscriminator().to(device)

In [15]:
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    for i, (real_images, target_images) in enumerate(dataloader):
        real_images = real_images.to(device)
        target_images = target_images.to(device)

        # Train Discriminator
        optimizer_D.zero_grad()
        fake_images = generator(real_images)

        real_pairs = torch.cat((real_images, target_images), dim=1)
        fake_pairs = torch.cat((real_images, fake_images.detach()), dim=1)

        D_real = discriminator(real_pairs)
        D_fake = discriminator(fake_pairs)

        # Clamping outputs
        eps = 1e-8  # A small value to avoid log(0)
        loss_D = -torch.mean(torch.log(D_real + eps) + torch.log(1 - D_fake + eps))
        loss_D.backward()
        optimizer_D.step()


        # Train Generator
        optimizer_G.zero_grad()

        # Computation of G loss
        D_fake_for_G = discriminator(fake_pairs)
        loss_G_GAN = -torch.mean(torch.log(D_fake_for_G))
        loss_G_L1 = l1_loss(target_images, fake_images)
        loss_G = loss_G_GAN + lambda_l1 * loss_G_L1

        loss_G.backward()
        optimizer_G.step()

    print(f'Epoch [{epoch}/{num_epochs}], Loss D: {loss_D.item()}, Loss G: {loss_G.item()}')

Epoch [0/150], Loss D: 0.5153475999832153, Loss G: 24.847997665405273
Epoch [1/150], Loss D: 0.7337636947631836, Loss G: 18.99017906188965
Epoch [2/150], Loss D: 0.45549485087394714, Loss G: 30.96002769470215
Epoch [3/150], Loss D: 0.2736993432044983, Loss G: 23.886207580566406
Epoch [4/150], Loss D: 0.8936703205108643, Loss G: 20.60154151916504
Epoch [5/150], Loss D: 0.7140700817108154, Loss G: 16.593381881713867
Epoch [6/150], Loss D: 1.3349840641021729, Loss G: 19.49226951599121
Epoch [7/150], Loss D: 0.32761266827583313, Loss G: 23.42490577697754
Epoch [8/150], Loss D: 0.22850024700164795, Loss G: 23.12481117248535
Epoch [9/150], Loss D: 0.2803192138671875, Loss G: 21.566743850708008
Epoch [10/150], Loss D: 0.7787575721740723, Loss G: 23.647706985473633
Epoch [11/150], Loss D: 0.4643474817276001, Loss G: 23.781238555908203
Epoch [12/150], Loss D: 1.1090317964553833, Loss G: 15.731419563293457
Epoch [13/150], Loss D: 0.4522349238395691, Loss G: 16.636924743652344
Epoch [14/150], Los

In [16]:
import torch
import torch.nn.functional as F

# MSE Calculation
def mse(image1, image2):
    return F.mse_loss(image1, image2)

# PSNR Calculation
def psnr(image1, image2, max_val=1.0):
    mse_value = mse(image1, image2)
    psnr_value = 10 * torch.log10(max_val ** 2 / mse_value)
    return psnr_value

# SSIM Calculation
def ssim(image1, image2, C1=0.01**2, C2=0.03**2):
    mu1 = F.avg_pool2d(image1, kernel_size=11, stride=1, padding=5)
    mu2 = F.avg_pool2d(image2, kernel_size=11, stride=1, padding=5)
    
    mu1_sq = mu1 ** 2
    mu2_sq = mu2 ** 2
    mu1_mu2 = mu1 * mu2

    sigma1_sq = F.avg_pool2d(image1 ** 2, kernel_size=11, stride=1, padding=5) - mu1_sq
    sigma2_sq = F.avg_pool2d(image2 ** 2, kernel_size=11, stride=1, padding=5) - mu2_sq
    sigma12 = F.avg_pool2d(image1 * image2, kernel_size=11, stride=1, padding=5) - mu1_mu2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
    return ssim_map.mean()


In [17]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# Define the custom dataset
class PairedImageDataset(Dataset):
    def __init__(self, raw_dir, reference_dir, transform=None):
        self.raw_dir = raw_dir
        self.reference_dir = reference_dir
        self.transform = transform
        self.image_filenames = os.listdir(raw_dir)

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image_filename = self.image_filenames[idx]
        raw_image_path = os.path.join(self.raw_dir, image_filename)
        reference_image_path = os.path.join(self.reference_dir, image_filename)

        raw_image = Image.open(raw_image_path).convert('RGB')
        reference_image = Image.open(reference_image_path).convert('RGB')

        if self.transform:
            raw_image = self.transform(raw_image)
            reference_image = self.transform(reference_image)

        return raw_image, reference_image

transform = transforms.Compose([
    transforms.Resize((256, 256)), 
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 
])

raw_dir = 'Dataset/Test/Raw'
reference_dir = 'Dataset/Test/Reference'
test_dataset = PairedImageDataset(raw_dir, reference_dir, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

import torch

# Set the generator to evaluation mode
generator.eval()

# Initialize accumulators for MSE, PSNR, SSIM
mse_total = 0.0
psnr_total = 0.0
ssim_total = 0.0
num_samples = len(test_loader)

for raw_image, reference_image in test_loader:
    with torch.no_grad():
        generated_image = generator(raw_image)

    generated_image = (generated_image + 1) / 2
    reference_image = (reference_image + 1) / 2

    mse_value = mse(generated_image, reference_image)
    psnr_value = psnr(generated_image, reference_image)
    ssim_value = ssim(generated_image, reference_image)

    mse_total += mse_value.item()
    psnr_total += psnr_value.item()
    ssim_total += ssim_value.item()

mean_mse = mse_total / num_samples
mean_psnr = psnr_total / num_samples
mean_ssim = ssim_total / num_samples

print(f"Mean MSE: {mean_mse}")
print(f"Mean PSNR: {mean_psnr}")
print(f"Mean SSIM: {mean_ssim}")

Mean MSE: 0.0171629457023779
Mean PSNR: 19.315635078831722
Mean SSIM: 0.8253874237600126
