In [10]:
%load_ext autoreload
%autoreload 2

In [50]:
import os
import sys
import numpy as np
import pandas as pd
import xarray as xr
import torch

base_dir = os.path.join(os.getcwd(), '..')
sys.path.append(base_dir)

from src.evaluation import metrics
from src.evaluation.metrics import spearman_correlation

In [68]:
def weighted_mean(x, wlat):
    mu = torch.sum(x * wlat, dim=(1, 2)) / (x.size(2) * wlat.sum())
    return mu

def compute_deterministic_metrics(prediction, groundtruth, wlat):
    # Compute raw distances metrics
    difference = prediction.sub(groundtruth)
    mean_bias = weighted_mean(difference, wlat).mean()
    rmse = weighted_mean(torch.square(difference), wlat).mean().sqrt()
    mae = weighted_mean(torch.abs(difference), wlat).mean()

    # Compute spearman correlation
    corr = spearman_correlation(prediction.flatten(), groundtruth.flatten())

    # Encapsulate results in output dictionnary
    output = {'mb': mean_bias.item(),
              'rmse': rmse.item(),
              'mae': mae.item(),
              'corr': corr}
    return output

def compute_probabilistic_metrics(prediction, groundtruth, wlat):
    ll = weighted_mean(prediction.log_prob(groundtruth), wlat).mean()
    lb, ub = prediction.icdf(torch.tensor(0.025)), prediction.icdf(torch.tensor(0.975))
    mask = (groundtruth >= lb) & (groundtruth <= ub)
    calib95 = weighted_mean(mask.float(), wlat).mean()
    
    mu, sigma = prediction.mean, prediction.stddev
    y = (groundtruth - mu) / sigma
    norm = torch.distributions.Normal(0, 1)
    crps = sigma * (y * (2 * norm.cdf(y) - 1) + 2 * norm.log_prob(y).exp() - 1 / np.sqrt(np.pi))
    crps = weighted_mean(crps, wlat).mean()
    
    output = {'ll': ll.item(),
              'calib95': calib95.item(),
              'CRPS': crps.item()}
    return output

In [52]:
output_ssp245 = xr.open_dataset('../data/outputs_ssp245.nc').mean('member')['tas']
pred_climatebench = xr.open_dataarray('../../archived/hackathon2021/climatebench-gp-posterior-mean-tas-test-2019-2100.nc')
stddev_climatebench = xr.open_dataarray('../../archived/hackathon2021/climatebench-gp-posterior-std-tas-test-2019-2100.nc')
wlat = np.cos(np.deg2rad(output_ssp245.lat))

In [62]:
global_gt = torch.from_numpy(output_ssp245.weighted(wlat).mean(['lat', 'lon']).data)
global_pred = torch.from_numpy(pred_climatebench.weighted(wlat).mean(['lat', 'lon']).data)
global_stddev = torch.from_numpy(stddev_climatebench.weighted(wlat).mean(['lat', 'lon']).data)
global_dist = torch.distributions.Normal(global_pred, global_stddev)

glob_det_scores = metrics.compute_deterministic_metrics(global_pred, global_gt)
glob_prob_scores = metrics.compute_probabilistic_metrics(global_dist, global_gt)
glob_scores = {**glob_det_scores, **glob_prob_scores}
glob_df = pd.DataFrame(data=[glob_scores])
glob_df

Unnamed: 0,mb,rmse,mae,corr,ll,calib95,CRPS,ICI
0,0.01044,0.123675,0.103984,0.972437,-0.446444,1.0,0.152658,0.385557


In [63]:
spatial_gt = torch.from_numpy(output_ssp245.sel(time=slice(2080, None)).data)
spatial_pred = torch.from_numpy(pred_climatebench.sel(time=slice(2080, None)).data)
spatial_stddev = torch.from_numpy(stddev_climatebench.sel(time=slice(2080, None)).data)
spatial_dist = torch.distributions.Normal(spatial_pred, spatial_stddev)

In [69]:
torchwlat = torch.from_numpy(wlat.data)[:, None]
spatial_metrics = {**compute_deterministic_metrics(spatial_pred, spatial_gt, torchwlat),
                   **compute_probabilistic_metrics(spatial_dist, spatial_gt, torchwlat)}

In [71]:
spatial_df = pd.DataFrame(data=[spatial_metrics])
spatial_df

Unnamed: 0,mb,rmse,mae,corr,ll,calib95,CRPS
0,-0.133914,0.431235,0.320826,0.958299,-0.690149,0.98491,0.248752
