In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import os
import torchvision.transforms as transforms
import torch.nn as nn
import torchvision
import math
import matplotlib.pyplot as plt
import torch
import numpy as np
import PIL
from PIL import Image

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")


In [None]:
def load_image_pairs(raw_image_folder, ref_image_folder):
    raw_images = []
    ref_images = []
    
    for filename in os.listdir(raw_image_folder):
        raw_image_path = os.path.join(raw_image_folder, filename)
        ref_image_path = os.path.join(ref_image_folder, filename)
        
        if os.path.exists(ref_image_path):
            raw_image = Image.open(raw_image_path)
            ref_image = Image.open(ref_image_path)
            
            raw_images.append(raw_image)
            ref_images.append(ref_image)
    
    return raw_images, ref_images

raw_image_folder = '/kaggle/input/underwater4/Train/Raw'
ref_image_folder = '/kaggle/input/underwater4/Train/Raw'
raw_images, ref_images = load_image_pairs(raw_image_folder, ref_image_folder)


In [None]:
class DiffusionModel:
    def __init__(self, start_schedule=0.0001, end_schedule=0.02, timesteps=300):
        self.start_schedule = start_schedule
        self.end_schedule = end_schedule
        self.timesteps = timesteps
        self.betas = torch.linspace(start_schedule, end_schedule, timesteps)
        self.alphas = 1 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)

    def forward(self, x_0, t, device):
        noise = torch.randn_like(x_0)
        sqrt_alphas_cumprod_t = self.get_index_from_list(self.alphas_cumprod.sqrt(), t, x_0.shape)
        sqrt_one_minus_alphas_cumprod_t = self.get_index_from_list(torch.sqrt(1. - self.alphas_cumprod), t, x_0.shape)
        mean = sqrt_alphas_cumprod_t.to(device) * x_0.to(device)
        variance = sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device)
        return mean + variance, noise.to(device)
    
    @torch.no_grad()
    def backward(self, x, t, model, **kwargs):
        betas_t = self.get_index_from_list(self.betas, t, x.shape)
        sqrt_one_minus_alphas_cumprod_t = self.get_index_from_list(torch.sqrt(1. - self.alphas_cumprod), t, x.shape)
        sqrt_recip_alphas_t = self.get_index_from_list(torch.sqrt(1.0 / self.alphas), t, x.shape)
        mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t, **kwargs) / sqrt_one_minus_alphas_cumprod_t)
        posterior_variance_t = betas_t

        # Check if all timesteps are 0, otherwise generate noise
        if t.eq(0).all():  # Ensuring it works for batch tensors
            return mean
        else:
            noise = torch.randn_like(x)
            variance = torch.sqrt(posterior_variance_t) * noise
            return mean + variance


    @staticmethod
    def get_index_from_list(values, t, x_shape):
        batch_size = t.shape[0]
        result = values.gather(-1, t.cpu())
        return result.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)



In [None]:
IMAGE_SHAPE = (128, 128)

transform = transforms.Compose([
    transforms.Resize(IMAGE_SHAPE), 
    transforms.ToTensor(), 
    transforms.Lambda(lambda t: (t * 2) - 1), 
])

reverse_transform = transforms.Compose([
    transforms.Lambda(lambda t: (t + 1) / 2), 
    transforms.Lambda(lambda t: t.permute(1, 2, 0)), 
    transforms.Lambda(lambda t: t * 255.), 
    transforms.Lambda(lambda t: t.cpu().numpy().astype(np.uint8)), 
    transforms.ToPILImage(), 
])



In [None]:
class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

