In [2]:
import torch                               
import torch.nn as nn                      
import torch.optim as optim                
from torch.utils.data import DataLoader   
import torchvision                  
import torchvision.transforms as transforms  
import numpy as np                       
import cv2                                
from skimage.metrics import structural_similarity as ssim  
import matplotlib.pyplot as plt         
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

def calculate_psnr(denoised, ground_truth):
    mse = np.mean((denoised - ground_truth) ** 2)
    if mse == 0:
        return float('inf')
    PIXEL_MAX = 1.0 
    psnr = 20 * np.log10(PIXEL_MAX / np.sqrt(mse))
    return psnr
from skimage.metrics import structural_similarity as ssim


def calculate_ssim(denoised, ground_truth):
    return ssim(ground_truth, denoised, data_range=ground_truth.max() - ground_truth.min(), win_size=7, channel_axis=-1)


import torch
import torch.nn as nn
from torch.fft import fft2, ifft2

# Define a single FNet Block that performs FFT-based global feature mixing
class FNetBlock(nn.Module):
    def __init__(self, channels, height, width):
        super(FNetBlock, self).__init__()
        
        # LayerNorm applied over full [C, H, W] dimensions
        self.layer_norm = nn.LayerNorm([channels, height, width])
        
        # Feedforward network with 1x1 convolutions for feature projection
        self.ffn = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=1),  # Project + mix channel-wise features
            nn.ReLU(),
            nn.Conv2d(channels, channels, kernel_size=1)   # Final refinement
        )

    def forward(self, x):
        # Apply 2D FFT to each channel to capture global frequency features
        x_fft = torch.fft.fft2(x, dim=(-2, -1))
        
        # Convert back to spatial domain (real part only)
        x_ifft = torch.fft.ifft2(x_fft, dim=(-2, -1)).real
        
        # Normalize the result to stabilize feature values
        x_norm = self.layer_norm(x_ifft)
        
        # Pass through FFN and add residual connection (like ResNet)
        x_out = x_norm + self.ffn(x_norm)
        
        return x_out


# Full DnCNN-based denoising model with FNet Blocks replacing standard convolutions
class FnetDnCNNResidual(nn.Module):
    def __init__(self, channels=3, num_features=64, num_fnet_blocks=10, height=32, width=32):
        super(FnetDnCNNResidual, self).__init__()

        # Head: basic conv layer to extract initial features from input
        self.head = nn.Sequential(
            nn.Conv2d(channels, num_features, kernel_size=3, padding=1, bias=False),
            nn.ReLU(inplace=True)
        )
        
        # FNetBlocks stack: performs global mixing and normalization
        self.fnet_blocks = nn.Sequential(
            *[FNetBlock(num_features, height, width) for _ in range(num_fnet_blocks)]
        )

        # Tail: reduce feature maps back to original channel dimension (for noise prediction)
        self.tail = nn.Conv2d(num_features, channels, kernel_size=3, padding=1, bias=False)
    
    def forward(self, x):
        # Extract features from input image
        features = self.head(x)

        # Pass features through FNet-based global attention blocks
        features = self.fnet_blocks(features)

        # Predict the residual noise in the image
        predicted_noise = self.tail(features)
    
        # Subtract predicted noise from input to get denoised output
        denoised = x - predicted_noise

        return denoised
    
