In [2]:
import torch                              
import torch.nn as nn                      
import torch.optim as optim                
from torch.utils.data import DataLoader    
import torchvision                       
import torchvision.transforms as transforms 
import numpy as np                         
import cv2                                
from skimage.metrics import structural_similarity as ssim  
import matplotlib.pyplot as plt           
import time
from pytorch_wavelets import DWT, IDWT

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [4]:
def calculate_psnr(denoised, ground_truth):
    mse = np.mean((denoised - ground_truth) ** 2)
    if mse == 0:
        return float('inf')
    PIXEL_MAX = 1.0 
    psnr = 20 * np.log10(PIXEL_MAX / np.sqrt(mse))
    return psnr
from skimage.metrics import structural_similarity as ssim


def calculate_ssim(denoised, ground_truth):
    return ssim(ground_truth, denoised, data_range=ground_truth.max() - ground_truth.min(), win_size=7, channel_axis=-1)


In [23]:
class WaveletBlock(nn.Module):
    """
    Predicts the noise residual in the wavelet domain.
    """
    def __init__(self, channels, wavelet='haar'):
        super().__init__()
        # Forward and inverse DWT (single level)
        self.dwt = DWT(J=1, wave=wavelet)
        self.idwt = IDWT(wave=wavelet)

        # Feed-forward in wavelet domain
        self.ffn = nn.Sequential(
            nn.Conv2d(4 * channels, 4 * channels, kernel_size=1, padding=0),
            nn.ReLU(inplace=True),
            nn.Conv2d(4 * channels, 4 * channels, kernel_size=1, padding=0),
        )

        # Learnable soft-thresholds for LH, HL, HH bands
        # Shape: [3 (bands), channels, 1, 1]
        self.threshold = nn.Parameter(torch.zeros(3, channels, 1, 1))

    def forward(self, x):
        # x: [B, C, H, W]
        ll, yh = self.dwt(x)
        # yh is a list of length J; for J=1, yh[0] has shape [B, C, 3, H/2, W/2]
        detail = yh[0]
        # Split into subbands
        lh, hl, hh = torch.unbind(detail, dim=2)

        # Concatenate lowpass and highpass subbands along channel dim
        stacked = torch.cat([ll, lh, hl, hh], dim=1)
        y = self.ffn(stacked)

        # Split FFN output back into subbands
        c = x.size(1)
        ll2, lh2, hl2, hh2 = torch.split(y, c, dim=1)

        # Apply learnable soft-threshold to high-frequency bands
        t = torch.sigmoid(self.threshold)
        lh2 = lh2 * t[0]
        hl2 = hl2 * t[1]
        hh2 = hh2 * t[2]

        # Re-stack into shape expected by IDWT: [B, C, 3, H/2, W/2]
        y_high = torch.stack([lh2, hl2, hh2], dim=2)

        # Reconstruct residual from wavelet coefficients
        # IDWT expects a tuple (lowpass, [highpass_list])
        out = self.idwt((ll2, [y_high]))
        return out  # residual noise prediction

class HybridDnCNN(nn.Module):
    """
    Combines standard DnCNN residual prediction with a wavelet-based residual branch.
    Input: noisy image x
    Output: denoised image
    """
    def __init__(self, channels=3, num_layers=17, features=64, wavelet='haar'):
        super(HybridDnCNN, self).__init__()
        # Standard DnCNN branch
        layers = [
            nn.Conv2d(in_channels=channels, out_channels=features, kernel_size=3, padding=1, bias=True),
            nn.ReLU(inplace=True)
        ]
        for _ in range(num_layers - 2):
            layers += [
                nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(features),
                nn.ReLU(inplace=True)
            ]
        layers.append(nn.Conv2d(features, channels, kernel_size=3, padding=1, bias=False))
        self.dncnn = nn.Sequential(*layers)

        # Wavelet residual branch
        self.wavelet_block = WaveletBlock(channels, wavelet)

    def forward(self, x):
        # Predict residual noise via DnCNN
        noise_dn = self.dncnn(x)
        # Predict residual noise via WaveletBlock
        noise_wave = self.wavelet_block(x)
        # Combine residual predictions
        noise_combined = noise_dn + noise_wave
        # Subtract noise from input to get denoised output
        clean = x - noise_combined
        return clean

In [24]:
transform = transforms.Compose([
    transforms.ToTensor()
])

