In [23]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import os

In [24]:
# Define the forward diffusion process
def forward_diffusion(x0, noise, t, T):
    alpha = 1 - (t / T)  # Simple linear schedule
    alpha = alpha.view(-1, 1, 1, 1)  # Reshape to (batch_size, 1, 1, 1) for broadcasting
    return alpha * x0 + (1 - alpha) * noise


# Simple U-Net-like model for reverse process (denoising)
class SimpleUNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(SimpleUNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, out_channels, kernel_size=3, padding=1),
            nn.Tanh()
        )
    
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# Custom Dataset Class
class UnderwaterImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.raw_images = sorted(os.listdir(os.path.join(root_dir, 'Train/Raw')))
        self.reference_images = sorted(os.listdir(os.path.join(root_dir, 'Train/Reference')))

    def __len__(self):
        return len(self.raw_images)

    def __getitem__(self, idx):
        raw_image_path = os.path.join(self.root_dir, 'Train/Raw', self.raw_images[idx])
        reference_image_path = os.path.join(self.root_dir, 'Train/Reference', self.reference_images[idx])

        raw_image = Image.open(raw_image_path).convert("RGB")
        reference_image = Image.open(reference_image_path).convert("RGB")

        if self.transform:
            raw_image = self.transform(raw_image)
            reference_image = self.transform(reference_image)

        return raw_image, reference_image

In [25]:
# Hyperparameters
T = 1000  # Number of time steps
batch_size = 32
learning_rate = 1e-4
epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data Transformation
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  # Normalizing to [-1, 1]
])

# Create DataLoader
train_dataset = UnderwaterImageDataset(root_dir='Dataset',transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Model, optimizer, loss
model = SimpleUNet(in_channels=3, out_channels=3).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()

In [27]:
# Training loop
for epoch in range(epochs):
    model.train()
    
    for batch_idx, (raw_images, ref_images) in enumerate(train_loader):
        raw_images, ref_images = raw_images.to(device), ref_images.to(device)
        optimizer.zero_grad()
        
        # Apply forward diffusion (add noise)
        noise = torch.randn_like(raw_images)
        t = torch.randint(0, T, (raw_images.shape[0],)).to(device)  # Random time steps
        xt = forward_diffusion(raw_images, noise, t, T)
        
        # Denoising step: predict original image from noisy image
        reconstructed = model(xt)
        loss = criterion(reconstructed, ref_images)  # Compare to reference image
        
        loss.backward()
        optimizer.step()
        print(f'Batch {batch_idx} in Epoch {epoch + 1}/{epochs}, Loss: {loss.item()}')

Batch 0 in Epoch 1/10, Loss: 0.21540425717830658
Batch 1 in Epoch 1/10, Loss: 0.2042238563299179
Batch 2 in Epoch 1/10, Loss: 0.20347857475280762
Batch 3 in Epoch 1/10, Loss: 0.2022009640932083
Batch 4 in Epoch 1/10, Loss: 0.21769355237483978
Batch 5 in Epoch 1/10, Loss: 0.2031174898147583
Batch 6 in Epoch 1/10, Loss: 0.2085644155740738
Batch 7 in Epoch 1/10, Loss: 0.19531960785388947
Batch 8 in Epoch 1/10, Loss: 0.1828293651342392
Batch 9 in Epoch 1/10, Loss: 0.20270180702209473
Batch 10 in Epoch 1/10, Loss: 0.20430169999599457
Batch 11 in Epoch 1/10, Loss: 0.16860926151275635
Batch 12 in Epoch 1/10, Loss: 0.19175750017166138
Batch 13 in Epoch 1/10, Loss: 0.17395873367786407
Batch 14 in Epoch 1/10, Loss: 0.1756550818681717
Batch 15 in Epoch 1/10, Loss: 0.16906356811523438
Batch 16 in Epoch 1/10, Loss: 0.1445644348859787
Batch 17 in Epoch 1/10, Loss: 0.1587427407503128
Batch 18 in Epoch 1/10, Loss: 0.15751734375953674
Batch 19 in Epoch 1/10, Loss: 0.14399781823158264
Batch 20 in Epoch 

In [35]:
import torch
import torch.nn.functional as F

# MSE Calculation
def mse(image1, image2):
    return F.mse_loss(image1, image2)

# PSNR Calculation
def psnr(image1, image2, max_val=1.0):
    mse_value = mse(image1, image2)
    psnr_value = 10 * torch.log10(max_val ** 2 / mse_value)
    return psnr_value

# SSIM Calculation
def ssim(image1, image2, C1=0.01**2, C2=0.03**2):
    mu1 = F.avg_pool2d(image1, kernel_size=11, stride=1, padding=5)
    mu2 = F.avg_pool2d(image2, kernel_size=11, stride=1, padding=5)
    
    mu1_sq = mu1 ** 2
    mu2_sq = mu2 ** 2
    mu1_mu2 = mu1 * mu2

    sigma1_sq = F.avg_pool2d(image1 ** 2, kernel_size=11, stride=1, padding=5) - mu1_sq
    sigma2_sq = F.avg_pool2d(image2 ** 2, kernel_size=11, stride=1, padding=5) - mu2_sq
    sigma12 = F.avg_pool2d(image1 * image2, kernel_size=11, stride=1, padding=5) - mu1_mu2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
    return ssim_map.mean()

In [36]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# Define the custom dataset
class PairedImageDataset(Dataset):
    def __init__(self, raw_dir, reference_dir, transform=None):
        self.raw_dir = raw_dir
        self.reference_dir = reference_dir
        self.transform = transform
        self.image_filenames = os.listdir(raw_dir)

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image_filename = self.image_filenames[idx]
        raw_image_path = os.path.join(self.raw_dir, image_filename)
        reference_image_path = os.path.join(self.reference_dir, image_filename)

        raw_image = Image.open(raw_image_path).convert('RGB')
        reference_image = Image.open(reference_image_path).convert('RGB')

        if self.transform:
            raw_image = self.transform(raw_image)
            reference_image = self.transform(reference_image)

        return raw_image, reference_image

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

raw_dir = 'Dataset/Test/Raw'
reference_dir = 'Dataset/Test/Reference'
test_dataset = PairedImageDataset(raw_dir, reference_dir, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

import torch

# Set the generator to evaluation mode
model.eval()

# Initialize accumulators for MSE, PSNR, SSIM
mse_total = 0.0
psnr_total = 0.0
ssim_total = 0.0
num_samples = len(test_loader)

for raw_image, reference_image in test_loader:
    with torch.no_grad():
        generated_image = model(raw_image)

    generated_image = (generated_image + 1) / 2
    reference_image = (reference_image + 1) / 2

    # Calculate MSE, PSNR, SSIM for this pair
    mse_value = mse(generated_image, reference_image)
    psnr_value = psnr(generated_image, reference_image)
    ssim_value = ssim(generated_image, reference_image)

    mse_total += mse_value.item()
    psnr_total += psnr_value.item()
    ssim_total += ssim_value.item()

mean_mse = mse_total / num_samples
mean_psnr = psnr_total / num_samples
mean_ssim = ssim_total / num_samples

print(f"Mean MSE: {mean_mse}")
print(f"Mean PSNR: {mean_psnr}")
print(f"Mean SSIM: {mean_ssim}")

Mean MSE: 0.025100612721258873
Mean PSNR: 17.150235718174983
Mean SSIM: 0.6281609176805145