transform = transforms.Compose([
    transforms.ToTensor()
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset  = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

batch_size = 128
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# model = DnCNN_FFT(channels=3).to(device)
model = FnetDnCNNResidual(channels=3, num_features=64, num_fnet_blocks=3, height=32, width=32).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
num_epochs = 100
noise_std = 0.1  

print("Starting Training with Early Stopping...")
model.train()

best_loss = float('inf')
patience = 2  
patience_counter = 0
best_model_state = None

for epoch in range(num_epochs):
    epoch_loss = 0
    start_time = time.time()

    for data, _ in train_loader:
        data = data.to(device)
        noise = torch.randn_like(data) * noise_std
        noisy_data = data + noise

        output = model(noisy_data)
        loss = criterion(output, data)
        epoch_loss += loss.item() * data.size(0)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    epoch_loss /= len(train_dataset)
    elapsed = time.time() - start_time
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.6f}, Time: {elapsed:.2f} sec")

    # Early stopping check
    if epoch_loss < best_loss - 1e-6:  
        best_loss = epoch_loss
        patience_counter = 0
        best_model_state = model.state_dict() 
    else:
        patience_counter += 1
        print(f"No improvement. Patience: {patience_counter}/{patience}")
        if patience_counter >= patience:
            print("Early stopping triggered.")
            break

if best_model_state is not None:
    model.load_state_dict(best_model_state)
    print("Loaded best model with lowest validation loss.")


if best_model_state is not None:
    model.load_state_dict(best_model_state)
    print("Loaded best model with lowest validation loss.")
    torch.save(model.state_dict(), 'fft_model.pth')
    print("Best model saved as 'best_model.pth'.")

Using device: cuda
Files already downloaded and verified
Files already downloaded and verified
Starting Training with Early Stopping...
Epoch [1/100], Loss: 0.007151, Time: 33.77 sec
Epoch [2/100], Loss: 0.002474, Time: 32.64 sec
Epoch [3/100], Loss: 0.002279, Time: 32.61 sec
Epoch [4/100], Loss: 0.002157, Time: 32.63 sec
Epoch [5/100], Loss: 0.002050, Time: 32.52 sec
Epoch [6/100], Loss: 0.001981, Time: 32.43 sec
Epoch [7/100], Loss: 0.001924, Time: 32.44 sec
Epoch [8/100], Loss: 0.001867, Time: 32.42 sec
Epoch [9/100], Loss: 0.001821, Time: 32.46 sec
Epoch [10/100], Loss: 0.001787, Time: 32.43 sec
Epoch [11/100], Loss: 0.001759, Time: 32.43 sec
Epoch [12/100], Loss: 0.001744, Time: 32.49 sec
Epoch [13/100], Loss: 0.001722, Time: 32.53 sec
Epoch [14/100], Loss: 0.001710, Time: 32.67 sec
Epoch [15/100], Loss: 0.001697, Time: 32.62 sec
Epoch [16/100], Loss: 0.001682, Time: 32.79 sec
Epoch [17/100], Loss: 0.001673, Time: 32.68 sec
Epoch [18/100], Loss: 0.001663, Time: 32.88 sec
Epoch [19

In [3]:

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os

class ImageFolderNoClass(Dataset):
    def __init__(self, folder_path, transform=None):
        self.file_paths = [os.path.join(folder_path, f) 
                           for f in os.listdir(folder_path) 
                           if f.lower().endswith(('png', 'jpg', 'jpeg'))]
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        img = Image.open(self.file_paths[idx]).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, 0

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((32, 32))
])

# train_dataset = ImageFolderNoClass('./BSD500/train', transform=transform)
# val_dataset   = ImageFolderNoClass('./BSD500/val', transform=transform)
test_dataset  = ImageFolderNoClass('./BSD68/BSD68', transform=transform)

batch_size = 32
# train_loader = DataLoader(train_dataset+val_dataset, batch_size=batch_size, shuffle=True)
# val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


model.eval()
psnr_list = []
ssim_list = []



with torch.no_grad():
    for data, _ in test_loader:
        data = data.to(device)
        noise = torch.randn_like(data) * noise_std
        noisy_data = data + noise
        output = model(noisy_data)
        
        # Move tensors to CPU and convert to numpy arrays, clipping values into [0,1]
        output_np = output.cpu().numpy().transpose(0, 2, 3, 1)   # (N, H, W, C)
        clean_np  = data.cpu().numpy().transpose(0, 2, 3, 1)
        noisy_np  = noisy_data.cpu().numpy().transpose(0, 2, 3, 1)
        
        # Calculate metrics image by image
        for denoised, clean in zip(output_np, clean_np):
            denoised = np.clip(denoised, 0., 1.)
            clean = np.clip(clean, 0., 1.)
            psnr_val = calculate_psnr(denoised, clean)
            ssim_val = calculate_ssim(denoised, clean)
            psnr_list.append(psnr_val)
            ssim_list.append(ssim_val)

mean_psnr = np.mean(psnr_list)
mean_ssim = np.mean(ssim_list)

print(f"Test PSNR: {mean_psnr:.2f} dB")
print(f"Test SSIM: {mean_ssim:.4f}")

Test PSNR: 29.04 dB
Test SSIM: 0.8924


In [4]:
from pytorch_wavelets import DWT, IDWT

