In [2]:
import torch
import torch.nn as nn
import numpy as np
import rasterio
import matplotlib.pyplot as plt
from pathlib import Path
from skimage.metrics import structural_similarity as ssim
from skimage.color import rgb2ycbcr, ycbcr2rgb
import time
import pandas as pd
from tqdm import tqdm

from classic_algos.lanczos import SR_lanczos
from classic_algos.bicubic_interpolation import SR_bicubic
from src.models import *

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
DATA_DIR = Path.home() / '.data'

TEST_DIR = DATA_DIR / 'UCMerced_LandUse_Split' / 'test'
VAL_DIR = DATA_DIR / 'UCMerced_LandUse_Split' / 'val'
MODEL_DIR = DATA_DIR

In [4]:
def PSNR(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    return 10 * np.log10(1. / (mse + 1e-20))

def ERGAS(img_gt, img_sr, scale):
    img_gt = img_gt.astype(np.float64)
    img_sr = img_sr.astype(np.float64)
    diff = img_gt - img_sr
    rmse_per_band = np.sqrt(np.mean(diff ** 2, axis=(0, 1)))
    mean_per_band = np.mean(img_gt, axis=(0, 1))
    mean_per_band[mean_per_band == 0] = 1e-10

    sum_sq = np.sum((rmse_per_band / mean_per_band) ** 2)
    return 100 * (1 / scale) * np.sqrt((1 / img_gt.shape[2]) * sum_sq)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [18]:
def get_model_prediction(lr_rgb, model_name, model, scale):
    H, W, C = lr_rgb.shape
    target_h, target_w = H * scale, W * scale

    mode = 'y' if model_name.startswith('y_') else 'rgb'

    with torch.no_grad():
        if mode == 'rgb':
            # (H,W,C) -> (1,C,H,W)
            lr_tensor = torch.from_numpy(np.transpose(lr_rgb, (2, 0, 1))).unsqueeze(0).float().to(device)
            sr_tensor = model(lr_tensor)

            # (1,C,H,W) -> (H,W,C)
            sr_rgb = sr_tensor.squeeze().cpu().numpy()
            sr_rgb = np.transpose(sr_rgb, (1, 2, 0))
            return np.clip(sr_rgb, 0, 1)

        elif mode == 'y':
            # --- 1. Получаем предсказание Y от нейросети ---
            # Конвертируем вход в YCbCr
            # rgb2ycbcr(float) -> возвращает Y в диапазоне [16..235], тип float
            ycbcr = rgb2ycbcr(lr_rgb)

            # Нормализуем Y к [0, 1] для подачи в сеть (как при обучении)
            lr_y = ycbcr[:, :, 0] / 255.0

            lr_y_tensor = torch.from_numpy(lr_y).unsqueeze(0).unsqueeze(0).float().to(device)
            sr_y_tensor = model(lr_y_tensor)

            # Получаем выход [0, 1] и возвращаем к масштабу [16..235]
            sr_y = sr_y_tensor.squeeze().squeeze().cpu().numpy()
            sr_y = sr_y * 255.0

            # --- 2. Получаем каналы Cb и Cr через обычный Бикубик ---
            # Мы просто апскейлим всю RGB картинку, так как SR_bicubic точно работает корректно с RGB
            sr_bicubic_rgb = SR_bicubic(lr_rgb, target_h, target_w, preserve_range=True, output_dtype=np.float32)

            # Переводим интерполированный RGB в YCbCr
            sr_bicubic_ycbcr = rgb2ycbcr(sr_bicubic_rgb)

            # Берем Cb и Cr от бикубика. Они уже правильного размера и масштаба.
            sr_cb = sr_bicubic_ycbcr[:, :, 1]
            sr_cr = sr_bicubic_ycbcr[:, :, 2]

            # --- 3. Подменяем Y канал и собираем ---
            # Проверка размеров на всякий случай (иногда бывают ошибки округления на 1 пиксель)
            # Если размеры не совпадают, подрежем под Y (основной контент)
            h_y, w_y = sr_y.shape
            sr_cb = sr_cb[:h_y, :w_y]
            sr_cr = sr_cr[:h_y, :w_y]

            sr_ycbcr = np.stack([sr_y, sr_cb, sr_cr], axis=2)
            sr_rgb = ycbcr2rgb(sr_ycbcr)

            return np.clip(sr_rgb, 0, 1)


In [10]:
models = {}

y_fsrcnn = FSRCNN_Y().to(device)
y_fsrcnn.load_state_dict(torch.load(MODEL_DIR / 'y_fsrcnn_best_model.pth', map_location=device))
models['y_fsrcnn'] = y_fsrcnn

y_ressr = ResSR(num_channels=1).to(device)
y_ressr.load_state_dict(torch.load(MODEL_DIR / 'y_ressr_best_model.pth', map_location=device))
models['y_ressr'] = y_ressr

rgb_ressr = ResSR(num_channels=3).to(device)
rgb_ressr.load_state_dict(torch.load(MODEL_DIR / 'rgb_ressr_best_model.pth', map_location=device))
models['rgb_ressr'] = rgb_ressr

y_rcan = RCAN(num_channels=1).to(device)
y_rcan.load_state_dict(torch.load(MODEL_DIR / 'y_rcan_best_model.pth', map_location=device))
models['y_rcan'] = y_rcan

rgb_rcan = RCAN(num_channels=3).to(device)
rgb_rcan.load_state_dict(torch.load(MODEL_DIR / 'rgb_rcan_best_model.pth', map_location=device))
models['rgb_rcan'] = rgb_rcan

In [19]:
results = []
file_paths = list(TEST_DIR.rglob("*.tif"))
SCALE = 2
for img_path in tqdm(file_paths):

    with rasterio.open(img_path) as src:
        hr = src.read()
    # (C, H, W) -> (H, W, C), float32 [0..1]
    hr = np.transpose(hr, (1, 2, 0)).astype(np.float32) / 255.0

    H, W, _ = hr.shape
    H_new = H - (H % SCALE)
    W_new = W - (W % SCALE)
    hr = hr[:H_new, :W_new, :]

    lr_h, lr_w = H_new // SCALE, W_new // SCALE
    lr = SR_bicubic(hr, lr_h, lr_w, preserve_range=True, output_dtype=np.float32)

    methods_list = [
        ('Bicubic', None),
        ('Lanczos', None)
    ]

    for name in models.keys():
        methods_list.append((name, models[name]))

    for method_name, model in methods_list:
        torch.cuda.synchronize()
        start_time = time.perf_counter()

        if method_name == 'Bicubic':
            sr = SR_bicubic(lr, H_new, W_new, preserve_range=True, output_dtype=np.float32)
        elif method_name == 'Lanczos':
            sr = SR_lanczos(lr, H_new, W_new, preserve_range=True, output_dtype=np.float32)
        else:
            sr = get_model_prediction(lr, method_name, model, SCALE)

        torch.cuda.synchronize()
        end_time = time.perf_counter()
        inference_time = (end_time - start_time) * 1000


        sr_y = rgb2ycbcr(sr)[:, :, 0] / 255.0
        hr_y = rgb2ycbcr(hr)[:, :, 0] / 255.0

        #delete borders
        shave = SCALE
        hr_y = hr_y[shave:-shave, shave:-shave]
        sr_y = sr_y[shave:-shave, shave:-shave]

        p = PSNR(hr_y, sr_y)
        s = ssim(hr_y, sr_y, data_range=1.0)
        e = ERGAS(hr, sr, SCALE)
        fps = 1.0 / (end_time - start_time + 1e-10)

        results.append({
            'File': img_path.name,
            'Method': method_name,
            'PSNR': p,
            'SSIM': s,
            'ERGAS': e,
            'Time_ms': inference_time,
            'FPS': fps
        })

df = pd.DataFrame(results)

print(df.head())

  dataset = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)
  return np.where(x == 0, 1.0, np.sin(np.pi * x) / (np.pi * x))
