In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
from PIL import Image
from pyiqa import create_metric
import model

In [2]:
def lowlight_compare(folder_path):
    # Initialize quality metrics
    piqe = create_metric('piqe')
    niqe = create_metric('niqe')
    brisque = create_metric('brisque')

    # Lists to store scores for each model
    ic_scores = {'piqe': [], 'niqe': [], 'brisque': []}

    # Loop through all images in the folder
    for filename in os.listdir(folder_path):
        if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.tif', '.bmp')):
            image_path = os.path.join(folder_path, filename)
            
            # Load the image
            data_lowlight = Image.open(image_path)
            data_lowlight = (np.asarray(data_lowlight) / 255.0)
            data_lowlight = torch.from_numpy(data_lowlight).float()
            data_lowlight = data_lowlight.permute(2, 0, 1)  # Convert to (C, H, W)
            data_lowlight = data_lowlight.cuda().unsqueeze(0)  # Add batch dimension

            # Load the Model
            IC_net = model.illumi_curve_net().cuda()
            IC_net.load_state_dict(torch.load('snapshots/model-best.pth'))
            enhanced_image_ic, _ = IC_net(data_lowlight)

            # Convert tensors to images for evaluation
            def tensor_to_image(tensor_image):
                tensor_image = tensor_image.squeeze(0).permute(1, 2, 0).cpu().detach().numpy()
                return (tensor_image * 255).astype(np.uint8)

            enhanced_image_ic_np = tensor_to_image(enhanced_image_ic)
            enhanced_image_ic_tensor = torch.tensor(enhanced_image_ic_np).permute(2, 0, 1).unsqueeze(0).float() / 255.0

            # Calculate metrics for DCE Model
            ic_scores['piqe'].append(piqe(enhanced_image_ic_tensor).item())
            ic_scores['niqe'].append(niqe(enhanced_image_ic_tensor).item())
            ic_scores['brisque'].append(brisque(enhanced_image_ic_tensor).item())

            # Save the enhanced images and intermediate images
            if filename == "14.TIF":
                tif_dce_image = enhanced_image_ic_np
                original_image = (data_lowlight.squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)

    # Calculate average scores
    def calculate_average(scores):
        return {key: np.mean(scores[key]) for key in scores}

    avg_ic_scores = calculate_average(ic_scores)

    # Print average scores
    def print_scores(model_name, scores):
        print(f"Average Scores for {model_name} Model (lower is better):")
        print(f"PIQE: {scores['piqe']:.2f}, NIQE: {scores['niqe']:.2f}, BRISQUE: {scores['brisque']:.2f}\n")

    # Display Results
    def display_results():
        if all(img is not None for img in [tif_dce_image]):
            fig, axes = plt.subplots(2, 1, figsize=(10, 30))
            
            # Original Image
            axes[0].imshow(original_image)
            axes[0].set_title("Original Image")
            axes[0].axis("off")

            # Enhanced Image - IC-Net Model
            axes[1].imshow(tif_dce_image)
            axes[1].set_title("Enhanced (IC-Net Model)")
            axes[1].axis("off")

            plt.tight_layout()
            plt.show()

    print_scores("IC-Net", avg_ic_scores)

In [3]:
# Get Metric Values
lowlight_compare('data/test_data/PSR/')  # Replace with the folder containing your images

  IC_net.load_state_dict(torch.load('snapshots/model-best.pth'))


Average Scores for IC-Net Model (lower is better):
PIQE: 36.36, NIQE: 8.38, BRISQUE: 36.56

