In [1]:
import os
import sys

from PIL import Image
import torch

sys.path.append("../")
from src import create_hf_val_dataset, Evaluator, AutoNetwork, InterpolatedNetwork

In [2]:
val_data_dir = "../data/val"
model_log_dir = "../model_logs"

In [3]:
hf_val_dataset = create_hf_val_dataset(val_data_dir)

In [4]:
lq_imges = [Image.open(lq_image_path) for lq_image_path in hf_val_dataset["lq_image_path"]]
gt_images = [Image.open(gt_image_path) for gt_image_path in hf_val_dataset["gt_image_path"]]

In [5]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# PSNR-based Model

In [6]:
psnr_based_model_restore_version = "train_240420182455"
psnr_based_model_log_dir = os.path.join(model_log_dir, psnr_based_model_restore_version)

In [7]:
psnr_based_model = AutoNetwork.from_pretrained(psnr_based_model_log_dir)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [8]:
psnr_based_model_evaluator = Evaluator(psnr_based_model, device)

In [9]:
psnr_based_model_metrics = psnr_based_model_evaluator.evaluate(lq_imges, gt_images)

Evaluating: 400it [00:30, 13.04it/s]


In [10]:
print(f"PSNR of PSNR-based model: {psnr_based_model_metrics['psnr']:.2f}")

PSNR of PSNR-based model: 26.55


# GAN-based Model

In [20]:
gan_based_model_restore_version = "train_240422160848"
gan_based_model_log_dir = os.path.join(model_log_dir, gan_based_model_restore_version)

In [21]:
gan_based_model = AutoNetwork.from_pretrained(gan_based_model_log_dir)

In [22]:
gan_based_model_evaluator = Evaluator(gan_based_model, device)

In [23]:
gan_based_model_metrics = gan_based_model_evaluator.evaluate(lq_imges, gt_images)

Evaluating: 400it [00:27, 14.74it/s]


In [24]:
print(f"PSNR of GAN-based model: {gan_based_model_metrics['psnr']:.2f}")

PSNR of GAN-based model: 25.26


# Interpolated Model

In [25]:
interpolated_model = InterpolatedNetwork(psnr_based_model, gan_based_model, lambda_val=0.25)

In [26]:
interpolated_model_evaluator = Evaluator(interpolated_model, device)

In [27]:
interpolated_model_metrics = interpolated_model_evaluator.evaluate(lq_imges, gt_images)

Evaluating: 400it [00:28, 14.11it/s]


In [28]:
print(f"PSNR of Interpolated model: {interpolated_model_metrics['psnr']:.2f}")

PSNR of Interpolated model: 25.60
