In [1]:
import torch
import torch.nn as nn
import numpy as np
import rasterio
import matplotlib.pyplot as plt
from pathlib import Path
from skimage.metrics import structural_similarity as ssim
from skimage.color import rgb2ycbcr, ycbcr2rgb
import time
import pandas as pd
from tqdm import tqdm

from classic_algos.lanczos import SR_lanczos
from classic_algos.bicubic_interpolation import SR_bicubic
from src.models import *

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
DATA_DIR = Path.home() / '.data'

TEST_DIR = DATA_DIR / 'UCMerced_LandUse_Split' / 'test'
VAL_DIR = DATA_DIR / 'UCMerced_LandUse_Split' / 'val'
MODEL_DIR = DATA_DIR

In [3]:
def PSNR(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    return 10 * np.log10(1. / (mse + 1e-20))

def ERGAS(img_gt, img_sr, scale):
    img_gt = img_gt.astype(np.float64)
    img_sr = img_sr.astype(np.float64)
    diff = img_gt - img_sr
    rmse_per_band = np.sqrt(np.mean(diff ** 2, axis=(0, 1)))
    mean_per_band = np.mean(img_gt, axis=(0, 1))
    mean_per_band[mean_per_band == 0] = 1e-10

    sum_sq = np.sum((rmse_per_band / mean_per_band) ** 2)
    return 100 * (1 / scale) * np.sqrt((1 / img_gt.shape[2]) * sum_sq)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [10]:
def get_model_prediction(lr_rgb, model_name, model, scale):
    H, W, C = lr_rgb.shape
    target_h, target_w = H * scale, W * scale

    mode = 'y' if model_name.startswith('y_') else 'rgb'

    with torch.no_grad():
        if mode == 'rgb':
            # (H,W,C) -> (1,C,H,W)
            lr_tensor = torch.from_numpy(np.transpose(lr_rgb, (2, 0, 1))).unsqueeze(0).float().to(device)
            sr_tensor = model(lr_tensor)

            # (1,C,H,W) -> (H,W,C)
            sr_rgb = sr_tensor.squeeze().cpu().numpy()
            sr_rgb = np.transpose(sr_rgb, (1, 2, 0))
            return np.clip(sr_rgb, 0, 1)

        elif mode == 'y':
            lr_uint8 = (np.clip(lr_rgb, 0, 1) * 255).astype(np.uint8)
            ycbcr = rgb2ycbcr(lr_uint8)
            lr_y = ycbcr[:, :, 0].astype(np.float32) / 255.0
            lr_cb = ycbcr[:, :, 1]
            lr_cr = ycbcr[:, :, 2]

            # (1, 1, H, W)
            lr_y_tensor = torch.from_numpy(lr_y).unsqueeze(0).unsqueeze(0).float().to(device)
            sr_y_tensor = model(lr_y_tensor)
            sr_y = sr_y_tensor.squeeze().squeeze().cpu().numpy() # (H_new, W_new)

            # add colors with bicubic
            lr_cb_exp = lr_cb[:, :, np.newaxis]
            lr_cr_exp = lr_cr[:, :, np.newaxis]

            sr_cb = SR_bicubic(lr_cb_exp, target_h, target_w, preserve_range=True, output_dtype=np.float32)
            sr_cr = SR_bicubic(lr_cr_exp, target_h, target_w, preserve_range=True, output_dtype=np.float32)

            sr_cb = sr_cb[:, :, 0]
            sr_cr = sr_cr[:, :, 0]

            sr_y = sr_y * 255.0

            sr_ycbcr = np.stack([sr_y, sr_cb, sr_cr], axis=2)
            sr_rgb = ycbcr2rgb(sr_ycbcr)

            return np.clip(sr_rgb, 0, 1)

In [5]:
models = {}

y_fsrcnn = FSRCNN_Y().to(device)
y_fsrcnn.load_state_dict(torch.load(MODEL_DIR / 'y_fsrcnn_best_model.pth', map_location=device))
models['y_fsrcnn'] = y_fsrcnn

y_ressr = ResSR(num_channels=1).to(device)
y_ressr.load_state_dict(torch.load(MODEL_DIR / 'y_ressr_best_model.pth', map_location=device))
models['y_ressr'] = y_ressr

rgb_ressr = ResSR(num_channels=3).to(device)
rgb_ressr.load_state_dict(torch.load(MODEL_DIR / 'rgb_ressr_best_model.pth', map_location=device))
models['rgb_ressr'] = rgb_ressr

y_rcan = RCAN(num_channels=1).to(device)
y_rcan.load_state_dict(torch.load(MODEL_DIR / 'y_rcan_best_model.pth', map_location=device))
models['y_rcan'] = y_rcan

rgb_rcan = RCAN(num_channels=3).to(device)
rgb_rcan.load_state_dict(torch.load(MODEL_DIR / 'rgb_rcan_best_model.pth', map_location=device))
models['rgb_rcan'] = rgb_rcan

In [12]:
results = []
file_paths = list(TEST_DIR.rglob("*.tif"))
SCALE = 2
for img_path in tqdm(file_paths):

    with rasterio.open(img_path) as src:
        hr = src.read()
    # (C, H, W) -> (H, W, C), float32 [0..1]
    hr = np.transpose(hr, (1, 2, 0)).astype(np.float32) / 255.0

    H, W, _ = hr.shape
    H_new = H - (H % SCALE)
    W_new = W - (W % SCALE)
    hr = hr[:H_new, :W_new, :]

    lr_h, lr_w = H_new // SCALE, W_new // SCALE
    lr = SR_bicubic(hr, lr_h, lr_w, preserve_range=True, output_dtype=np.float32)

    methods_list = [
        ('Bicubic', None),
        ('Lanczos', None)
    ]

    for name in models.keys():
        methods_list.append((name, models[name]))

    for method_name, model in methods_list:
        torch.cuda.synchronize()
        start_time = time.perf_counter()

        if method_name == 'Bicubic':
            sr = SR_bicubic(lr, H_new, W_new, preserve_range=True, output_dtype=np.float32)
        elif method_name == 'Lanczos':
            sr = SR_lanczos(lr, H_new, W_new, preserve_range=True, output_dtype=np.float32)
        else:
            sr = get_model_prediction(lr, method_name, model, SCALE)

        torch.cuda.synchronize()
        end_time = time.perf_counter()
        inference_time = (end_time - start_time) * 1000


        sr_y = rgb2ycbcr(sr)[:, :, 0] / 255.0
        hr_y = rgb2ycbcr(hr)[:, :, 0] / 255.0

        #delete borders
        shave = SCALE
        hr_y = hr_y[shave:-shave, shave:-shave]
        sr_y = sr_y[shave:-shave, shave:-shave]

        p = PSNR(hr_y, sr_y)
        s = ssim(hr_y, sr_y, data_range=1.0)
        e = ERGAS(hr, sr, SCALE)
        fps = 1.0 / (end_time - start_time + 1e-10)

        results.append({
            'File': img_path.name,
            'Method': method_name,
            'PSNR': p,
            'SSIM': s,
            'ERGAS': e,
            'Time_ms': inference_time,
            'FPS': fps
        })

df = pd.DataFrame(results)

print(df.head())

  dataset = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)
  return np.where(x == 0, 1.0, np.sin(np.pi * x) / (np.pi * x))
