In [1]:
import os
import sys

from PIL import Image
import torch
from tqdm import tqdm

sys.path.append("../")
from src import create_hf_test_dataset, Predictor, AutoNetwork, InterpolatedNetwork

In [2]:
test_data_dir = "../data/test"
test_real_data_dir = "../data/test_real"
model_log_dir = "../model_logs"
results_dir = "../data/results"

In [3]:
hf_test_dataset = create_hf_test_dataset(test_data_dir)

In [4]:
lq_imges_dict = {lq_image_path.split("/")[-1]: Image.open(lq_image_path) for lq_image_path in hf_test_dataset["lq_image_path"]}

In [5]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [6]:
def generate_and_save_results(predictor, save_dir):
    for filename, lq_image in tqdm(lq_imges_dict.items(), desc="Generating results"):
        pred_image = predictor.predict(lq_image)
        pred_image.save(os.path.join(save_dir, filename))

# PSNR-based Model

In [7]:
psnr_based_model_restore_version = "train_240420182455"
psnr_based_model_log_dir = os.path.join(model_log_dir, psnr_based_model_restore_version)

In [8]:
psnr_based_model = AutoNetwork.from_pretrained(psnr_based_model_log_dir)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [9]:
psnr_based_model_predictor = Predictor(psnr_based_model, device)

In [10]:
psnr_based_model_save_dir = os.path.join(results_dir, "psnr_based")

In [11]:
generate_and_save_results(psnr_based_model_predictor, psnr_based_model_save_dir)

Generating results: 100%|██████████| 400/400 [00:43<00:00,  9.24it/s]


# GAN-based Model

In [12]:
gan_based_model_restore_version = "train_240422160848"
gan_based_model_log_dir = os.path.join(model_log_dir, gan_based_model_restore_version)

In [13]:
gan_based_model = AutoNetwork.from_pretrained(gan_based_model_log_dir)

In [14]:
gan_based_model_predictor = Predictor(gan_based_model, device)

In [15]:
gan_based_model_save_dir = os.path.join(results_dir, "gan_based")

In [16]:
generate_and_save_results(gan_based_model_predictor, gan_based_model_save_dir)

Generating results: 100%|██████████| 400/400 [00:46<00:00,  8.55it/s]


# Interpolated Model

**Test Data**

In [17]:
interpolated_model = InterpolatedNetwork(psnr_based_model, gan_based_model, lambda_val=0.25)

In [18]:
interpolated_model_predictor = Predictor(interpolated_model, device)

In [19]:
interpolated_model_save_dir = os.path.join(results_dir, "interpolated")

In [20]:
generate_and_save_results(interpolated_model_predictor, interpolated_model_save_dir)

Generating results: 100%|██████████| 400/400 [00:49<00:00,  8.05it/s]


**Test Real Data**

In [21]:
hf_test_real_dataset = create_hf_test_dataset(test_real_data_dir)

In [22]:
real_lq_imges_dict = {lq_image_path.split("/")[-1]: Image.open(lq_image_path) for lq_image_path in hf_test_real_dataset["lq_image_path"]}

In [23]:
interpolated_model_real_save_dir = os.path.join(results_dir, "test_real_interpolated")

In [24]:
def generate_and_save_real_results(predictor, save_dir):
    for filename, lq_image in tqdm(real_lq_imges_dict.items(), desc="Generating results"):
        pred_image = predictor.predict(lq_image)
        pred_image.save(os.path.join(save_dir, filename))

In [25]:
generate_and_save_real_results(interpolated_model_predictor, interpolated_model_real_save_dir)

Generating results: 100%|██████████| 6/6 [00:00<00:00,  7.92it/s]
