In [1]:
import os
import numpy as np
import cv2
import torch
from torch import nn
import matplotlib.pyplot as plt

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# DMCNN model definition (original, without padding)
class DMCNN(nn.Module):
    def __init__(self):
        super(DMCNN, self).__init__()
        
        self.feature_layer = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=9, padding=0),
            nn.ReLU()
        )
        
        self.mapping_layer = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=1),
            nn.ReLU()
        )
        
        self.reconstruction_layer = nn.Sequential(
            nn.Conv2d(64, 3, kernel_size=5),
            nn.ReLU()
        )
    
    def forward(self, x):
        out = self.feature_layer(x)
        out = self.mapping_layer(out)
        out = self.reconstruction_layer(out)
        return out

def calculate_psnr(img1, img2, max_val=1.0):
    """Calculate PSNR between two images"""
    mse = np.mean((img1 - img2) ** 2)
    return 20 * np.log10(max_val / np.sqrt(mse))

def process_image(model, bayer_img, patch_size=33, overlap=6):
    """Process image using patch-based approach"""
    model.eval()
    
    h, w = bayer_img.shape
    output = np.zeros((h, w, 3), dtype=np.float32)
    weights = np.zeros((h, w, 3), dtype=np.float32)
    
    # Initial RGB estimate for full image
    bayer_uint8 = (bayer_img * 255).astype(np.uint8)
    initial_rgb = cv2.cvtColor(bayer_uint8, cv2.COLOR_BAYER_BG2RGB_EA)
    initial_rgb = initial_rgb.astype(np.float32) / 255.0
    
    stride = patch_size - overlap
    target_size = 21  # Output size after convolutions
    margin = (patch_size - target_size) // 2
    
    with torch.no_grad():
        for y in range(0, h - patch_size + 1, stride):
            for x in range(0, w - patch_size + 1, stride):
                # Extract patch from initial RGB estimate
                patch = initial_rgb[y:y+patch_size, x:x+patch_size]
                
                # Convert to tensor
                patch_tensor = torch.from_numpy(patch).float().permute(2, 0, 1).unsqueeze(0)
                patch_tensor = patch_tensor.to(device)
                
                # Process patch
                output_patch = model(patch_tensor)
                output_patch = output_patch.squeeze().cpu().numpy()
                output_patch = output_patch.transpose(1, 2, 0)
                
                # Calculate output position
                out_y = y + margin
                out_x = x + margin
                
                # Create weight mask (gaussian falloff)
                weight_mask = np.ones((target_size, target_size, 1))
                if overlap > 0:
                    for i in range(overlap):
                        weight = np.exp(-((i - overlap/2)**2) / (2*(overlap/4)**2))
                        weight_mask[i, :] *= weight
                        weight_mask[-(i+1), :] *= weight
                        weight_mask[:, i] *= weight
                        weight_mask[:, -(i+1)] *= weight
                
                # Add to output
                output[out_y:out_y+target_size, out_x:out_x+target_size] += output_patch * weight_mask
                weights[out_y:out_y+target_size, out_x:out_x+target_size] += weight_mask
    
    # Handle borders
    mask = (weights != 0)
    output[mask] /= weights[mask]
    output[~mask] = initial_rgb[~mask]
    
    return output, initial_rgb

# Initialize model and load weights
model = DMCNN().to(device)
checkpoint = torch.load('../best_demosaic_model.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Set up paths
input_dir = '../dataset/kodak/input'
gt_dir = '../dataset/kodak/groundtruth'

# Get list of images
input_images = sorted([f for f in os.listdir(input_dir) if f.endswith('.png')])
total_images = len(input_images)

print(f"\nProcessing {total_images} images from Kodak dataset...")

# Process each image
results = []
for idx, img_file in enumerate(input_images):
    print(f"Processing image {idx+1}/{total_images}: {img_file}")
    
    # Load input image
    input_path = os.path.join(input_dir, img_file)
    input_img = cv2.imread(input_path, cv2.IMREAD_UNCHANGED)
    input_img = input_img.astype(np.float32) / 255.0
    
    # Load ground truth image
    gt_path = os.path.join(gt_dir, img_file)
    gt_img = cv2.imread(gt_path)
    gt_img = cv2.cvtColor(gt_img, cv2.COLOR_BGR2RGB)
    gt_img = gt_img.astype(np.float32) / 255.0
    
    # Process with patches
    output_img, initial_rgb = process_image(model, input_img)
    
    # Calculate PSNR
    psnr_dmcnn = calculate_psnr(gt_img, output_img)
    psnr_initial = calculate_psnr(gt_img, initial_rgb)
    
    results.append({
        'image': img_file,
        'dmcnn_psnr': psnr_dmcnn,
        'initial_psnr': psnr_initial
    })
    
    # Print current image results
    print(f"DMCNN PSNR: {psnr_dmcnn:.2f} dB")
    print(f"Initial PSNR: {psnr_initial:.2f} dB")
    print(f"Improvement: {psnr_dmcnn - psnr_initial:.2f} dB")
    print("-" * 50)
    
    # Visualize results
    plt.figure(figsize=(15, 5))
    
    plt.subplot(131)
    plt.imshow(initial_rgb)
    plt.title(f'Bilinear\nPSNR: {psnr_initial:.2f} dB')
    plt.axis('off')
    
    plt.subplot(132)
    plt.imshow(np.clip(output_img, 0, 1))
    plt.title(f'DMCNN\nPSNR: {psnr_dmcnn:.2f} dB')
    plt.axis('off')
    
    plt.subplot(133)
    plt.imshow(gt_img)
    plt.title('Ground Truth')
    plt.axis('off')
    
    plt.suptitle(f'Image: {img_file}')
    plt.tight_layout()
    plt.show()

# Print final results
print("\nResults Summary:")
print("-" * 50)
avg_dmcnn_psnr = np.mean([r['dmcnn_psnr'] for r in results])
avg_initial_psnr = np.mean([r['initial_psnr'] for r in results])
print(f"Average DMCNN PSNR: {avg_dmcnn_psnr:.2f} dB")
print(f"Average Initial PSNR: {avg_initial_psnr:.2f} dB")
print(f"Average Improvement: {avg_dmcnn_psnr - avg_initial_psnr:.2f} dB")

Using device: cuda


  checkpoint = torch.load('../best_demosaic_model.pth', map_location=device)


RuntimeError: Error(s) in loading state_dict for DMCNN:
	Missing key(s) in state_dict: "feature_layer.0.weight", "feature_layer.0.bias", "mapping_layer.0.weight", "mapping_layer.0.bias", "reconstruction_layer.0.weight", "reconstruction_layer.0.bias". 
	Unexpected key(s) in state_dict: "feature_extraction.weight", "feature_extraction.bias", "nonlinear_mapping.weight", "nonlinear_mapping.bias", "reconstruction.weight", "reconstruction.bias". 