In [1]:
import torch                              
import torch.nn as nn                      
import torch.optim as optim                
from torch.utils.data import DataLoader    
import torchvision                       
import torchvision.transforms as transforms 
import numpy as np                         
import cv2                                
from skimage.metrics import structural_similarity as ssim  
import matplotlib.pyplot as plt           
import time
from pytorch_wavelets import DWT, IDWT


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


# import os

# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"

Using device: cuda


In [3]:
def calculate_psnr(denoised, ground_truth):
    mse = np.mean((denoised - ground_truth) ** 2)
    if mse == 0:
        return float('inf')
    PIXEL_MAX = 1.0 
    psnr = 20 * np.log10(PIXEL_MAX / np.sqrt(mse))
    return psnr
from skimage.metrics import structural_similarity as ssim


def calculate_ssim(denoised, ground_truth):
    return ssim(ground_truth, denoised, data_range=ground_truth.max() - ground_truth.min(), win_size=7, channel_axis=-1)


In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Resize((64,64)) #update
])

# Download and load the CIFAR-10 training and test datasets
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset  = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Create data loaders for training and testing
batch_size = 128
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
# test_loader  = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)


Files already downloaded and verified
Files already downloaded and verified


In [5]:
import torch
import torch.nn as nn
from torch.fft import fft2, ifft2
from pytorch_wavelets import DWT, IDWT

class FNetBlock(nn.Module):
    """
    Spatial-frequency mixing via FFT, with residual feed-forward.
    This block mixes information in both spatial and frequency domains.
    """
    def __init__(self, channels, height, width):
        super().__init__()
        self.layer_norm = nn.LayerNorm([channels, height, width])
        # Feed-forward network (FFN) consisting of two Conv2d layers with a ReLU in between
        self.ffn = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=1),  # 1x1 convolution to adjust channels
            nn.ReLU(inplace=True),  # ReLU activation for non-linearity
            nn.Conv2d(channels, channels, kernel_size=1)   # Another 1x1 convolution
        )

    def forward(self, x):
        # Apply FFT and IFFT to the input tensor (x), working in frequency space
        x_fft = fft2(x, dim=(-2, -1))  # Fast Fourier Transform (FFT)
        x_ifft = ifft2(x_fft, dim=(-2, -1)).real  # Inverse FFT (IFFT), only keeping real part
        x_norm = self.layer_norm(x_ifft)  # Normalize the result in the spatial domain
        # Add residual connection with feed-forward network applied on the normalized result
        return x_ifft + self.ffn(x_norm)


class WaveletBlock(nn.Module):
    """
    Predicts residual noise in the wavelet domain using single-level DWT.
    This block uses Discrete Wavelet Transform (DWT) to decompose the input, 
    processes the details, and then reconstructs the residual noise.
    """
    def __init__(self, channels, wavelet='haar'):
        super().__init__()
        self.dwt = DWT(J=1, wave=wavelet)  # Single-level DWT with specified wavelet
        self.idwt = IDWT(wave=wavelet)  # Inverse DWT for reconstruction
        self.ffn = nn.Sequential(
            nn.Conv2d(4 * channels, 4 * channels, kernel_size=1, padding=0),  # 1x1 convolution to adjust channels
            nn.ReLU(inplace=True),  # ReLU activation for non-linearity
            nn.Conv2d(4 * channels, 4 * channels, kernel_size=1, padding=0),  # Another 1x1 convolution
        )
        self.threshold = nn.Parameter(torch.zeros(3, channels, 1, 1))  # Threshold to scale the wavelet details

    def forward(self, x):
        # Decompose the input tensor into wavelet sub-bands (low and high frequency components)
        ll, yh = self.dwt(x)
        detail = yh[0]  # High-frequency components (details)
        lh, hl, hh = torch.unbind(detail, dim=2)  # Split high-frequency details into 3 sub-bands
        # Concatenate low and high-frequency components along the channel dimension
        stacked = torch.cat([ll, lh, hl, hh], dim=1)
        # Apply feed-forward network to the stacked components
        y = self.ffn(stacked)
        # Split the output back into corresponding sub-bands
        c = x.size(1)
        ll2, lh2, hl2, hh2 = torch.split(y, c, dim=1)
        t = torch.sigmoid(self.threshold)  # Sigmoid to generate scaling factors for the sub-bands
        lh2 = lh2 * t[0]  # Scale the lh sub-band
        hl2 = hl2 * t[1]  # Scale the hl sub-band
        hh2 = hh2 * t[2]  # Scale the hh sub-band
        y_high = torch.stack([lh2, hl2, hh2], dim=2)  # Stack the scaled high-frequency sub-bands
        # Reconstruct the residual noise using the inverse DWT
        out = self.idwt((ll2, [y_high]))
        return out