# Download and load the CIFAR-10 training and test datasets
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset  = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Create data loaders for training and testing
batch_size = 128
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader  = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [25]:
model = HybridDnCNN(channels=3).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
num_epochs = 20
noise_std = 0.1  

In [26]:
print("Starting Training...")
model.train()  
best_loss = float('inf')
patience = 2  
patience_counter = 0
best_model_state = None


for epoch in range(num_epochs):
    epoch_loss = 0
    start_time = time.time()

    for data, _ in train_loader:
        data = data.to(device)  
        noise = torch.randn_like(data) * noise_std
        noisy_data = data + noise
        output = model(noisy_data)
        loss = criterion(output, data)
        epoch_loss += loss.item() * data.size(0)

        optimizer.zero_grad()  
        loss.backward()     
        optimizer.step()       

    epoch_loss /= len(train_dataset)
    elapsed = time.time() - start_time
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.6f}, Time: {elapsed:.2f} sec")

    # Early stopping check
    if epoch_loss < best_loss - 1e-6:  
        best_loss = epoch_loss
        patience_counter = 0
        best_model_state = model.state_dict() 
    else:
        patience_counter += 1
        print(f"No improvement. Patience: {patience_counter}/{patience}")
        if patience_counter >= patience:
            print("Early stopping triggered.")
            break

if best_model_state is not None:
    model.load_state_dict(best_model_state)
    print("Loaded best model with lowest validation loss.")


Starting Training...
Epoch [1/20], Loss: 0.008989, Time: 102.82 sec
Epoch [2/20], Loss: 0.003195, Time: 85.25 sec
Epoch [3/20], Loss: 0.002378, Time: 88.75 sec
Epoch [4/20], Loss: 0.002063, Time: 84.14 sec
Epoch [5/20], Loss: 0.001862, Time: 86.15 sec
Epoch [6/20], Loss: 0.001759, Time: 80.57 sec
Epoch [7/20], Loss: 0.001692, Time: 80.86 sec
Epoch [8/20], Loss: 0.001664, Time: 81.26 sec
Epoch [9/20], Loss: 0.001632, Time: 81.00 sec
Epoch [10/20], Loss: 0.001603, Time: 82.57 sec
Epoch [11/20], Loss: 0.001590, Time: 89.32 sec
Epoch [12/20], Loss: 0.001564, Time: 87.07 sec
Epoch [13/20], Loss: 0.001569, Time: 83.47 sec
No improvement. Patience: 1/2
Epoch [14/20], Loss: 0.001531, Time: 82.51 sec
Epoch [15/20], Loss: 0.001533, Time: 83.37 sec
No improvement. Patience: 1/2
Epoch [16/20], Loss: 0.001519, Time: 88.35 sec
Epoch [17/20], Loss: 0.001500, Time: 83.30 sec
Epoch [18/20], Loss: 0.001494, Time: 80.74 sec
Epoch [19/20], Loss: 0.001468, Time: 81.02 sec
Epoch [20/20], Loss: 0.001483, Tim

In [27]:
model.eval()
psnr_list = []
ssim_list = []

with torch.no_grad():
    for data, _ in test_loader:
        data = data.to(device)
        noise = torch.randn_like(data) * noise_std
        noisy_data = data + noise
        output = model(noisy_data)
        
        # Move tensors to CPU and convert to numpy arrays, clipping values into [0,1]
        output_np = output.cpu().numpy().transpose(0, 2, 3, 1)   # (N, H, W, C)
        clean_np  = data.cpu().numpy().transpose(0, 2, 3, 1)
        noisy_np  = noisy_data.cpu().numpy().transpose(0, 2, 3, 1)
        
        # Calculate metrics image by image
        for denoised, clean in zip(output_np, clean_np):
            denoised = np.clip(denoised, 0., 1.)
            clean = np.clip(clean, 0., 1.)
            psnr_val = calculate_psnr(denoised, clean)
            ssim_val = calculate_ssim(denoised, clean)
            psnr_list.append(psnr_val)
            ssim_list.append(ssim_val)

mean_psnr = np.mean(psnr_list)
mean_ssim = np.mean(ssim_list)

print(f"Test PSNR: {mean_psnr:.2f} dB")
print(f"Test SSIM: {mean_ssim:.4f}")

Test PSNR: 28.47 dB
Test SSIM: 0.9149
