In this notebook we compute a series of metrics from the scaling and root scaling models

In [8]:
import logging

import pandas as pd
import numpy as np
import math
from typing import Optional, List, Tuple
import torch

from sdo.viz.compare_models import load_pred_and_gt
from sdo.metrics.ssim_metric import ssim

In [28]:
def load_pred_and_gt_no_mask(results_path: str, revert_root: bool = False, 
                     frac: Optional[float]=1.0) -> Tuple[np.array, np.array]:
    """
    Load predictions and ground truths from file, optionally remove
    root scaling,
    Args:
        results_path: path to file containing gt and predictions in npz format
        revert_root: if True both predictions and ground truth are **2
        frac: percentage of samples to be loaded, selected sequentially

    Returns:
        Y_test, Y_pred
    """
    Y = np.load(results_path)
    shape = Y.shape
    Y_test = Y[:, :, 0:int(shape[2] / 2), :]
    Y_pred = Y[:, :, int(shape[2] / 2):, :]
    #print(f"Shape Y {shape}")

    if revert_root:
        logging.info('Reverting root scaling')
        Y_test = np.power(Y_test, 2)
        Y_pred = np.power(Y_pred, 2)

    if frac< 1.0:
        Y_size = shape[0]
        sample_size = int(Y_size * frac)
        Y_test = Y_test[0:sample_size, :, :, :]
        Y_pred = Y_pred[0:sample_size, :, :, :]
        
    nsamples = Y_test.shape[0]
    
    Y_test = Y_test.reshape(nsamples, 1, 512, 512)
    Y_pred = Y_pred.reshape(nsamples, 1, 512, 512)
    return Y_test, Y_pred

In [29]:
results_path = '/fdl_sdo_data/bucket/EXPERIMENT_RESULTS/VIRTUAL_TELESCOPE'
dict_exp = {
    '211': [
    [ '/vale_exp_23/0600_vale_exp_23_test_predictions.npy', True, '211'],
    ['/vale_exp_20/0600_vale_exp_20_test_predictions.npy', False, '211_root'],
    ],
    '193': [
    ['/vale_exp_25/0600_vale_exp_25_test_predictions.npy', True, '193'],
    ['/vale_exp_13bis/0600_vale_exp_13bis_test_predictions.npy', False, '193_root'],
    ],
    '171': [
    [ '/vale_exp_26/0600_vale_exp_26_test_predictions.npy', True, '171'],
    ['/vale_exp_14bis/0600_vale_exp_14bis_test_predictions.npy', False, '171_root']
    ],
    '094': [
        [ '/vale_exp_27/0400_vale_exp_27_test_predictions.npy', True, '094'],
        ['/vale_exp_18/0600_vale_exp_18_test_predictions.npy', False, '094_root'],
    ]
}
#use 1.0 for final results, smaller value for testing
frac = 1.0

In [30]:
df_results = pd.DataFrame()
for key in dict_exp:
    print(f'Channel {key}')
    for exp in dict_exp[key]:
        print(f'Experiment {exp[0]}')
        Y_test, Y_pred = load_pred_and_gt_no_mask(results_path + exp[0], revert_root=exp[1], frac=frac)
        mse = np.square(np.subtract(Y_test, Y_pred)).mean() 
        rmse = math.sqrt(mse)
        # ssim is expecting images not arrays
        # array of size 1748467 -> the function load_pred_and_gt apply mask on the array, 
        # we need instead to reshape first and, if we want, appy the mask before the computation
        # modify the function. Better to compute metrics with and without mask
        #ssim = np.abs(1-ssim(Y_pred.reshape((667, 512, 512)), Y_test.reshape((16, 512, 512))))
        t_Y_pred = torch.from_numpy(Y_pred)
        t_Y_test = torch.from_numpy(Y_test)
        val_ssim = torch.abs(1-ssim(t_Y_pred, t_Y_test))
        pm = (rmse + val_ssim.item())/2
        df_results[exp[2]] = [mse, rmse, val_ssim.item(), pm]

Channel 211
Experiment /vale_exp_23/0600_vale_exp_23_test_predictions.npy
Experiment /vale_exp_20/0600_vale_exp_20_test_predictions.npy
Channel 193
Experiment /vale_exp_25/0600_vale_exp_25_test_predictions.npy
Experiment /vale_exp_13bis/0600_vale_exp_13bis_test_predictions.npy
Channel 171
Experiment /vale_exp_26/0600_vale_exp_26_test_predictions.npy
Experiment /vale_exp_14bis/0600_vale_exp_14bis_test_predictions.npy
Channel 094
Experiment /vale_exp_27/0400_vale_exp_27_test_predictions.npy
Experiment /vale_exp_18/0600_vale_exp_18_test_predictions.npy


In [31]:
df_results.index = ['MSE', 'RMSE', 'SSIM', '(RMSE + SSIM)/2']
df_results

Unnamed: 0,211,211_root,193,193_root,171,171_root,094,094_root
MSE,0.002648,0.002311,0.000394,0.000382,0.001008,0.00067,25.048731,36.042522
RMSE,0.051461,0.048074,0.019843,0.019545,0.031757,0.025893,5.004871,6.003542
SSIM,0.040844,0.046189,0.022866,0.024522,0.030636,0.034892,0.114447,0.138455
(RMSE + SSIM)/2,0.046152,0.047131,0.021355,0.022033,0.031196,0.030392,2.559659,3.070999