class DnCNNBranch(nn.Module):
    """
    Standard DnCNN residual noise predictor.
    This is a simple deep neural network that predicts the noise residuals from the input image.
    """
    def __init__(self, channels=3, num_layers=17, features=64):
        super().__init__()
        layers = [nn.Conv2d(channels, features, kernel_size=3, padding=1, bias=True),  # First layer
                  nn.ReLU(inplace=True)]  # ReLU activation
        # Add several convolutional layers followed by batch normalization and ReLU
        for _ in range(num_layers - 2):
            layers += [nn.Conv2d(features, features, 3, padding=1, bias=False),
                       nn.BatchNorm2d(features),
                       nn.ReLU(inplace=True)]
        # Final convolutional layer to match output channels with input
        layers.append(nn.Conv2d(features, channels, 3, padding=1, bias=False))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        # Forward pass through the DnCNN network
        return self.net(x)


class HybridTriBranchModel(nn.Module):
    """
    Hybrid denoiser combining DnCNN, FNet, and Wavelet residual branches.
    This model uses three separate branches (DnCNN, FNet, and Wavelet) to predict residual noise
    and then fuses the results to denoise the input image.
    """
    def __init__(self,
                 channels=3,
                 height=32,
                 width=32,
                 dncnn_layers=17,
                 dncnn_features=64,
                 fnet_blocks=5,
                 wavelet='haar'):
        super().__init__()
        # Define the branches: DnCNN, FNet, and Wavelet
        self.dncnn_branch = DnCNNBranch(channels, dncnn_layers, dncnn_features)
        self.fnet_blocks = nn.Sequential(
            *[FNetBlock(dncnn_features, height, width) for _ in range(fnet_blocks)]  # Stack FNet blocks
        )
        self.wavelet_branch = WaveletBlock(channels, wavelet)
        # Fusion layer: combines the residual predictions from the three branches
        self.fusion = nn.Conv2d(channels * 3, channels, kernel_size=1)

    def forward(self, x):
        # Predict residual noise using each branch
        r_dn = self.dncnn_branch(x)
        # Project input through DnCNN to obtain features for FNet
        feat = self.dncnn_branch.net[0:2](x)  # First two layers of DnCNN (Conv + ReLU)
        feat = self.fnet_blocks(feat)  # Apply FNet blocks
        r_fnet = self.dncnn_branch.net[-1:](feat)  # Final layer to bring it back to input channels
        r_wave = self.wavelet_branch(x)  # Predict residual from wavelet branch
        # Concatenate the residuals from all three branches
        r_cat = torch.cat([r_dn, r_fnet, r_wave], dim=1)
        # Fuse the concatenated residuals and subtract from the input for denoising
        r = self.fusion(r_cat)
        return x - r  # Denoised output


In [6]:
model = HybridTriBranchModel(channels=3).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
num_epochs = 50
noise_std = 0.1  

In [7]:
print("Starting Training...")
model.train()  
best_loss = float('inf')
patience = 2  
patience_counter = 0
best_model_state = None


for epoch in range(num_epochs):
    epoch_loss = 0
    start_time = time.time()

    for data, _ in train_loader:
        data = data.to(device)  
        noise = torch.randn_like(data) * noise_std
        noisy_data = data + noise
        output = model(noisy_data)
        loss = criterion(output, data)
        epoch_loss += loss.item() * data.size(0)

        optimizer.zero_grad()  
        loss.backward()     
        optimizer.step()       

    epoch_loss /= len(train_dataset)
    elapsed = time.time() - start_time
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.6f}, Time: {elapsed:.2f} sec")

    # Early stopping check
    if epoch_loss < best_loss - 1e-6:  
        best_loss = epoch_loss
        patience_counter = 0
        best_model_state = model.state_dict() 
    else:
        patience_counter += 1
        print(f"No improvement. Patience: {patience_counter}/{patience}")
        if patience_counter >= patience:
            print("Early stopping triggered.")
            break

