In [1]:
import torch                              
import torch.nn as nn                      
import torch.optim as optim                
from torch.utils.data import DataLoader    
import torchvision                       
import torchvision.transforms as transforms 
import numpy as np                         
import cv2                                
from skimage.metrics import structural_similarity as ssim  
import matplotlib.pyplot as plt           
import time
from pytorch_wavelets import DWT, IDWT


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [3]:
def calculate_psnr(denoised, ground_truth):
    mse = np.mean((denoised - ground_truth) ** 2)
    if mse == 0:
        return float('inf')
    PIXEL_MAX = 1.0 
    psnr = 20 * np.log10(PIXEL_MAX / np.sqrt(mse))
    return psnr
from skimage.metrics import structural_similarity as ssim


def calculate_ssim(denoised, ground_truth):
    return ssim(ground_truth, denoised, data_range=ground_truth.max() - ground_truth.min(), win_size=7, channel_axis=-1)


In [4]:
transform = transforms.Compose([
    transforms.ToTensor()
])

# Download and load the CIFAR-10 training and test datasets
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset  = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Create data loaders for training and testing
batch_size = 128
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader  = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
import torch
import torch.nn as nn
from torch.fft import fft2, ifft2
from pytorch_wavelets import DWT, IDWT


class FNetBlock(nn.Module):
    """
    Spatial-frequency mixing via FFT, with residual feed-forward.
    """
    def __init__(self, channels, height, width):
        super().__init__()
        # Normalize across [C, H, W]
        self.layer_norm = nn.LayerNorm([channels, height, width])
        self.ffn = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=1)
        )

    def forward(self, x):
        # FFT and inverse FFT to mix spatial information globally
        x_fft = fft2(x, dim=(-2, -1))
        x_ifft = ifft2(x_fft, dim=(-2, -1)).real
        # Residual feed-forward
        x_norm = self.layer_norm(x_ifft)
        return x_ifft + self.ffn(x_norm)


class WaveletBlock(nn.Module):
    """
    Predicts residual noise in the wavelet domain using single-level DWT.
    """
    def __init__(self, channels, wavelet='haar'):
        super().__init__()
        self.dwt = DWT(J=1, wave=wavelet)
        self.idwt = IDWT(wave=wavelet)
        # FFN in wavelet domain (4 subbands concatenated)
        self.ffn = nn.Sequential(
            nn.Conv2d(4 * channels, 4 * channels, kernel_size=1, padding=0),
            nn.ReLU(inplace=True),
            nn.Conv2d(4 * channels, 4 * channels, kernel_size=1, padding=0),
        )
        self.threshold = nn.Parameter(torch.zeros(3, channels, 1, 1))

    def forward(self, x):
        # Decompose
        ll, yh = self.dwt(x)
        detail = yh[0]               # [B, C, 3, H/2, W/2]
        lh, hl, hh = torch.unbind(detail, dim=2)
        # Process
        stacked = torch.cat([ll, lh, hl, hh], dim=1)
        y = self.ffn(stacked)
        # Split back
        c = x.size(1)
        ll2, lh2, hl2, hh2 = torch.split(y, c, dim=1)
        t = torch.sigmoid(self.threshold)
        lh2 = lh2 * t[0]
        hl2 = hl2 * t[1]
        hh2 = hh2 * t[2]
        y_high = torch.stack([lh2, hl2, hh2], dim=2)
        # Reconstruct noise residual
        out = self.idwt((ll2, [y_high]))
        return out


class DnCNNBranch(nn.Module):
    """
    Standard DnCNN residual noise predictor.
    """
    def __init__(self, channels=3, num_layers=17, features=64):
        super().__init__()
        layers = [nn.Conv2d(channels, features, kernel_size=3, padding=1, bias=True),
                  nn.ReLU(inplace=True)]
        for _ in range(num_layers - 2):
            layers += [nn.Conv2d(features, features, 3, padding=1, bias=False),
                       nn.BatchNorm2d(features),
                       nn.ReLU(inplace=True)]
        layers.append(nn.Conv2d(features, channels, 3, padding=1, bias=False))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


class HybridTriBranchModel(nn.Module):
    """
    Hybrid denoiser combining DnCNN, FNet, and Wavelet residual branches.
    """
    def __init__(self,
                 channels=3,
                 height=32,
                 width=32,
                 dncnn_layers=17,
                 dncnn_features=64,
                 fnet_blocks=5,
                 wavelet='haar'):
        super().__init__()
        # Branches
        self.dncnn_branch = DnCNNBranch(channels, dncnn_layers, dncnn_features)
        self.fnet_blocks = nn.Sequential(
            *[FNetBlock(dncnn_features, height, width) for _ in range(fnet_blocks)]
        )
        self.wavelet_branch = WaveletBlock(channels, wavelet)
        # Fusion conv: combines 3 residual estimates
        self.fusion = nn.Conv2d(channels * 3, channels, kernel_size=1)

    def forward(self, x):
        # Predict residual noises
        r_dn = self.dncnn_branch(x)
        # Project into DnCNN feature space for FNet
        feat = self.dncnn_branch.net[0:2](x)  # first conv+ReLU
        feat = self.fnet_blocks(feat)
        r_fnet = self.dncnn_branch.net[-1:](feat) if hasattr(self.dncnn_branch.net, '__getitem__') else nn.Conv2d(feat.size(1), x.size(1), 3, padding=1)(feat)
        r_wave = self.wavelet_branch(x)
        # Concatenate residuals and fuse
        r_cat = torch.cat([r_dn, r_fnet, r_wave], dim=1)
        r = self.fusion(r_cat)
        # Subtract fused residual
        return x - r

