In [1]:
import xarray as xr
import numpy as np
import torch
import pandas as pd

from ens_transformer.data_module import IFSERADataModule
from ens_transformer.measures import crps_loss, WeightedScore

In [2]:
metrics = {
    'crps': WeightedScore(
        lambda prediction, target: crps_loss(
            prediction[0], prediction[1], target
        ),
    ),
    'mse': WeightedScore(
        lambda prediction, target: (prediction[0]-target).pow(2),
    ),
    'var': WeightedScore(
        lambda prediction, target: prediction[1].pow(2),
    )
}

In [3]:
data_module = IFSERADataModule(pin_memory=False)

In [4]:
data_module.setup()

In [5]:
prediction_paths = {
    'Transformer (10)': '../data/processed/prediction/subsampling/transformer_1_10.nc',
    'Transformer (20)': '../data/processed/prediction/subsampling/transformer_1_20.nc',
    'Transformer (50)': '../data/processed/prediction/subsampling/transformer_1_50.nc',
    'PPNN (0)': '../data/processed/prediction/baseline_scale/ppnn_0_20.nc',
    'PPNN (1)': '../data/processed/prediction/baseline_scale/ppnn_1_20.nc',
    'PPNN (5)': '../data/processed/prediction/baseline_scale/ppnn_5_20.nc',
    'Direct (1)': '../data/processed/prediction/baseline_scale/direct_1_20.nc',
    'Direct (5)': '../data/processed/prediction/baseline_scale/direct_5_20.nc',
    'Transformer (1)': '../data/processed/prediction/transformer_scale/transformer_1_20.nc',
    'Transformer (2)': '../data/processed/prediction/transformer_scale/transformer_2_20.nc',
    'Transformer (3)': '../data/processed/prediction/transformer_scale/transformer_3_20.nc',
    'Transformer (4)': '../data/processed/prediction/transformer_scale/transformer_4_20.nc',
    'Transformer (5)': '../data/processed/prediction/transformer_scale/transformer_5_20.nc',
}

In [6]:
xr_pred = {pred_name: xr.open_dataset(pred_path) for pred_name, pred_path in prediction_paths.items()}

In [7]:
xr_pred['IFS-EPS raw'] = xr.Dataset({
    'mean': data_module.ds_test.ifs.sel(var_name='t2m').mean('ensemble')-273.15,
    'stddev': data_module.ds_test.ifs.sel(var_name='t2m').std('ensemble', ddof=1)
})

In [8]:
xr_pred['Climatology'] = xr.Dataset({
    'mean': data_module.ds_train.dataset.era5.mean('time').expand_dims('time', axis=0),
    'stddev': data_module.ds_train.dataset.era5.std('time', ddof=1).expand_dims('time', axis=0)
})

In [9]:
def estimate_score(xr_prediction: xr.Dataset, xr_target: xr.DataArray):
    prediction = (
        torch.from_numpy(xr_prediction['mean'].values),
        torch.from_numpy(xr_prediction['stddev'].values),
    )
    target = torch.from_numpy(xr_target.values)
    score = {metric_name: metric_func(prediction, target).mean().item() for metric_name, metric_func in metrics.items()}
    return score

In [10]:
scores = {
    exp_name: estimate_score(exp_pred, data_module.ds_test.era5)
    for exp_name, exp_pred in xr_pred.items()
}

  torch.from_numpy(xr_prediction['mean'].values),


In [11]:
pd_scores = pd.DataFrame(scores).T
pd_scores['rmse'] = np.sqrt(pd_scores['mse'])
pd_scores['spread'] = np.sqrt(pd_scores['var'])
pd_scores['ratio'] = pd_scores['var']/pd_scores['mse']

In [12]:
pd_scores

Unnamed: 0,crps,mse,var,rmse,spread,ratio
Transformer (10),0.416692,0.819092,0.829157,0.905037,0.910581,1.012288
Transformer (20),0.422254,0.842562,0.812923,0.917912,0.901622,0.964823
Transformer (50),0.424164,0.85028,0.796596,0.922106,0.892522,0.936863
PPNN (0),0.439223,0.921808,0.760306,0.960108,0.871955,0.824799
PPNN (1),0.431505,0.898492,0.753128,0.947888,0.867829,0.838213
PPNN (5),0.418676,0.861397,0.764614,0.928115,0.874422,0.887644
Direct (1),0.445394,0.909814,0.49326,0.953842,0.702325,0.542155
Direct (5),0.446664,0.915982,0.495137,0.957069,0.70366,0.540554
Transformer (1),0.421108,0.835182,0.832842,0.913883,0.912602,0.997198
Transformer (2),0.420886,0.835741,0.837196,0.914189,0.914984,1.001741


In [13]:
pd_scores.round(2)[['crps', 'rmse', 'spread']].to_latex()

'\\begin{tabular}{lrrr}\n\\toprule\n{} &  crps &  rmse &  spread \\\\\n\\midrule\nTransformer (10) &  0.42 &  0.91 &    0.91 \\\\\nTransformer (20) &  0.42 &  0.92 &    0.90 \\\\\nTransformer (50) &  0.42 &  0.92 &    0.89 \\\\\nPPNN (0)         &  0.44 &  0.96 &    0.87 \\\\\nPPNN (1)         &  0.43 &  0.95 &    0.87 \\\\\nPPNN (5)         &  0.42 &  0.93 &    0.87 \\\\\nDirect (1)       &  0.45 &  0.95 &    0.70 \\\\\nDirect (5)       &  0.45 &  0.96 &    0.70 \\\\\nTransformer (1)  &  0.42 &  0.91 &    0.91 \\\\\nTransformer (2)  &  0.42 &  0.91 &    0.91 \\\\\nTransformer (3)  &  0.42 &  0.91 &    0.92 \\\\\nTransformer (4)  &  0.42 &  0.91 &    0.92 \\\\\nTransformer (5)  &  0.41 &  0.90 &    0.90 \\\\\nIFS-EPS raw      &  0.52 &  1.12 &    0.73 \\\\\nClimatology      &  2.60 &  6.12 &    6.05 \\\\\n\\bottomrule\n\\end{tabular}\n'