class WaveletBlock(nn.Module):
    """
    Predicts the noise residual in the wavelet domain.
    Uses DWT (Discrete Wavelet Transform) to analyze image in frequency + spatial domain.
    """
    def __init__(self, channels, wavelet='haar'):
        super().__init__()
        
        # Forward and inverse wavelet transforms (1-level decomposition)
        self.dwt = DWT(J=1, wave=wavelet)
        self.idwt = IDWT(wave=wavelet)

        # Feed-forward network to process wavelet features (all 4 subbands)
        self.ffn = nn.Sequential(
            nn.Conv2d(4 * channels, 4 * channels, kernel_size=1, padding=0),  # Project + mix across channels
            nn.ReLU(inplace=True),
            nn.Conv2d(4 * channels, 4 * channels, kernel_size=1, padding=0)   # Final projection
        )

        # Learnable soft-thresholding parameters for LH, HL, HH bands
        self.threshold = nn.Parameter(torch.zeros(3, channels, 1, 1))  # Shape: [3 subbands, C, 1, 1]

    def forward(self, x):
        # Apply Discrete Wavelet Transform
        ll, yh = self.dwt(x)           # ll: low-frequency approx, yh: list of high-freq details

        detail = yh[0]                 # Extract the first-level detail subbands
        lh, hl, hh = torch.unbind(detail, dim=2)  # Separate into LH, HL, HH components

        # Concatenate all 4 bands (LL, LH, HL, HH) for processing
        stacked = torch.cat([ll, lh, hl, hh], dim=1)

        # Pass through feedforward projection network
        y = self.ffn(stacked)

        c = x.size(1)  # Number of channels

        # Split the processed features back into subbands
        ll2, lh2, hl2, hh2 = torch.split(y, c, dim=1)

        # Apply learnable soft-thresholding to high-frequency bands
        t = torch.sigmoid(self.threshold)  # Thresholds in [0, 1] range
        lh2 = lh2 * t[0]
        hl2 = hl2 * t[1]
        hh2 = hh2 * t[2]

        # Stack high-frequency bands back into required shape for IDWT: [B, C, 3, H/2, W/2]
        y_high = torch.stack([lh2, hl2, hh2], dim=2)

        # Apply inverse DWT to get back to spatial domain (predicted residual noise)
        out = self.idwt((ll2, [y_high]))
        return out  # Final residual noise prediction from wavelet block


class HybridDnCNN(nn.Module):
    """
    Combines standard DnCNN residual prediction with a wavelet-based residual branch.
    Input: noisy image x
    Output: denoised image
    """
    def __init__(self, channels=3, num_layers=17, features=64, wavelet='haar'):
        super(HybridDnCNN, self).__init__()
        # Standard DnCNN branch
        layers = [
            nn.Conv2d(in_channels=channels, out_channels=features, kernel_size=3, padding=1, bias=True),
            nn.ReLU(inplace=True)
        ]
        for _ in range(num_layers - 2):
            layers += [
                nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(features),
                nn.ReLU(inplace=True)
            ]
        layers.append(nn.Conv2d(features, channels, kernel_size=3, padding=1, bias=False))
        self.dncnn = nn.Sequential(*layers)

        # Wavelet residual branch
        self.wavelet_block = WaveletBlock(channels, wavelet)

    def forward(self, x):
        noise_dn = self.dncnn(x)
        noise_wave = self.wavelet_block(x)
        noise_combined = noise_dn + noise_wave
        clean = x - noise_combined
        return clean
    

model = HybridDnCNN(channels=3).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
num_epochs = 100
noise_std = 0.1 

print("Starting Training...")
model.train()  
best_loss = float('inf')
patience = 2  
patience_counter = 0
best_model_state = None


for epoch in range(num_epochs):
    epoch_loss = 0
    start_time = time.time()

    for data,_ in train_loader:
        data = data.to(device)  
        noise = torch.randn_like(data) * noise_std
        noisy_data = data + noise
        output = model(noisy_data)
        loss = criterion(output, data)
        epoch_loss += loss.item() * data.size(0)

        optimizer.zero_grad()  
        loss.backward()     
        optimizer.step()       

    epoch_loss /= len(train_dataset)
    elapsed = time.time() - start_time
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.6f}, Time: {elapsed:.2f} sec")

    # Early stopping check
    if epoch_loss < best_loss - 1e-6:  
        best_loss = epoch_loss
        patience_counter = 0
        best_model_state = model.state_dict() 
    else:
        patience_counter += 1
        print(f"No improvement. Patience: {patience_counter}/{patience}")
        if patience_counter >= patience:
            print("Early stopping triggered.")
            break

