In [1]:
import torch                              
import torch.nn as nn                      
import torch.optim as optim                
from torch.utils.data import DataLoader    
import torchvision                       
import torchvision.transforms as transforms 
import numpy as np                         
import cv2                                
from skimage.metrics import structural_similarity as ssim  
import matplotlib.pyplot as plt           
import time
from pytorch_wavelets import DWT, IDWT

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


import os

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"


Using device: cuda


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cuda


In [4]:
def calculate_psnr(denoised, ground_truth):
    mse = np.mean((denoised - ground_truth) ** 2)
    if mse == 0:
        return float('inf')
    PIXEL_MAX = 1.0 
    psnr = 20 * np.log10(PIXEL_MAX / np.sqrt(mse))
    return psnr
from skimage.metrics import structural_similarity as ssim


def calculate_ssim(denoised, ground_truth):
    return ssim(ground_truth, denoised, data_range=ground_truth.max() - ground_truth.min(), win_size=7, channel_axis=-1)


In [5]:
class WaveletBlock(nn.Module):
    """
    Predicts the noise residual in the wavelet domain.
    Uses DWT (Discrete Wavelet Transform) to analyze image in frequency + spatial domain.
    """
    def __init__(self, channels, wavelet='haar'):
        super().__init__()
        
        # Forward and inverse wavelet transforms (1-level decomposition)
        self.dwt = DWT(J=1, wave=wavelet)
        self.idwt = IDWT(wave=wavelet)

        # Feed-forward network to process wavelet features (all 4 subbands)
        self.ffn = nn.Sequential(
            nn.Conv2d(4 * channels, 4 * channels, kernel_size=1, padding=0),  # Project + mix across channels
            nn.ReLU(inplace=True),
            nn.Conv2d(4 * channels, 4 * channels, kernel_size=1, padding=0)   # Final projection
        )

        # Learnable soft-thresholding parameters for LH, HL, HH bands
        self.threshold = nn.Parameter(torch.zeros(3, channels, 1, 1))  # Shape: [3 subbands, C, 1, 1]

    def forward(self, x):
        # Apply Discrete Wavelet Transform
        ll, yh = self.dwt(x)           # ll: low-frequency approx, yh: list of high-freq details

        detail = yh[0]                 # Extract the first-level detail subbands
        lh, hl, hh = torch.unbind(detail, dim=2)  # Separate into LH, HL, HH components

        # Concatenate all 4 bands (LL, LH, HL, HH) for processing
        stacked = torch.cat([ll, lh, hl, hh], dim=1)

        # Pass through feedforward projection network
        y = self.ffn(stacked)

        c = x.size(1)  # Number of channels

        # Split the processed features back into subbands
        ll2, lh2, hl2, hh2 = torch.split(y, c, dim=1)

        # Apply learnable soft-thresholding to high-frequency bands
        t = torch.sigmoid(self.threshold)  # Thresholds in [0, 1] range
        lh2 = lh2 * t[0]
        hl2 = hl2 * t[1]
        hh2 = hh2 * t[2]

        # Stack high-frequency bands back into required shape for IDWT: [B, C, 3, H/2, W/2]
        y_high = torch.stack([lh2, hl2, hh2], dim=2)

        # Apply inverse DWT to get back to spatial domain (predicted residual noise)
        out = self.idwt((ll2, [y_high]))
        return out  # Final residual noise prediction from wavelet block


class HybridDnCNN(nn.Module):
    """
    Combines standard DnCNN residual prediction with a wavelet-based residual branch.
    Input: noisy image x
    Output: denoised image
    """
    def __init__(self, channels=3, num_layers=17, features=64, wavelet='haar'):
        super(HybridDnCNN, self).__init__()
        # Standard DnCNN branch
        layers = [
            nn.Conv2d(in_channels=channels, out_channels=features, kernel_size=3, padding=1, bias=True),
            nn.ReLU(inplace=True)
        ]
        for _ in range(num_layers - 2):
            layers += [
                nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(features),
                nn.ReLU(inplace=True)
            ]
        layers.append(nn.Conv2d(features, channels, kernel_size=3, padding=1, bias=False))
        self.dncnn = nn.Sequential(*layers)

        # Wavelet residual branch
        self.wavelet_block = WaveletBlock(channels, wavelet)

    def forward(self, x):
        noise_dn = self.dncnn(x)
        noise_wave = self.wavelet_block(x)
        noise_combined = noise_dn + noise_wave
        clean = x - noise_combined
        return clean

In [6]:
transform = transforms.Compose([
    transforms.ToTensor()
])

# Download and load the CIFAR-10 training and test datasets
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset  = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Create data loaders for training and testing
batch_size = 128
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [14]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os

class ImageFolderNoClass(Dataset):
    def __init__(self, folder_path, transform=None):
        self.file_paths = [os.path.join(folder_path, f) 
                           for f in os.listdir(folder_path) 
                           if f.lower().endswith(('png', 'jpg', 'jpeg'))]
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        img = Image.open(self.file_paths[idx]).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, 0

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((32,32))
])