if best_model_state is not None:
    model.load_state_dict(best_model_state)
    print("Loaded best model with lowest validation loss.")


Starting Training...
Epoch [1/50], Loss: 0.003217, Time: 125.73 sec
Epoch [2/50], Loss: 0.001703, Time: 125.88 sec
Epoch [3/50], Loss: 0.001580, Time: 126.47 sec
Epoch [4/50], Loss: 0.001530, Time: 126.38 sec
Epoch [5/50], Loss: 0.001500, Time: 125.12 sec
Epoch [6/50], Loss: 0.001480, Time: 126.84 sec
Epoch [7/50], Loss: 0.001469, Time: 124.52 sec
Epoch [8/50], Loss: 0.001458, Time: 124.38 sec
Epoch [9/50], Loss: 0.001451, Time: 126.03 sec
Epoch [10/50], Loss: 0.001439, Time: 125.47 sec
Epoch [11/50], Loss: 0.001413, Time: 125.74 sec
Epoch [12/50], Loss: 0.001374, Time: 125.23 sec
Epoch [13/50], Loss: 0.001354, Time: 125.92 sec
Epoch [14/50], Loss: 0.001343, Time: 125.79 sec
Epoch [15/50], Loss: 0.001333, Time: 125.52 sec
Epoch [16/50], Loss: 0.001326, Time: 125.10 sec
Epoch [17/50], Loss: 0.001321, Time: 125.35 sec
Epoch [18/50], Loss: 0.001316, Time: 125.13 sec
Epoch [19/50], Loss: 0.001312, Time: 125.44 sec
Epoch [20/50], Loss: 0.001309, Time: 125.33 sec
Epoch [21/50], Loss: 0.00130

In [8]:
if best_model_state is not None:
    model.load_state_dict(best_model_state)
    print("Loaded best model with lowest validation loss.")
    torch.save(model.state_dict(), 'best_model.pth')
    print("Best model saved as 'best_model.pth'.")

Loaded best model with lowest validation loss.
Best model saved as 'best_model.pth'.


In [10]:

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os

class ImageFolderNoClass(Dataset):
    def __init__(self, folder_path, transform=None):
        self.file_paths = [os.path.join(folder_path, f) 
                           for f in os.listdir(folder_path) 
                           if f.lower().endswith(('png', 'jpg', 'jpeg'))]
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        img = Image.open(self.file_paths[idx]).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, 0

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((32, 32))
])

# train_dataset = ImageFolderNoClass('./BSD500/train', transform=transform)
# val_dataset   = ImageFolderNoClass('./BSD500/val', transform=transform)
test_dataset  = ImageFolderNoClass('./BSD68/BSD68', transform=transform)

batch_size = 32
# train_loader = DataLoader(train_dataset+val_dataset, batch_size=batch_size, shuffle=True)
# val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


model.eval()
psnr_list = []
ssim_list = []



with torch.no_grad():
    for data, _ in test_loader:
        data = data.to(device)
        noise = torch.randn_like(data) * noise_std
        noisy_data = data + noise
        output = model(noisy_data)
        
        # Move tensors to CPU and convert to numpy arrays, clipping values into [0,1]
        output_np = output.cpu().numpy().transpose(0, 2, 3, 1)   # (N, H, W, C)
        clean_np  = data.cpu().numpy().transpose(0, 2, 3, 1)
        noisy_np  = noisy_data.cpu().numpy().transpose(0, 2, 3, 1)
        
        # Calculate metrics image by image
        for denoised, clean in zip(output_np, clean_np):
            denoised = np.clip(denoised, 0., 1.)
            clean = np.clip(clean, 0., 1.)
            psnr_val = calculate_psnr(denoised, clean)
            ssim_val = calculate_ssim(denoised, clean)
            psnr_list.append(psnr_val)
            ssim_list.append(ssim_val)

mean_psnr = np.mean(psnr_list)
mean_ssim = np.mean(ssim_list)

print(f"Test PSNR: {mean_psnr:.2f} dB")
print(f"Test SSIM: {mean_ssim:.4f}")

Test PSNR: 29.77 dB
Test SSIM: 0.9104
