In [1]:
import os
import numpy as np
import cv2
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

In [2]:
# Define the UNet architecture exactly as in your code
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding='same'),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same'),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

In [3]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        # Encoder
        self.conv1 = DoubleConv(1, 32)
        self.conv2 = DoubleConv(32, 64)
        self.conv3 = DoubleConv(64, 128)
        
        # Decoder
        self.up_conv2 = DoubleConv(192, 64)  # 128 + 64 channels
        self.up_conv1 = DoubleConv(96, 32)   # 64 + 32 channels
        
        # Final convolution
        self.final_conv = nn.Conv2d(32, 3, kernel_size=1)
        
        # Pooling and upsampling
        self.pool = nn.MaxPool2d(2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
    def forward(self, x):
        # Encoder
        conv1 = self.conv1(x)           # 32x32 -> 32x32
        x = self.pool(conv1)            # 32x32 -> 16x16
        
        conv2 = self.conv2(x)           # 16x16 -> 16x16
        x = self.pool(conv2)            # 16x16 -> 8x8
        
        # Bridge
        x = self.conv3(x)               # 8x8 -> 8x8
        
        # Decoder
        x = self.up(x)                  # 8x8 -> 16x16
        x = torch.cat([x, conv2], dim=1)
        x = self.up_conv2(x)            # 16x16 -> 16x16
        
        x = self.up(x)                  # 16x16 -> 32x32
        x = torch.cat([x, conv1], dim=1)
        x = self.up_conv1(x)            # 32x32 -> 32x32
        
        # Final convolution
        x = self.final_conv(x)          # 32x32 -> 32x32
        return x

In [4]:
def process_kodak_image(model, input_path, gt_path, device, patch_size=32, overlap=2):
    """
    Process a single pre-Bayered Kodak image through the UNet model
    Args:
        model: UNet model
        input_path: Path to the input Bayer image
        gt_path: Path to the ground truth image
        device: torch device
        patch_size: Size of patches to process
        overlap: Overlap between patches
    """
    # Load pre-Bayered image
    bayer = cv2.imread(input_path, cv2.IMREAD_UNCHANGED)  # Read raw Bayer data
    bayer = bayer.astype(np.float32) / 255.0  # 16-bit input
    h, w = bayer.shape
    
    # Create bilinear interpolation for comparison
    bayer_uint8 = (bayer * 255).astype(np.uint8)
    bilinear = cv2.cvtColor(bayer_uint8, cv2.COLOR_BAYER_BG2RGB_EA)
    bilinear = bilinear.astype(np.float32) / 255.0
    
    # Load ground truth
    ground_truth = cv2.imread(gt_path)
    ground_truth = cv2.cvtColor(ground_truth, cv2.COLOR_BGR2RGB)
    ground_truth = ground_truth.astype(np.float32) / 255.0
    
    # Prepare output arrays
    output = np.zeros((h, w, 3), dtype=np.float32)
    weights = np.zeros((h, w, 3), dtype=np.float32)
    
    stride = patch_size - overlap
    model.eval()
    
    with torch.no_grad():
        # Calculate positions for patches
        y_positions = list(range(0, h - patch_size + 1, stride))
        x_positions = list(range(0, w - patch_size + 1, stride))
        
        # Add final positions if needed
        if h - patch_size not in y_positions:
            y_positions.append(h - patch_size)
        if w - patch_size not in x_positions:
            x_positions.append(w - patch_size)
        
        for y in y_positions:
            for x in x_positions:
                # Extract patch
                patch = bayer[y:y+patch_size, x:x+patch_size]
                patch_tensor = torch.from_numpy(patch).float().unsqueeze(0).unsqueeze(0)
                patch_tensor = patch_tensor.to(device)
                
                # Process patch
                output_patch = model(patch_tensor)
                output_patch = output_patch.squeeze().cpu().numpy()
                output_patch = output_patch.transpose(1, 2, 0)
                
                # Create weight mask (gaussian falloff)
                weight_mask = np.ones((patch_size, patch_size, 1))
                if overlap > 0:
                    for i in range(overlap):
                        weight = np.exp(-((i - overlap/2)**2) / (2*(overlap/4)**2))
                        weight_mask[i, :] *= weight
                        weight_mask[-(i + 1), :] *= weight
                        weight_mask[:, i] *= weight
                        weight_mask[:, -(i + 1)] *= weight
                
                # Add to output with weights
                output[y:y+patch_size, x:x+patch_size] += output_patch * weight_mask
                weights[y:y+patch_size, x:x+patch_size] += weight_mask
    
    # Average overlapping regions
    output = np.divide(output, weights, where=weights != 0)
    
    return output, bilinear, ground_truth, bayer

In [5]:
def calculate_psnr(img1, img2, max_val=1.0):
    """Calculate PSNR between two images"""
    mse = np.mean((img1 - img2) ** 2)
    return 20 * np.log10(max_val / np.sqrt(mse))

def calculate_cpsnr(img1, img2):
    """Calculate Color PSNR (average of R,G,B channels) and individual channel PSNRs."""
    psnr_r = calculate_psnr(img1[:,:,0], img2[:,:,0])
    psnr_g = calculate_psnr(img1[:,:,1], img2[:,:,1])
    psnr_b = calculate_psnr(img1[:,:,2], img2[:,:,2])
    return psnr_r, psnr_g, psnr_b

In [6]:
def process_kodak_dataset(model, device, input_dir='../dataset/kodak/input', 
                         gt_dir='../dataset/kodak/groundtruth'):
    """Process entire Kodak dataset"""

    # Calculate PSNR for all test images
    r_psnr_values = []
    g_psnr_values = []
    b_psnr_values = []

    # Making new dir to save results
    os.makedirs('result_kodak_dmcnn', exist_ok=True)
    
    # Get all input files
    image_files = [f for f in os.listdir(input_dir) if f.endswith('.png')]
    image_files.sort()
    
    
    for i, img_file in enumerate(image_files):
        print(f"\nProcessing image {i+1}/{len(image_files)}: {img_file}")
        
        # Construct full paths
        input_path = os.path.join(input_dir, img_file)
        gt_path = os.path.join(gt_dir, img_file)
        
        if not os.path.exists(gt_path):
            print(f"Warning: Ground truth not found for {img_file}")
            continue
        
        # Process image
        output, bilinear, ground_truth, bayer = process_kodak_image(
            model, input_path, gt_path, device)

        # Convert to uint8 and save
        output_img = (output * 255).astype(np.uint8)
        output_img = cv2.cvtColor(output_img, cv2.COLOR_RGB2BGR)
        cv2.imwrite(f'result_kodak_dmcnn/{img_file}', output_img)
        
        # Calculate PSNR for each channel
        r_psnr, g_psnr, b_psnr = calculate_cpsnr(output_img, ground_truth)
        r_psnr_values.append(r_psnr)
        g_psnr_values.append(g_psnr)
        b_psnr_values.append(b_psnr)
        
    # Calculate statistics
    mean_r_psnr = np.mean(r_psnr_values)
    mean_g_psnr = np.mean(g_psnr_values)
    mean_b_psnr = np.mean(b_psnr_values)
    mean_cpsnr = (mean_r_psnr + mean_g_psnr + mean_b_psnr) / 3.0
    
    std_r_psnr = np.std(r_psnr_values)
    std_g_psnr = np.std(g_psnr_values)
    std_b_psnr = np.std(b_psnr_values)
    std_cpsnr = np.std([r_psnr_values, g_psnr_values, b_psnr_values])
    
    print(f"\nTest Set Results:")
    print(f"R channel - Mean: {mean_r_psnr:.2f} dB, Std: {std_r_psnr:.2f} dB")
    print(f"G channel - Mean: {mean_g_psnr:.2f} dB, Std: {std_g_psnr:.2f} dB")
    print(f"B channel - Mean: {mean_b_psnr:.2f} dB, Std: {std_b_psnr:.2f} dB")
    print(f"CPSNR     - Mean: {mean_cpsnr:.2f} dB, Std: {std_cpsnr:.2f} dB")

        
    return None


In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Initialize model
model = UNet().to(device)

Using device: cuda


In [8]:
 # Load trained weights
checkpoint = torch.load('../best_unet_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded model from epoch {checkpoint['epoch']}")

Loaded model from epoch 58


  checkpoint = torch.load('../best_unet_model.pth')


In [9]:
# Process Kodak dataset
process_kodak_dataset(model, device)


Processing image 1/24: kodim01.png

Processing image 2/24: kodim02.png

Processing image 3/24: kodim03.png

Processing image 4/24: kodim04.png

Processing image 5/24: kodim05.png

Processing image 6/24: kodim06.png

Processing image 7/24: kodim07.png

Processing image 8/24: kodim08.png

Processing image 9/24: kodim09.png

Processing image 10/24: kodim10.png

Processing image 11/24: kodim11.png

Processing image 12/24: kodim12.png

Processing image 13/24: kodim13.png

Processing image 14/24: kodim14.png

Processing image 15/24: kodim15.png

Processing image 16/24: kodim16.png

Processing image 17/24: kodim17.png

Processing image 18/24: kodim18.png

Processing image 19/24: kodim19.png

Processing image 20/24: kodim20.png

Processing image 21/24: kodim21.png

Processing image 22/24: kodim22.png

Processing image 23/24: kodim23.png

Processing image 24/24: kodim24.png

Test Set Results:
R channel - Mean: -39.86 dB, Std: 2.68 dB
G channel - Mean: -41.24 dB, Std: 2.02 dB
B channel - Mean: 