train_dataset = ImageFolderNoClass('./BSD500/train', transform=transform)
val_dataset   = ImageFolderNoClass('./BSD500/val', transform=transform)
test_dataset  = ImageFolderNoClass('./BSD500/test', transform=transform)

# batch_size = 32
# train_loader = DataLoader(train_dataset+val_dataset, batch_size=batch_size, shuffle=True)
# val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader  = DataLoader(train_dataset+val_dataset+test_dataset, batch_size=batch_size, shuffle=False)


In [8]:
print(repr(device))


device(type='cuda')


In [9]:
model = HybridDnCNN(channels=3).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
num_epochs = 100
noise_std = 0.1  

In [10]:
print("Starting Training...")
model.train()  
best_loss = float('inf')
patience = 2  
patience_counter = 0
best_model_state = None


for epoch in range(num_epochs):
    epoch_loss = 0
    start_time = time.time()

    for data,_ in train_loader:
        data = data.to(device)  
        noise = torch.randn_like(data) * noise_std
        noisy_data = data + noise
        output = model(noisy_data)
        loss = criterion(output, data)
        epoch_loss += loss.item() * data.size(0)

        optimizer.zero_grad()  
        loss.backward()     
        optimizer.step()       

    epoch_loss /= len(train_dataset)
    elapsed = time.time() - start_time
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.6f}, Time: {elapsed:.2f} sec")

    # Early stopping check
    if epoch_loss < best_loss - 1e-6:  
        best_loss = epoch_loss
        patience_counter = 0
        best_model_state = model.state_dict() 
    else:
        patience_counter += 1
        print(f"No improvement. Patience: {patience_counter}/{patience}")
        if patience_counter >= patience:
            print("Early stopping triggered.")
            break

if best_model_state is not None:
    model.load_state_dict(best_model_state)
    print("Loaded best model with lowest validation loss.")


Starting Training...
Epoch [1/100], Loss: 2.860587, Time: 80.38 sec
Epoch [2/100], Loss: 1.083148, Time: 72.84 sec
Epoch [3/100], Loss: 0.690991, Time: 73.10 sec
Epoch [4/100], Loss: 0.570792, Time: 73.30 sec
Epoch [5/100], Loss: 0.494640, Time: 73.15 sec
Epoch [6/100], Loss: 0.456837, Time: 73.01 sec
Epoch [7/100], Loss: 0.434533, Time: 75.68 sec
Epoch [8/100], Loss: 0.417939, Time: 72.06 sec
Epoch [9/100], Loss: 0.408424, Time: 71.93 sec
Epoch [10/100], Loss: 0.404299, Time: 71.97 sec
Epoch [11/100], Loss: 0.391720, Time: 71.40 sec
Epoch [12/100], Loss: 0.387643, Time: 72.57 sec
Epoch [13/100], Loss: 0.384211, Time: 72.35 sec
Epoch [14/100], Loss: 0.378280, Time: 72.77 sec
Epoch [15/100], Loss: 0.372295, Time: 73.28 sec
Epoch [16/100], Loss: 0.371394, Time: 73.69 sec
Epoch [17/100], Loss: 0.369361, Time: 72.70 sec
Epoch [18/100], Loss: 0.365060, Time: 72.60 sec
Epoch [19/100], Loss: 0.362386, Time: 72.70 sec
Epoch [20/100], Loss: 0.360790, Time: 72.71 sec
Epoch [21/100], Loss: 0.3609

In [15]:
model.eval()
psnr_list = []
ssim_list = []

with torch.no_grad():
    for data,_ in test_loader:
        data = data.to(device)
        noise = torch.randn_like(data) * noise_std
        noisy_data = data + noise
        output = model(noisy_data)
        
        # Move tensors to CPU and convert to numpy arrays, clipping values into [0,1]
        output_np = output.cpu().numpy().transpose(0, 2, 3, 1)   # (N, H, W, C)
        clean_np  = data.cpu().numpy().transpose(0, 2, 3, 1)
        noisy_np  = noisy_data.cpu().numpy().transpose(0, 2, 3, 1)
        
        # Calculate metrics image by image
        for denoised, clean in zip(output_np, clean_np):
            denoised = np.clip(denoised, 0., 1.)
            clean = np.clip(clean, 0., 1.)
            psnr_val = calculate_psnr(denoised, clean)
            ssim_val = calculate_ssim(denoised, clean)
            psnr_list.append(psnr_val)
            ssim_list.append(ssim_val)

mean_psnr = np.mean(psnr_list)
mean_ssim = np.mean(ssim_list)

print(f"Test PSNR: {mean_psnr:.2f} dB")
print(f"Test SSIM: {mean_ssim:.4f}")

Test PSNR: 28.15 dB
Test SSIM: 0.8793
