In [1]:
import os
import numpy as np
import cv2
import torch
from torch import nn
import matplotlib.pyplot as plt

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [2]:
class DMCNN(nn.Module):
    def __init__(self):
        super(DMCNN, self).__init__()
        self.feature_extraction = nn.Conv2d(3, 128, kernel_size=9, padding=4)
        self.relu1 = nn.ReLU(inplace=True)
        self.nonlinear_mapping = nn.Conv2d(128, 64, kernel_size=1)
        self.relu2 = nn.ReLU(inplace=True)
        self.reconstruction = nn.Conv2d(64, 3, kernel_size=5, padding=2)
        
    def forward(self, x):
        features = self.relu1(self.feature_extraction(x))
        mapped = self.relu2(self.nonlinear_mapping(features))
        out = self.reconstruction(mapped)
        return out


In [3]:
def calculate_psnr(img1, img2, max_val=1.0):
    """Calculate PSNR between two images"""
    mse = np.mean((img1 - img2) ** 2)
    return 20 * np.log10(max_val / np.sqrt(mse))

def calculate_cpsnr(img1, img2):
    """Calculate Color PSNR (average of R,G,B channels) and individual channel PSNRs."""
    psnr_r = calculate_psnr(img1[:,:,0], img2[:,:,0])
    psnr_g = calculate_psnr(img1[:,:,1], img2[:,:,1])
    psnr_b = calculate_psnr(img1[:,:,2], img2[:,:,2])
    return psnr_r, psnr_g, psnr_b


In [4]:
def create_cfa_channels(bayer_img):
    """Convert single-channel Bayer image to 3-channel representation."""
    H, W = bayer_img.shape
    cfa = np.zeros((H, W, 3), dtype=bayer_img.dtype)
    
    # RGGB pattern
    cfa[0::2, 0::2, 0] = bayer_img[0::2, 0::2]  # R
    cfa[0::2, 1::2, 1] = bayer_img[0::2, 1::2]  # G
    cfa[1::2, 0::2, 1] = bayer_img[1::2, 0::2]  # G
    cfa[1::2, 1::2, 2] = bayer_img[1::2, 1::2]  # B
    
    return cfa

In [5]:
def process_image(model, bayer_img, patch_size=33, overlap=6):
    """Process image using patch-based approach"""
    model.eval()
    
    h, w = bayer_img.shape
    output = np.zeros((h, w, 3), dtype=np.float32)
    weights = np.zeros((h, w, 3), dtype=np.float32)
    
    initial_rgb = np.stack([bayer_img, bayer_img, bayer_img], axis=-1)  # Create 3 identical channels
    
    stride = patch_size - overlap
    target_size = 33  # Output size after convolutions
    margin = (patch_size - target_size) // 2
    
    with torch.no_grad():
        for y in range(0, h - patch_size + 1, stride):
            for x in range(0, w - patch_size + 1, stride):
                # Extract patch from initial RGB estimate
                patch = initial_rgb[y:y+patch_size, x:x+patch_size]
                
                # Convert to tensor
                patch_tensor = torch.from_numpy(patch).float().permute(2, 0, 1).unsqueeze(0)
                patch_tensor = patch_tensor.to(device)
                
                # Process patch
                output_patch = model(patch_tensor)
                output_patch = output_patch.squeeze().cpu().numpy()
                output_patch = output_patch.transpose(1, 2, 0)
                
                # Calculate output position
                out_y = y + margin
                out_x = x + margin
                
                # Create weight mask (gaussian falloff)
                weight_mask = np.ones((target_size, target_size, 1))
                if overlap > 0:
                    for i in range(overlap):
                        weight = np.exp(-((i - overlap/2)**2) / (2*(overlap/4)**2))
                        weight_mask[i, :] *= weight
                        weight_mask[-(i+1), :] *= weight
                        weight_mask[:, i] *= weight
                        weight_mask[:, -(i+1)] *= weight
                
                # Add to output
                output[out_y:out_y+target_size, out_x:out_x+target_size] += output_patch * weight_mask
                weights[out_y:out_y+target_size, out_x:out_x+target_size] += weight_mask
    
    # Handle borders
    mask = (weights != 0)
    output[mask] /= weights[mask]
    output[~mask] = initial_rgb[~mask]
    
    return output, initial_rgb