100%|██████████| 210/210 [00:31<00:00,  6.57it/s]

                 File     Method       PSNR      SSIM     ERGAS    Time_ms  \
0  agricultural42.tif    Bicubic  37.351933  0.963715  1.348210  14.059264   
1  agricultural42.tif    Lanczos  37.697346  0.967239  1.299914  13.651941   
2  agricultural42.tif   y_fsrcnn  36.808662  0.959945  1.425582  23.009431   
3  agricultural42.tif    y_ressr  38.364960  0.970497  1.207859  77.065407   
4  agricultural42.tif  rgb_ressr  38.314854  0.970332  1.218050  93.717263   

         FPS  
0  71.127478  
1  73.249657  
2  43.460440  
3  12.975991  
4  10.670393  





In [9]:
print(df)

                         File     Method       PSNR      SSIM      ERGAS  \
0          agricultural42.tif    Bicubic -10.778872  0.948717   1.348210   
1          agricultural42.tif    Lanczos -10.433456  0.953994   1.299914   
2          agricultural42.tif   y_fsrcnn -43.151047  0.000002  50.414881   
3          agricultural42.tif    y_ressr -43.151047  0.000002  50.414881   
4          agricultural42.tif  rgb_ressr  -9.815949  0.957222   1.218050   
...                       ...        ...        ...       ...        ...   
1465  sparseresidential38.tif   y_fsrcnn -40.966789  0.000048  51.970500   
1466  sparseresidential38.tif    y_ressr -40.966789  0.000048  51.970500   
1467  sparseresidential38.tif  rgb_ressr -18.775761  0.720565   4.122629   
1468  sparseresidential38.tif     y_rcan -40.966789  0.000048  51.970500   
1469  sparseresidential38.tif   rgb_rcan -18.741323  0.720620   4.105002   

        Time_ms         FPS  
0     12.522886   79.853797  
1      9.413458  106.230887