100%|██████████| 210/210 [00:26<00:00,  7.82it/s]

                 File     Method       PSNR      SSIM     ERGAS     Time_ms  \
0  agricultural42.tif    Bicubic  37.437710  0.963957  1.348210   13.285537   
1  agricultural42.tif    Lanczos  37.793297  0.967473  1.299914   14.981957   
2  agricultural42.tif   y_fsrcnn  36.998672  0.960414  1.425582   21.079180   
3  agricultural42.tif    y_ressr  38.478462  0.970782  1.207859   98.582398   
4  agricultural42.tif  rgb_ressr  38.428818  0.970605  1.218050  140.808570   

         FPS  
0  75.269821  
1  66.746954  
2  47.440175  
3  10.143799  
4   7.101840  





In [20]:
summary_df = df.groupby('Method').agg({
    'PSNR': ['mean', 'std'],
    'SSIM': ['mean', 'std'],
    'ERGAS': ['mean'],
    'Time_ms': ['mean'],
    'FPS': ['mean']
}).round(4)
print(summary_df)

                PSNR            SSIM           ERGAS  Time_ms       FPS
                mean     std    mean     std    mean     mean      mean
Method                                                                 
Bicubic    31.930300  5.7136  0.8991  0.0719  4.0396   5.1581  197.5189
Lanczos    32.009499  5.8919  0.8971  0.0764  4.0596  11.2325   89.9506
rgb_rcan   32.947399  5.6816  0.9115  0.0648  3.6407  14.9819   68.7333
rgb_ressr  32.850399  5.6990  0.9099  0.0672  3.6824   8.8153  122.9658
y_fsrcnn   31.813801  5.5740  0.8960  0.0728  4.0879  12.9035   78.3352
y_rcan     32.893700  5.6885  0.9104  0.0665  3.6680  20.3129   49.8528
y_ressr    32.786598  5.7015  0.9090  0.0680  3.7106  14.4445   71.3733