In [6]:
# Initialize model and load weights
model = DMCNN().to(device)
checkpoint = torch.load('../best_demosaic_model.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Set up paths
input_dir = '../dataset/kodak/input'
gt_dir = '../dataset/kodak/groundtruth'

# Get list of images
input_images = sorted([f for f in os.listdir(input_dir) if f.endswith('.png')])
total_images = len(input_images)

print(f"\nProcessing {total_images} images from Kodak dataset...")



Processing 24 images from Kodak dataset...


  checkpoint = torch.load('../best_demosaic_model.pth', map_location=device)


In [7]:
# Calculate PSNR for all test images
r_psnr_values = []
g_psnr_values = []
b_psnr_values = []
print("\nCalculating PSNR for all test images...")

# Making new dir to save results
os.makedirs('result_kodak_dmcnn', exist_ok=True)

# Process each image
results = []
for idx, img_file in enumerate(input_images):
    print(f"Processing image {idx+1}/{total_images}: {img_file}")
    
    # Load input image
    input_path = os.path.join(input_dir, img_file)
    input_img = cv2.imread(input_path, cv2.IMREAD_UNCHANGED)
    input_img = input_img.astype(np.float32) / 255.0
    
    # Load ground truth image
    gt_path = os.path.join(gt_dir, img_file)
    gt_img = cv2.imread(gt_path)
    gt_img = cv2.cvtColor(gt_img, cv2.COLOR_BGR2RGB)
    gt_img = gt_img.astype(np.float32) / 255.0

    # Convert to 3-channel input
    input_3ch = create_cfa_channels(input_img)
    
   # Process with model
    with torch.no_grad():
        input_tensor = torch.from_numpy(input_3ch).float().permute(2, 0, 1).unsqueeze(0).to(device)
        output = model(input_tensor)
        output_img = output[0].cpu().numpy().transpose(1, 2, 0)
        output_img = np.clip(output_img, 0, 1)
        
    
    # Calculate PSNR for each channel
    r_psnr, g_psnr, b_psnr = calculate_cpsnr(gt_img, output_img)
    r_psnr_values.append(r_psnr)
    g_psnr_values.append(g_psnr)
    b_psnr_values.append(b_psnr)

    # Convert to uint8 and save
    output_img = (output_img * 255).astype(np.uint8)
    output_img = cv2.cvtColor(output_img, cv2.COLOR_RGB2BGR)
    cv2.imwrite(f'result_kodak_dmcnn/{img_file}', output_img)

print("\nPSNR calculation complete!")


Calculating PSNR for all test images...
Processing image 1/24: kodim01.png
Processing image 2/24: kodim02.png
Processing image 3/24: kodim03.png
Processing image 4/24: kodim04.png
Processing image 5/24: kodim05.png
Processing image 6/24: kodim06.png
Processing image 7/24: kodim07.png
Processing image 8/24: kodim08.png
Processing image 9/24: kodim09.png
Processing image 10/24: kodim10.png
Processing image 11/24: kodim11.png
Processing image 12/24: kodim12.png
Processing image 13/24: kodim13.png
Processing image 14/24: kodim14.png
Processing image 15/24: kodim15.png
Processing image 16/24: kodim16.png
Processing image 17/24: kodim17.png
Processing image 18/24: kodim18.png
Processing image 19/24: kodim19.png
Processing image 20/24: kodim20.png
Processing image 21/24: kodim21.png
Processing image 22/24: kodim22.png
Processing image 23/24: kodim23.png
Processing image 24/24: kodim24.png

PSNR calculation complete!


In [8]:
# Calculate statistics
mean_r_psnr = np.mean(r_psnr_values)
mean_g_psnr = np.mean(g_psnr_values)
mean_b_psnr = np.mean(b_psnr_values)
mean_cpsnr = (mean_r_psnr + mean_g_psnr + mean_b_psnr) / 3.0

std_r_psnr = np.std(r_psnr_values)
std_g_psnr = np.std(g_psnr_values)
std_b_psnr = np.std(b_psnr_values)
std_cpsnr = np.std([r_psnr_values, g_psnr_values, b_psnr_values])

print(f"\nTest Set Results:")
print(f"R channel - Mean: {mean_r_psnr:.2f} dB, Std: {std_r_psnr:.2f} dB")
print(f"G channel - Mean: {mean_g_psnr:.2f} dB, Std: {std_g_psnr:.2f} dB")
print(f"B channel - Mean: {mean_b_psnr:.2f} dB, Std: {std_b_psnr:.2f} dB")
print(f"CPSNR     - Mean: {mean_cpsnr:.2f} dB, Std: {std_cpsnr:.2f} dB")


Test Set Results:
R channel - Mean: 36.77 dB, Std: 2.17 dB
G channel - Mean: 39.86 dB, Std: 2.03 dB
B channel - Mean: 36.37 dB, Std: 2.03 dB
CPSNR     - Mean: 37.67 dB, Std: 2.60 dB
