In [9]:
import __init__

from collections import defaultdict
from functools import partial
from typing import Callable, Literal

import numpy as np
from cv2.typing import MatLike
from eval import cv2_to_pil, load_data, pil_to_cv2
from eval_with_ground import evaluate_image_pair_from_pil
from eval_without_ground import evaluate_image
from tqdm import tqdm

from dehaze import dehaze_and_enhance
from gan import gan_method

In [6]:
paired_samples = load_data()

100%|██████████| 7200/7200 [01:11<00:00, 100.95it/s]


In [None]:
def eval(dataset: Literal["lol_dataset", "Dark_Face", "LOL-V2"], method: Callable[[MatLike], MatLike]):
    samples = paired_samples[dataset]
    results = defaultdict(list)
    for sample in tqdm(samples):
        low_image = sample["low_image"]
        high_image = sample.get("high_image")
        # print(low_image, high_image)
        if low_image is not None:
            reference_img = cv2_to_pil(method(pil_to_cv2(low_image)))
            # reference_img.show()
            if high_image is not None:
                scores = evaluate_image_pair_from_pil(reference_img, high_image)
                # Compute metrics here
            else:
                scores = evaluate_image(reference_img, low_image)
            for key, s in scores.items():
                results[key].append(s)
    return results

In [5]:
paired_samples.keys()

dict_keys(['lol_dataset', 'Dark_Face', 'LOL-v2'])

### dahaze

In [6]:
result = eval('lol_dataset', dehaze_and_enhance)

100%|██████████| 500/500 [00:15<00:00, 32.61it/s]


In [7]:
for key, r in result.items():
    print(f"{key}: {np.mean(r):.4f} ± {np.std(r):.4f}")

MSE: 3876.0293 ± 3862.0581
PSNR: 14.2037 ± 4.2630
SSIM: 0.5057 ± 0.1516


In [8]:
result = eval('Dark_Face', dehaze_and_enhance)

100%|██████████| 6000/6000 [05:32<00:00, 18.03it/s]


In [9]:
for key, r in result.items():
    print(f"{key}: {np.mean(r):.4f} ± {np.std(r):.4f}")

SIMPLE_SCORE: 0.9235 ± 6.6518
CEI: 1.8999 ± 0.4300


In [10]:
result = eval('LOL-v2', dehaze_and_enhance)

100%|██████████| 100/100 [00:02<00:00, 42.75it/s]


In [11]:
for key, r in result.items():
    print(f"{key}: {np.mean(r):.4f} ± {np.std(r):.4f}")

MSE: 2107.1069 ± 1780.1074
PSNR: 16.2354 ± 3.3389
SSIM: 0.5460 ± 0.1234


In [None]:
dataset = 'lol_dataset'
result = eval(dataset, partial(gan_method, dataset))

  0%|          | 0/500 [00:00<?, ?it/s]


KeyError: 'lol'

In [None]:
for key, r in result.items():
    print(f"{key}: {np.mean(r):.4f} ± {np.std(r):.4f}")