In [1]:
import os
import numpy as np
import cv2
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [2]:
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [3]:
# Define the model architecture
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding='same'),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same'),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        # Encoder
        self.conv1 = DoubleConv(1, 32)
        self.conv2 = DoubleConv(32, 64)
        self.conv3 = DoubleConv(64, 128)
        
        # Decoder
        self.up_conv2 = DoubleConv(192, 64)  # 128 + 64 channels
        self.up_conv1 = DoubleConv(96, 32)   # 64 + 32 channels
        
        # Final convolution
        self.final_conv = nn.Conv2d(32, 3, kernel_size=1)
        
        # Pooling and upsampling
        self.pool = nn.MaxPool2d(2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
    def forward(self, x):
        # Encoder
        conv1 = self.conv1(x)
        x = self.pool(conv1)
        
        conv2 = self.conv2(x)
        x = self.pool(conv2)
        
        # Bridge
        x = self.conv3(x)
        
        # Decoder
        x = self.up(x)
        x = torch.cat([x, conv2], dim=1)
        x = self.up_conv2(x)
        
        x = self.up(x)
        x = torch.cat([x, conv1], dim=1)
        x = self.up_conv1(x)
        
        # Final convolution
        x = self.final_conv(x)
        return x

In [4]:
def calculate_psnr(img1, img2):
    """Calculate PSNR between two images."""
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    return 20 * np.log10(1.0 / np.sqrt(mse))
    
def calculate_cpsnr(img1, img2):
    """Calculate Color PSNR (average of R,G,B channels) and individual channel PSNRs."""
    psnr_r = calculate_psnr(img1[:,:,0], img2[:,:,0])
    psnr_g = calculate_psnr(img1[:,:,1], img2[:,:,1])
    psnr_b = calculate_psnr(img1[:,:,2], img2[:,:,2])
    return psnr_r, psnr_g, psnr_b

In [5]:
# Initialize model and load weights
model = UNet().to(device)
checkpoint = torch.load('best_unet_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Dataset path and test split file
dataset_path = 'dataset/MSR-Demosaicing/MSR-Demosaicing/Dataset_LINEAR_without_noise/bayer_panasonic'
test_split_file = os.path.join(dataset_path, 'test.txt')


  checkpoint = torch.load('best_unet_model.pth')


In [6]:
# Dataset path and test split file
dataset_path = 'dataset/MSR-Demosaicing/MSR-Demosaicing/Dataset_LINEAR_without_noise/bayer_panasonic'
test_split_file = os.path.join(dataset_path, 'test.txt')

# Read all test images
with open(test_split_file, 'r') as f:
    test_images = [line.strip() + '.png' for line in f.readlines()]

# Calculate PSNR for all test images
r_psnr_values = []
g_psnr_values = []
b_psnr_values = []
print("\nCalculating PSNR for all test images...")

for i, img_file in enumerate(test_images):
    print(f"Processing image {i+1}/{len(test_images)}", end='\r')
    
    # Load ground truth image
    gt_path = os.path.join(dataset_path, 'groundtruth', img_file)
    gt_img = cv2.imread(gt_path)
    gt_img = cv2.cvtColor(gt_img, cv2.COLOR_BGR2RGB)
    gt_img = gt_img.astype(np.float32) / 255.0
    
    # Load and normalize Bayer image
    input_path = os.path.join(dataset_path, 'input', img_file)
    bayer_img = cv2.imread(input_path, cv2.IMREAD_UNCHANGED)
    bayer_img = bayer_img.astype(np.float32) / 65535.0
    
    # Process with model
    with torch.no_grad():
        input_tensor = torch.from_numpy(bayer_img).float().unsqueeze(0).unsqueeze(0).to(device)
        output = model(input_tensor)
        output_img = output[0].cpu().numpy().transpose(1, 2, 0)
        output_img = np.clip(output_img, 0, 1)
    
    # Calculate PSNR for each channel
    r_psnr, g_psnr, b_psnr = calculate_cpsnr(gt_img, output_img)
    r_psnr_values.append(r_psnr)
    g_psnr_values.append(g_psnr)
    b_psnr_values.append(b_psnr)

print("\nPSNR calculation complete!")

# Calculate statistics
mean_r_psnr = np.mean(r_psnr_values)
mean_g_psnr = np.mean(g_psnr_values)
mean_b_psnr = np.mean(b_psnr_values)
mean_cpsnr = (mean_r_psnr + mean_g_psnr + mean_b_psnr) / 3.0

std_r_psnr = np.std(r_psnr_values)
std_g_psnr = np.std(g_psnr_values)
std_b_psnr = np.std(b_psnr_values)
std_cpsnr = np.std([r_psnr_values, g_psnr_values, b_psnr_values])

print(f"\nTest Set Results:")
print(f"R channel - Mean: {mean_r_psnr:.2f} dB, Std: {std_r_psnr:.2f} dB")
print(f"G channel - Mean: {mean_g_psnr:.2f} dB, Std: {std_g_psnr:.2f} dB")
print(f"B channel - Mean: {mean_b_psnr:.2f} dB, Std: {std_b_psnr:.2f} dB")
print(f"CPSNR     - Mean: {mean_cpsnr:.2f} dB, Std: {std_cpsnr:.2f} dB")


Calculating PSNR for all test images...
Processing image 200/200
PSNR calculation complete!

Test Set Results:
R channel - Mean: 37.01 dB, Std: 4.45 dB
G channel - Mean: 41.59 dB, Std: 4.01 dB
B channel - Mean: 37.97 dB, Std: 4.37 dB
CPSNR     - Mean: 38.86 dB, Std: 4.72 dB


In [7]:
# Save first 4 results
os.makedirs('results_unet', exist_ok=True)
for i in range(4):
    img_file = test_images[i]
    
    # Load and normalize Bayer image
    input_path = os.path.join(dataset_path, 'input', img_file)
    bayer_img = cv2.imread(input_path, cv2.IMREAD_UNCHANGED)
    bayer_img = bayer_img.astype(np.float32) / 65535.0
    
    # Process with model
    with torch.no_grad():
        input_tensor = torch.from_numpy(bayer_img).float().unsqueeze(0).unsqueeze(0).to(device)
        output = model(input_tensor)
        output_img = output[0].cpu().numpy().transpose(1, 2, 0)
        output_img = np.clip(output_img, 0, 1)
    
    # Save output image
    output_img = (output_img * 255).astype(np.uint8)
    output_img = cv2.cvtColor(output_img, cv2.COLOR_RGB2BGR)
    cv2.imwrite(f'results_unet/{img_file}', output_img)

print("\nOutput images saved in 'results_unet' directory.")


Output images saved in 'results_unet' directory.
