In [1]:
from isr.model import MODEL_REGISTRY, IsrModel
from isr.data import DataManager
from isr.metrics import Metrics
from PIL import Image

import matplotlib.pyplot as plt
import cv2

ImportError: libGL.so.1: cannot open shared object file: No such file or directory

In [None]:
config = {
    "data": {
        "data_dir": "data/SuperResolution",
        "num_folds": 10,
    },
    "model": {
        "model_type": "InterpolationModel",
    },
    "path_to_artifacts_dir": "artifacts",
}

data_config = config["data"]
model_config = config["model"]
path_to_artifacts_dir = config["path_to_artifacts_dir"]

In [None]:
data_manager = DataManager(
    config=data_config,
)

df_train, df_valid = data_manager.get_fold(0)
train_lr_images = df_train["lr_image"].values
train_hr_images = df_train["hr_image"].values
val_lr_images = df_valid["lr_image"].values
val_hr_images = df_valid["hr_image"].values

df_train.shape, df_valid.shape

In [None]:
# just add this to models.py file after development
class LocalRegressionModel(IsrModel):
    def __init__(self, config):
        super(LocalRegressionModel, self).__init__(config)

    def train(self, list_of_lr_images, list_of_hr_images):
        print("LocalRegressionModel needs no training")

    def predict(self, list_of_lr_images):
        return [
            lr_image.resize((lr_image.size[0] * 4, lr_image.size[1] * 4), Image.LANCZOS)
            for lr_image in list_of_lr_images
        ]

In [None]:
local_regression_model = LocalRegressionModel(model_config)
local_regression_model.train(train_lr_images, train_hr_images)

predictions = local_regression_model.predict(val_lr_images)
metric_report = Metrics.generate_metric_report(val_hr_images, predictions)

In [None]:
metric_report = Metrics.generate_metric_report(val_hr_images, predictions)

metric_report.values()

In [None]:
plt.rcParams["figure.figsize"] = (15, 5)
for image_index, (prediction, gt, input_image) in enumerate(
    zip(predictions, val_hr_images, val_lr_images)
):
    plt.subplot(1, 3, 1)
    input_image = input_image.resize(
        (input_image.size[0] * 4, input_image.size[1] * 4), Image.BICUBIC
    )
    plt.imshow(input_image)
    plt.title("Input Image (bicubic upsampling)")

    plt.subplot(1, 3, 2)
    plt.imshow(prediction)
    plt.title("Prediction")

    plt.subplot(1, 3, 3)
    plt.imshow(gt)
    plt.title("Ground Truth")

    image_mse = metric_report["mse"][image_index]
    image_ssim = metric_report["ssim"][image_index]

    plt.suptitle(f"MSE: {image_mse:.2f}, SSIM: {image_ssim:.2f}")
    plt.tight_layout()
    plt.show()