# Example instantiation:
# model = HybridTriBranchModel(channels=3, height=32, width=32)
# Given an input tensor `noisy` of shape [B, 3, 32, 32],
# `denoised = model(noisy)` will produce the cleaned output.

In [7]:
model = HybridTriBranchModel(channels=3).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
num_epochs = 20
noise_std = 0.1  

In [8]:
print("Starting Training...")
model.train()  
best_loss = float('inf')
patience = 2  
patience_counter = 0
best_model_state = None


for epoch in range(num_epochs):
    epoch_loss = 0
    start_time = time.time()

    for data, _ in train_loader:
        data = data.to(device)  
        noise = torch.randn_like(data) * noise_std
        noisy_data = data + noise
        output = model(noisy_data)
        loss = criterion(output, data)
        epoch_loss += loss.item() * data.size(0)

        optimizer.zero_grad()  
        loss.backward()     
        optimizer.step()       

    epoch_loss /= len(train_dataset)
    elapsed = time.time() - start_time
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.6f}, Time: {elapsed:.2f} sec")

    # Early stopping check
    if epoch_loss < best_loss - 1e-6:  
        best_loss = epoch_loss
        patience_counter = 0
        best_model_state = model.state_dict() 
    else:
        patience_counter += 1
        print(f"No improvement. Patience: {patience_counter}/{patience}")
        if patience_counter >= patience:
            print("Early stopping triggered.")
            break

if best_model_state is not None:
    model.load_state_dict(best_model_state)
    print("Loaded best model with lowest validation loss.")


Starting Training...
Epoch [1/20], Loss: 0.005254, Time: 136.89 sec
Epoch [2/20], Loss: 0.001924, Time: 133.02 sec
Epoch [3/20], Loss: 0.001744, Time: 145.09 sec
Epoch [4/20], Loss: 0.001640, Time: 134.85 sec
Epoch [5/20], Loss: 0.001592, Time: 135.01 sec
Epoch [6/20], Loss: 0.001562, Time: 134.47 sec
Epoch [7/20], Loss: 0.001538, Time: 134.53 sec
Epoch [8/20], Loss: 0.001502, Time: 133.98 sec
Epoch [9/20], Loss: 0.001445, Time: 133.77 sec
Epoch [10/20], Loss: 0.001401, Time: 133.72 sec
Epoch [11/20], Loss: 0.001381, Time: 132.76 sec
Epoch [12/20], Loss: 0.001368, Time: 132.33 sec
Epoch [13/20], Loss: 0.001354, Time: 132.03 sec
Epoch [14/20], Loss: 0.001349, Time: 131.96 sec
Epoch [15/20], Loss: 0.001343, Time: 132.16 sec
Epoch [16/20], Loss: 0.001338, Time: 132.62 sec
Epoch [17/20], Loss: 0.001334, Time: 131.48 sec
Epoch [18/20], Loss: 0.001328, Time: 131.98 sec
Epoch [19/20], Loss: 0.001325, Time: 131.57 sec
Epoch [20/20], Loss: 0.001318, Time: 131.95 sec
Loaded best model with lowes

In [9]:
model.eval()
psnr_list = []
ssim_list = []

with torch.no_grad():
    for data, _ in test_loader:
        data = data.to(device)
        noise = torch.randn_like(data) * noise_std
        noisy_data = data + noise
        output = model(noisy_data)
        
        # Move tensors to CPU and convert to numpy arrays, clipping values into [0,1]
        output_np = output.cpu().numpy().transpose(0, 2, 3, 1)   # (N, H, W, C)
        clean_np  = data.cpu().numpy().transpose(0, 2, 3, 1)
        noisy_np  = noisy_data.cpu().numpy().transpose(0, 2, 3, 1)
        
        # Calculate metrics image by image
        for denoised, clean in zip(output_np, clean_np):
            denoised = np.clip(denoised, 0., 1.)
            clean = np.clip(clean, 0., 1.)
            psnr_val = calculate_psnr(denoised, clean)
            ssim_val = calculate_ssim(denoised, clean)
            psnr_list.append(psnr_val)
            ssim_list.append(ssim_val)

mean_psnr = np.mean(psnr_list)
mean_ssim = np.mean(ssim_list)

print(f"Test PSNR: {mean_psnr:.2f} dB")
print(f"Test SSIM: {mean_ssim:.4f}")

Test PSNR: 28.98 dB
Test SSIM: 0.9213