if best_model_state is not None:
    model.load_state_dict(best_model_state)
    print("Loaded best model with lowest validation loss.")


if best_model_state is not None:
    model.load_state_dict(best_model_state)
    print("Loaded best model with lowest validation loss.")
    torch.save(model.state_dict(), 'wavelet_model.pth')
    print("Best model saved as 'best_model.pth'.")


Starting Training...
Epoch [1/100], Loss: 0.009476, Time: 72.08 sec
Epoch [2/100], Loss: 0.003776, Time: 70.72 sec
Epoch [3/100], Loss: 0.002577, Time: 70.76 sec
Epoch [4/100], Loss: 0.002205, Time: 70.74 sec
Epoch [5/100], Loss: 0.001924, Time: 70.65 sec
Epoch [6/100], Loss: 0.001778, Time: 70.68 sec
Epoch [7/100], Loss: 0.001698, Time: 70.79 sec
Epoch [8/100], Loss: 0.001671, Time: 70.69 sec
Epoch [9/100], Loss: 0.001619, Time: 70.70 sec
Epoch [10/100], Loss: 0.001599, Time: 70.74 sec
Epoch [11/100], Loss: 0.001570, Time: 70.68 sec
Epoch [12/100], Loss: 0.001560, Time: 70.62 sec
Epoch [13/100], Loss: 0.001603, Time: 70.75 sec
No improvement. Patience: 1/2
Epoch [14/100], Loss: 0.001537, Time: 70.76 sec
Epoch [15/100], Loss: 0.001518, Time: 70.66 sec
Epoch [16/100], Loss: 0.001522, Time: 70.74 sec
No improvement. Patience: 1/2
Epoch [17/100], Loss: 0.001507, Time: 70.76 sec
Epoch [18/100], Loss: 0.001536, Time: 70.74 sec
No improvement. Patience: 1/2
Epoch [19/100], Loss: 0.001518, Ti

In [5]:

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os

class ImageFolderNoClass(Dataset):
    def __init__(self, folder_path, transform=None):
        self.file_paths = [os.path.join(folder_path, f) 
                           for f in os.listdir(folder_path) 
                           if f.lower().endswith(('png', 'jpg', 'jpeg'))]
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        img = Image.open(self.file_paths[idx]).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, 0

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((32, 32))
])

# train_dataset = ImageFolderNoClass('./BSD500/train', transform=transform)
# val_dataset   = ImageFolderNoClass('./BSD500/val', transform=transform)
test_dataset  = ImageFolderNoClass('./BSD68/BSD68', transform=transform)

batch_size = 32
# train_loader = DataLoader(train_dataset+val_dataset, batch_size=batch_size, shuffle=True)
# val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


model.eval()
psnr_list = []
ssim_list = []



with torch.no_grad():
    for data, _ in test_loader:
        data = data.to(device)
        noise = torch.randn_like(data) * noise_std
        noisy_data = data + noise
        output = model(noisy_data)
        
        # Move tensors to CPU and convert to numpy arrays, clipping values into [0,1]
        output_np = output.cpu().numpy().transpose(0, 2, 3, 1)   # (N, H, W, C)
        clean_np  = data.cpu().numpy().transpose(0, 2, 3, 1)
        noisy_np  = noisy_data.cpu().numpy().transpose(0, 2, 3, 1)
        
        # Calculate metrics image by image
        for denoised, clean in zip(output_np, clean_np):
            denoised = np.clip(denoised, 0., 1.)
            clean = np.clip(clean, 0., 1.)
            psnr_val = calculate_psnr(denoised, clean)
            ssim_val = calculate_ssim(denoised, clean)
            psnr_list.append(psnr_val)
            ssim_list.append(ssim_val)

mean_psnr = np.mean(psnr_list)
mean_ssim = np.mean(ssim_list)

print(f"Test PSNR: {mean_psnr:.2f} dB")
print(f"Test SSIM: {mean_ssim:.4f}")

Test PSNR: 29.05 dB
Test SSIM: 0.9008