class Block(nn.Module):
    def __init__(self, channels_in, channels_out, time_embedding_dims, num_filters=3, downsample=True):
        super().__init__()
        self.time_embedding = SinusoidalPositionEmbeddings(time_embedding_dims)
        self.downsample = downsample
        
        if downsample:
            self.conv1 = nn.Conv2d(channels_in, channels_out, num_filters, padding=1)
            self.final = nn.Conv2d(channels_out, channels_out, 4, 2, 1)
        else:
            self.conv1 = nn.Conv2d(2 * channels_in, channels_out, num_filters, padding=1)
            self.final = nn.ConvTranspose2d(channels_out, channels_out, 4, 2, 1)
        
        self.bnorm1 = nn.BatchNorm2d(channels_out)
        self.bnorm2 = nn.BatchNorm2d(channels_out)
        self.conv2 = nn.Conv2d(channels_out, channels_out, 3, padding=1)
        self.time_mlp = nn.Linear(time_embedding_dims, channels_out)
        self.relu = nn.ReLU()

    def forward(self, x, t):
        o = self.bnorm1(self.relu(self.conv1(x)))
        o_time = self.relu(self.time_mlp(self.time_embedding(t)))
        o = o + o_time[..., None, None]
        o = self.bnorm2(self.relu(self.conv2(o)))
        return self.final(o)


class UNet(nn.Module):
    def __init__(self, img_channels=3, time_embedding_dims=128, sequence_channels=(64, 128, 256, 512)):
        super().__init__()
        self.downsampling = nn.ModuleList([Block(ch_in, ch_out, time_embedding_dims) for ch_in, ch_out in zip(sequence_channels, sequence_channels[1:])])
        self.upsampling = nn.ModuleList([Block(ch_in, ch_out, time_embedding_dims, downsample=False) for ch_in, ch_out in zip(sequence_channels[::-1], sequence_channels[::-1][1:])])
        self.conv1 = nn.Conv2d(img_channels, sequence_channels[0], 3, padding=1)
        self.conv2 = nn.Conv2d(sequence_channels[0], img_channels, 1)

    def forward(self, x, t):
        residuals = []
        o = self.conv1(x)
        for ds in self.downsampling:
            o = ds(o, t)
            residuals.append(o)
        for us, res in zip(self.upsampling, reversed(residuals)):
            o = us(torch.cat((o, res), dim=1), t)
        return self.conv2(o)


In [None]:
def plot_noise_distribution(noise, predicted_noise):
    plt.hist(noise.cpu().detach().numpy().flatten(), density=True, alpha=0.8, label="ground truth noise")
    plt.hist(predicted_noise.cpu().detach().numpy().flatten(), density=True, alpha=0.8, label="predicted noise")
    plt.legend()
    plt.show()


In [None]:
def plot_noise_prediction(noise, predicted_noise):
    plt.figure(figsize=(15, 15))
    f, ax = plt.subplots(1, 2, figsize=(5, 5))
    ax[0].imshow(reverse_transform(noise.detach()))
    ax[0].set_title("ground truth noise", fontsize=10)
    ax[1].imshow(reverse_transform(predicted_noise.detach()))
    ax[1].set_title("predicted noise", fontsize=10)
    plt.show()


In [None]:
import random

NO_EPOCHS = 2000
PRINT_FREQUENCY = 50
LR = 0.001
BATCH_SIZE = 16

unet = UNet().to(device)
optimizer = torch.optim.Adam(unet.parameters(), lr=LR)

diffusion_model = DiffusionModel()

for epoch in range(NO_EPOCHS):
    mean_epoch_loss = []
    
    # Shuffle the raw images at the start of each epoch
    shuffled_indices = list(range(len(raw_images)))
    random.shuffle(shuffled_indices)
    
    # Selecting a random batch from shuffled images
    batch_raw_images = [transform(raw_images[i]).unsqueeze(0) for i in shuffled_indices[:BATCH_SIZE]]
    batch_ref_images = [ref_images[i] for i in shuffled_indices[:BATCH_SIZE]]
    
    batch_raw_images = torch.cat(batch_raw_images).to(device)

    optimizer.zero_grad()
    
    t = torch.randint(0, diffusion_model.timesteps, (BATCH_SIZE,), device=device).long()
    x_t, noise = diffusion_model.forward(batch_raw_images, t, device)
    
    noise_pred = unet(x_t, t)
    loss = nn.MSELoss()(noise_pred, noise)
    
    loss.backward()
    optimizer.step()
    
    mean_epoch_loss.append(loss.item())
    
    if (epoch % PRINT_FREQUENCY == 0):
        print(f"Epoch {epoch} | Loss: {np.mean(mean_epoch_loss)}")

        denoised_images = diffusion_model.backward(x_t, t, unet)
        
        for i in range(3):  # Show 3 images
            # Plot raw image, reference image, and denoised image
            plt.figure(figsize=(15, 15))
            f, ax = plt.subplots(1, 3, figsize=(15, 15))
            
            ax[0].imshow(reverse_transform(batch_raw_images[i].detach()))
            ax[0].set_title(f"Raw Image (Epoch {epoch})", fontsize=10)
            
            ax[1].imshow(batch_ref_images[i])
            ax[1].set_title(f"Reference Image", fontsize=10)
            
            ax[2].imshow(reverse_transform(denoised_images[i].detach()))
            ax[2].set_title(f"Denoised Image (Epoch {epoch})", fontsize=10)
            
            plt.show()

        # Plot noise distribution and prediction
        plot_noise_distribution(noise, noise_pred)
        plot_noise_prediction(noise[0], noise_pred[0])

