In [1]:
# import library
import glob

import cv2
import torch
import torch.nn as nn
from skimage.metrics import peak_signal_noise_ratio, structural_similarity

from src import consistency
from src import option
from src import sen12ms
from src import utils

from model import rcan
from model import scan
from model import srgan

In [2]:
# define setting
opt = option.Option()
scale = 4
cuda = True
device = torch.device('cuda' if cuda else 'cpu')

# create dataset
dataset = sen12ms.Dataset(
    data_dir=opt.test_dir,
    scale=scale,
    load_highs2=False,
    mean=opt.sen12ms_mean,
    std=opt.sen12ms_std,
)

In [3]:
# function of loading model
def load_model(timestamp, index, n_colors):
    model_path = utils.find_model_file(timestamp, verbose=True)[index]
    scale = 2 if '2x' in model_path else (4 if '4x' in model_path else (8 if '8x' in model_path else -1))
    print('[{}] > {}'.format(index, model_path))

    if 'rcan' in model_path:
        model = rcan.RCAN(scale=scale, n_colors=n_colors, n_out_colors=1, act=nn.LeakyReLU(0.2, inplace=True))
        model.load_state_dict(torch.load(model_path, map_location=device), strict=True)
    elif 'scan' in model_path:
        model = scan.SCAN(scale=scale, n_colors=n_colors, n_out_colors=1, act=nn.LeakyReLU(0.2, inplace=True))
        model.load_state_dict(torch.load(model_path, map_location=device), strict=True)
    elif 'srgan' in model_path:
        model = srgan.Generator(in_channels=n_colors, out_channels=1)
        model.load_state_dict(torch.load(model_path, map_location=device), strict=True)
    
    return model

### Evaluate models

In [4]:
# list of models to be evaluated
model_info_list = [
    {'time': '20211218-154941_00062000', 'idx': 1, 'n_colors': 1},
]

In [6]:
# evaluate models
for i, info in enumerate(model_info_list):
    print('\n{}/{} ({}) ----------------------------------------'.format(i+1, len(model_info_list), info['time']))
    
    model = load_model(timestamp=info['time'],
                       index=info['idx'],
                       n_colors=info['n_colors'],
                       )
    
    _ = utils.test_metrics(model,
                           dataset,
                           shuffle=False,
                           full_comparison=True,
                           cuda=cuda,
                           use_optical=(info['n_colors'] == 4),
                           verbose=True,
                           )


1/1 (20211218-154941_00062000) ----------------------------------------
Search for: /home/vit134/vit/sar_sr/model/*20211218-154941_00062000*.pth
Not found
Search for: /home/vit134/vit/sar_sr/model/*/*20211218-154941_00062000*.pth
Found:
/home/vit134/vit/sar_sr/model/full_dataset/rcan_4x_20211218-154941_00062000_d.pth
/home/vit134/vit/sar_sr/model/full_dataset/rcan_4x_20211218-154941_00062000_g.pth
[1] > /home/vit134/vit/sar_sr/model/full_dataset/rcan_4x_20211218-154941_00062000_g.pth
[1041/1041]
         PSNR    SSIM
Nearest  21.6500 0.5003
Bilinear 0.0000 0.0000
Bicubic  0.0000 0.0000
SR       0.0000 0.0000
CR       21.6500 0.5003


### Correct inconsistency

In [5]:
# define paths
im_dir = '/home/vit134/vit/sar_sr/data/tmp'
sr_path_regex = im_dir + '/*_SR.png'

In [6]:
# measure metrics of SR and corrected ims.
for i, info in enumerate(model_info_list):
    print('\n{}/{} ({}) ----------------------------------------'.format(i+1, len(model_info_list), info['time']))
    
    n_test = len(dataset)
    
    # load model
    model = load_model(
        timestamp=info['time'],
        index=info['idx'],
        n_colors=info['n_colors'],
        )
    
    # generate images and save to temporary directory
    print('generating images')
    utils.apply(
        model=model,
        dataset=dataset,
        n_test=n_test,
        shuffle=False,
        full_comparison=False,
        use_optical=(info['n_colors'] == 4),
        cuda=cuda,
        save_dir=im_dir,
        show_im=False,
        verbose=True,
        )
    
    # glob image paths to be corrected
    sr_paths = sorted(glob.glob(sr_path_regex))
    
    # correct average consistency
    for i, path in enumerate(sr_paths):
        print('\rcorrecting inconsisntecy [{}/{}]'.format(i, len(sr_paths)), end='')
        sr = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
        lr = cv2.imread(path.replace('SR', 'LR'), cv2.IMREAD_GRAYSCALE)
        cr = consistency.correct(hr=sr, lr=lr, correction='ave_add_uniformly')
        cv2.imwrite(path.replace('SR', 'CR'), cr)
    
    # calc psnr and ssim regarding sr and cr
    result = {'sr': [0., 0.], 'cr': [0., 0.], 'count': [0, 0, 0, 0], }
    for i, path in enumerate(sr_paths):
        print('\rcalculating psnr and ssim [{}/{}]'.format(i, len(sr_paths)), end='')
        sr = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
        cr = cv2.imread(path.replace('SR', 'CR'), cv2.IMREAD_GRAYSCALE)
        hr = cv2.imread(path.replace('SR', 'HR'), cv2.IMREAD_GRAYSCALE)
        
        p_sr = peak_signal_noise_ratio(sr, hr, data_range=255)
        p_cr = peak_signal_noise_ratio(cr, hr, data_range=255)
        result['sr'][0] += p_sr / n_test
        result['cr'][0] += p_cr / n_test
        s_sr = structural_similarity(sr, hr, data_range=255, multichannel=False)
        s_cr = structural_similarity(cr, hr, data_range=255, multichannel=False)
        result['sr'][1] += s_sr / n_test
        result['cr'][1] += s_cr / n_test
        
        if p_cr > p_sr:
            result['count'][0] += 1
        elif p_cr < p_sr:
            result['count'][1] += 1
        if s_cr > s_sr:
            result['count'][2] += 1
        elif s_cr < s_sr:
            result['count'][3] += 1
    
    # display result
    print('\n         PSNR    SSIM')
    print('SR       {p_sr:.4f} {s_sr:.4f}\n'
          'CR       {p_cr:.4f} {s_cr:.4f}'
          .format(p_sr=result['sr'][0], s_sr=result['sr'][1],
                  p_cr=result['cr'][0], s_cr=result['cr'][1]))
    
    print('psnr up: {}, down: {}'.format(result['count'][0], result['count'][1]))
    print('ssim up: {}, down: {}'.format(result['count'][2], result['count'][3]))


1/1 (20211221-135216_00054750) ----------------------------------------
Search for: /home/vit134/vit/sar_sr/model/*20211221-135216_00054750*.pth
Not found
Search for: /home/vit134/vit/sar_sr/model/*/*20211221-135216_00054750*.pth
Found:
/home/vit134/vit/sar_sr/model/full_dataset/rcan_4x_20211221-135216_00054750_d.pth
/home/vit134/vit/sar_sr/model/full_dataset/rcan_4x_20211221-135216_00054750_g.pth
[1] > /home/vit134/vit/sar_sr/model/full_dataset/rcan_4x_20211221-135216_00054750_g.pth
generating images
calculating psnr and ssim [1040/1041]
         PSNR    SSIM
SR       23.4995 0.6393
CR       23.4964 0.6387
psnr up: 18, down: 1023
ssim up: 7, down: 1034