In [None]:
# Add a perceptual loss function (you can use MSE here as a simple example)
def combined_loss(denoised_images, reference_images, noise_pred, noise, alpha=0.5):
    # MSE loss between denoised images and reference images
    mse_loss = nn.MSELoss()(denoised_images, reference_images)
    # MSE loss between predicted noise and actual noise
    noise_loss = nn.MSELoss()(noise_pred, noise)
    # Combine losses
    return alpha * mse_loss + (1 - alpha) * noise_loss
import random

NO_EPOCHS = 2000
PRINT_FREQUENCY = 50
LR = 0.001
BATCH_SIZE = 16
unet = UNet().to(device)
optimizer = torch.optim.Adam(unet.parameters(), lr=LR)

diffusion_model = DiffusionModel()
# Training loop
for epoch in range(NO_EPOCHS):
    mean_epoch_loss = []
    
    # Shuffle the raw images at the start of each epoch
    shuffled_indices = list(range(len(raw_images)))
    random.shuffle(shuffled_indices)
    
    # Selecting a random batch from shuffled images
    batch_raw_images = [transform(raw_images[i]).unsqueeze(0) for i in shuffled_indices[:BATCH_SIZE]]
    batch_ref_images = [transform(ref_images[i]).unsqueeze(0) for i in shuffled_indices[:BATCH_SIZE]]
    
    batch_raw_images = torch.cat(batch_raw_images).to(device)
    batch_ref_images = torch.cat(batch_ref_images).to(device)

    optimizer.zero_grad()
    
    t = torch.randint(0, diffusion_model.timesteps, (BATCH_SIZE,), device=device).long()
    x_t, noise = diffusion_model.forward(batch_raw_images, t, device)
    
    noise_pred = unet(x_t, t)
    # Compute the combined loss
    loss = combined_loss(diffusion_model.backward(x_t, t, unet), batch_ref_images, noise_pred, noise)

    loss.backward()
    optimizer.step()
    
    mean_epoch_loss.append(loss.item())
    
    if (epoch % PRINT_FREQUENCY == 0):
        print(f"Epoch {epoch} | Loss: {np.mean(mean_epoch_loss)}")
        
        denoised_images = diffusion_model.backward(x_t, t, unet)
        
        for i in range(3):  # Show 3 images
            # Plot raw image, reference image, and denoised image
            plt.figure(figsize=(15, 15))
            f, ax = plt.subplots(1, 3, figsize=(15, 15))
            
            ax[0].imshow(reverse_transform(batch_raw_images[i].detach()))
            ax[0].set_title(f"Raw Image (Epoch {epoch})", fontsize=10)
            
            ax[1].imshow(reverse_transform(batch_ref_images[i].detach()))
            ax[1].set_title(f"Reference Image", fontsize=10)
            
            ax[2].imshow(reverse_transform(denoised_images[i].detach()))
            ax[2].set_title(f"Denoised Image (Epoch {epoch})", fontsize=10)
            
            plt.show()

        # Plot noise distribution and prediction
        plot_noise_distribution(noise, noise_pred)
        plot_noise_prediction(noise[0], noise_pred[0])


In [None]:
import os
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from skimage.metrics import structural_similarity as ssim
from PIL import Image  # Ensure to import Image
import torch.nn as nn  # Ensure nn is imported

# Save the trained model
model_save_path = 'diffusion_model.pth'
torch.save(unet.state_dict(), model_save_path)
print(f'Model saved to {model_save_path}')

# Load images for testing
def load_test_image_pairs(raw_image_folder, ref_image_folder):
    raw_images = []
    ref_images = []
    
    for filename in os.listdir(raw_image_folder):
        raw_image_path = os.path.join(raw_image_folder, filename)
        ref_image_path = os.path.join(ref_image_folder, filename)
        
        if os.path.exists(ref_image_path):
            raw_image = Image.open(raw_image_path).convert('RGB')  # Convert to RGB if not already
            ref_image = Image.open(ref_image_path).convert('RGB')
            
            raw_images.append(raw_image)
            ref_images.append(ref_image)
    
    return raw_images, ref_images

# Define PSNR function
def psnr(denoised_image, ref_image):
    mse_value = nn.MSELoss()(denoised_image, ref_image).item()
    if mse_value == 0:
        return float('inf')  # If MSE is 0, PSNR is infinite
    max_pixel = 1.0  # Since the images are normalized to [-1, 1]
    psnr_value = 20 * torch.log10(max_pixel / torch.sqrt(torch.tensor(mse_value)))
    return psnr_value.item()

# Load your test dataset
test_raw_folder = '/kaggle/input/underwater4/Test/Raw'  # Update with your test raw images path
test_ref_folder = '/kaggle/input/underwater4/Test/Reference'  # Update with your test reference images path
test_raw_images, test_ref_images = load_test_image_pairs(test_raw_folder, test_ref_folder)

# Transform the test images
test_transform = transforms.Compose([
    transforms.Resize(IMAGE_SHAPE),
    transforms.ToTensor(),
    transforms.Lambda(lambda t: (t * 2) - 1),  # Normalize to [-1, 1]
])

# Initialize metrics lists
mse_values = []
psnr_values = []
ssim_values = []

# Evaluate the model on test images
unet.eval()  # Set model to evaluation mode

# Testing loop
for i in range(len(test_raw_images)):
    # Preprocess the raw image
    raw_image = test_transform(test_raw_images[i]).unsqueeze(0).to(device)
    ref_image = test_transform(test_ref_images[i]).unsqueeze(0).to(device)

    with torch.no_grad():
        denoised_image = unet(raw_image, torch.tensor([0], device=device))  # You can set a fixed timestep here

    # Compute metrics
    mse_value = nn.MSELoss()(denoised_image, ref_image).item()
    mse_values.append(mse_value)

    psnr_value = psnr(denoised_image, ref_image)
    psnr_values.append(psnr_value)

    # Update SSIM calculation
    ssim_value = ssim(
        denoised_image.squeeze(0).cpu().numpy().transpose(1, 2, 0), 
        ref_image.squeeze(0).cpu().numpy().transpose(1, 2, 0), 
        multichannel=True,
        win_size=3,  # Set the window size to 3
        data_range=2  # Set the data range for normalized images
    )
    ssim_values.append(ssim_value)

    # Reverse the transformation for visualization
    denoised_image_np = reverse_transform(denoised_image.squeeze(0).cpu())
    ref_image_np = reverse_transform(ref_image.squeeze(0).cpu())

    # Display results
    plt.figure(figsize=(15, 15))
    plt.subplot(1, 3, 1)
    plt.imshow(reverse_transform(raw_image.squeeze(0).cpu()))
    plt.title("Raw Image")
    
    plt.subplot(1, 3, 2)
    plt.imshow(ref_image_np)
    plt.title("Reference Image")
    
    plt.subplot(1, 3, 3)
    plt.imshow(denoised_image_np)
    plt.title("Denoised Image")
    
    plt.show()

# Display average metrics
print(f"Average MSE: {np.mean(mse_values):.4f}")
print(f"Average PSNR: {np.mean(psnr_values):.4f}")
print(f"Average SSIM: {np.mean(ssim_values):.4f}")

