In [1]:
%load_ext autoreload
%autoreload 2

In [99]:
import os
import sys
import numpy as np
import torch
import gpytorch
import matplotlib.pyplot as plt
import pandas as pd
import xarray as xr
import tqdm
import utils as utils

base_dir = os.path.join(os.getcwd(), '..')
sys.path.append(base_dir)

import fit_spatial_GP as plaingp
import fit_spatial_simpleFaIRGP as fairgp
from src.evaluation.metrics import spearman_correlation

In [61]:
def weighted_mean(x, wlat):
    mu = torch.sum(x * wlat, dim=(1, 2)) / (x.size(2) * wlat.sum())
    return mu

def compute_deterministic_metrics(prediction, groundtruth, wlat):
    # Compute raw distances metrics
    difference = prediction.sub(groundtruth)
    mean_bias = weighted_mean(difference, wlat).mean()
    rmse = weighted_mean(torch.square(difference), wlat).mean().sqrt()
    mae = weighted_mean(torch.abs(difference), wlat).mean()

    # Compute spearman correlation
    corr = spearman_correlation(prediction.flatten(), groundtruth.flatten())

    # Encapsulate results in output dictionnary
    output = {'mb': mean_bias.item(),
              'rmse': rmse.item(),
              'mae': mae.item(),
              'corr': corr}
    return output

def compute_probabilistic_metrics(prediction, groudtruth, wlat):
    ll = weighted_mean(prediction.log_prob(groundtruth), wlat).mean()
    lb, ub = prediction.icdf(torch.tensor(0.025)), prediction.icdf(torch.tensor(0.975))
    mask = (groundtruth >= lb) & (groundtruth <= ub)
    calib95 = weighted_mean(mask.float(), wlat).mean()
    output = {'ll': ll.item(),
              'calib95': calib95.item()}
    return output

## PlainGP

### SSP126

In [65]:
train_cfg = {'dataset' : {'dirpath': '../data/', 'keys': ['historical', 'ssp245', 'ssp370', 'ssp585']}}
train_data = plaingp.make_data(train_cfg)

test_cfg = {'dataset' : {'dirpath': '../data/', 'keys': ['ssp126']}}
test_data = plaingp.make_data(test_cfg)

model = plaingp.make_model(train_cfg, train_data)
state_dict = torch.load('../data/models/leave-one-out-ssp/ssp126/SpatialPlainGP/state_dict.pt')
model.load_state_dict(state_dict)

<All keys matched successfully>

In [66]:
test_scenarios = test_data.scenarios
Xtest = test_scenarios.glob_inputs[:, 1:]
Xtest = (Xtest - model.mu) / model.sigma
test_tas = test_scenarios.tas
model = model.eval()

In [67]:
with torch.no_grad():
    posterior = model.posterior(Xtest, diag=False)
    posterior = model.likelihood(posterior)

In [68]:
nlat = len(test_scenarios[0].lat)
nlon = len(test_scenarios[0].lon)
ntimes = len(test_scenarios[0].timesteps)
time_slice = slice(-21, None)

prior_mean = model.mu_targets.view(1, nlat, nlon).repeat(ntimes, 1, 1)
posterior_mean = posterior.mean.reshape(nlat, nlon, -1).permute(2, 0, 1)
posterior_correction =  model.sigma_targets.view(1, nlat, nlon) * posterior_mean
posterior_mean = prior_mean + posterior_mean

posterior_stddev = posterior.stddev.reshape(nlat, nlon, -1).permute(2, 0, 1)
posterior_stddev = model.sigma_targets.view(1, nlat, nlon) * posterior_stddev

prediction = torch.distributions.Normal(posterior_mean[time_slice], posterior_stddev[time_slice])
groundtruth = test_tas[time_slice]
wlat = torch.cos(torch.deg2rad(test_scenarios[0].lat)).clip(min=torch.finfo(torch.float64).eps)[:, None]

In [69]:
metrics_ssp126 = {**compute_deterministic_metrics(prediction.mean, groundtruth, wlat),
                  **compute_probabilistic_metrics(prediction, groundtruth, wlat)}

### SSP245

In [82]:
train_cfg = {'dataset' : {'dirpath': '../data/', 'keys': ['historical', 'ssp126', 'ssp370', 'ssp585']}}
train_data = plaingp.make_data(train_cfg)

test_cfg = {'dataset' : {'dirpath': '../data/', 'keys': ['ssp245']}}
test_data = plaingp.make_data(test_cfg)

model = plaingp.make_model(train_cfg, train_data)
state_dict = torch.load('../data/models/leave-one-out-ssp/ssp245/SpatialPlainGP/state_dict.pt')
model.load_state_dict(state_dict)

<All keys matched successfully>

In [83]:
test_scenarios = test_data.scenarios
Xtest = test_scenarios.glob_inputs[:, 1:]
Xtest = (Xtest - model.mu) / model.sigma
test_tas = test_scenarios.tas
model = model.eval()

In [84]:
with torch.no_grad():
    posterior = model.posterior(Xtest, diag=False)
    posterior = model.likelihood(posterior)

In [85]:
nlat = len(test_scenarios[0].lat)
nlon = len(test_scenarios[0].lon)
ntimes = len(test_scenarios[0].timesteps)
time_slice = slice(-21, None)

prior_mean = model.mu_targets.view(1, nlat, nlon).repeat(ntimes, 1, 1)
posterior_mean = posterior.mean.reshape(nlat, nlon, -1).permute(2, 0, 1)
posterior_correction =  model.sigma_targets.view(1, nlat, nlon) * posterior_mean
posterior_mean = prior_mean + posterior_mean

posterior_stddev = posterior.stddev.reshape(nlat, nlon, -1).permute(2, 0, 1)
posterior_stddev = model.sigma_targets.view(1, nlat, nlon) * posterior_stddev

prediction = torch.distributions.Normal(posterior_mean[time_slice], posterior_stddev[time_slice])
groundtruth = test_tas[time_slice]
wlat = torch.cos(torch.deg2rad(test_scenarios[0].lat)).clip(min=torch.finfo(torch.float64).eps)[:, None]

In [86]:
metrics_ssp245 = {**compute_deterministic_metrics(prediction.mean, groundtruth, wlat),
                  **compute_probabilistic_metrics(prediction, groundtruth, wlat)}

### SSP370

In [87]:
train_cfg = {'dataset' : {'dirpath': '../data/', 'keys': ['historical', 'ssp126', 'ssp245', 'ssp585']}}
train_data = plaingp.make_data(train_cfg)

test_cfg = {'dataset' : {'dirpath': '../data/', 'keys': ['ssp370']}}
test_data = plaingp.make_data(test_cfg)

model = plaingp.make_model(train_cfg, train_data)
state_dict = torch.load('../data/models/leave-one-out-ssp/ssp370/SpatialPlainGP/state_dict.pt')
model.load_state_dict(state_dict)

<All keys matched successfully>

In [88]:
test_scenarios = test_data.scenarios
Xtest = test_scenarios.glob_inputs[:, 1:]
Xtest = (Xtest - model.mu) / model.sigma
test_tas = test_scenarios.tas
model = model.eval()

In [89]:
with torch.no_grad():
    posterior = model.posterior(Xtest, diag=False)
    posterior = model.likelihood(posterior)

In [90]:
nlat = len(test_scenarios[0].lat)
nlon = len(test_scenarios[0].lon)
ntimes = len(test_scenarios[0].timesteps)
time_slice = slice(-21, None)

prior_mean = model.mu_targets.view(1, nlat, nlon).repeat(ntimes, 1, 1)
posterior_mean = posterior.mean.reshape(nlat, nlon, -1).permute(2, 0, 1)
posterior_correction =  model.sigma_targets.view(1, nlat, nlon) * posterior_mean
posterior_mean = prior_mean + posterior_mean

posterior_stddev = posterior.stddev.reshape(nlat, nlon, -1).permute(2, 0, 1)
posterior_stddev = model.sigma_targets.view(1, nlat, nlon) * posterior_stddev

prediction = torch.distributions.Normal(posterior_mean[time_slice], posterior_stddev[time_slice])
groundtruth = test_tas[time_slice]
wlat = torch.cos(torch.deg2rad(test_scenarios[0].lat)).clip(min=torch.finfo(torch.float64).eps)[:, None]

In [92]:
metrics_ssp370 = {**compute_deterministic_metrics(prediction.mean, groundtruth, wlat),
                  **compute_probabilistic_metrics(prediction, groundtruth, wlat)}

### SSP585

In [93]:
train_cfg = {'dataset' : {'dirpath': '../data/', 'keys': ['historical', 'ssp126', 'ssp245', 'ssp370']}}
train_data = plaingp.make_data(train_cfg)

test_cfg = {'dataset' : {'dirpath': '../data/', 'keys': ['ssp585']}}
test_data = plaingp.make_data(test_cfg)

model = plaingp.make_model(train_cfg, train_data)
state_dict = torch.load('../data/models/leave-one-out-ssp/ssp585/SpatialPlainGP/state_dict.pt')
model.load_state_dict(state_dict)

<All keys matched successfully>

In [94]:
test_scenarios = test_data.scenarios
Xtest = test_scenarios.glob_inputs[:, 1:]
Xtest = (Xtest - model.mu) / model.sigma
test_tas = test_scenarios.tas
model = model.eval()

In [95]:
with torch.no_grad():
    posterior = model.posterior(Xtest, diag=False)
    posterior = model.likelihood(posterior)

In [96]:
nlat = len(test_scenarios[0].lat)
nlon = len(test_scenarios[0].lon)
ntimes = len(test_scenarios[0].timesteps)
time_slice = slice(-21, None)

prior_mean = model.mu_targets.view(1, nlat, nlon).repeat(ntimes, 1, 1)
posterior_mean = posterior.mean.reshape(nlat, nlon, -1).permute(2, 0, 1)
posterior_correction =  model.sigma_targets.view(1, nlat, nlon) * posterior_mean
posterior_mean = prior_mean + posterior_mean

posterior_stddev = posterior.stddev.reshape(nlat, nlon, -1).permute(2, 0, 1)
posterior_stddev = model.sigma_targets.view(1, nlat, nlon) * posterior_stddev

prediction = torch.distributions.Normal(posterior_mean[time_slice], posterior_stddev[time_slice])
groundtruth = test_tas[time_slice]
wlat = torch.cos(torch.deg2rad(test_scenarios[0].lat)).clip(min=torch.finfo(torch.float64).eps)[:, None]

In [97]:
metrics_ssp585 = {**compute_deterministic_metrics(prediction.mean, groundtruth, wlat),
                  **compute_probabilistic_metrics(prediction, groundtruth, wlat)}

In [102]:
# metrics_plaingp = [metrics_ssp126, metrics_ssp245, metrics_ssp370, metrics_ssp585]
scores_plaingp = pd.DataFrame(metrics_plaingp)

# FaIRGP

### SSP126

In [103]:
train_cfg = {'dataset' : {'dirpath': '../data/', 'keys': ['historical', 'ssp245', 'ssp370', 'ssp585']}}
train_data = fairgp.make_data(train_cfg)

test_cfg = {'dataset' : {'dirpath': '../data/', 'keys': ['ssp126']}}
test_data = fairgp.make_data(test_cfg)

model = fairgp.make_model(train_cfg, train_data)
state_dict = torch.load('../data/models/leave-one-out-ssp/ssp126/SimpleSpatialFaIRGP/state_dict.pt')
model.load_state_dict(state_dict)

<All keys matched successfully>

In [104]:
test_scenarios = test_data.scenarios
test_tas = test_scenarios.tas

model = model.eval()
with torch.no_grad():
    posterior = model.posterior(test_scenarios, diag=False)
    posterior = model.likelihood(posterior)

In [105]:
nlat = len(test_scenarios[0].lat)
nlon = len(test_scenarios[0].lon)
time_slice = slice(-21, None)

prior_mean = model._compute_mean(test_scenarios[0])
posterior_mean = posterior.mean.reshape(nlat, nlon, -1).permute(2, 0, 1)
posterior_correction = model.mu_targets + model.sigma_targets * posterior_mean
posterior_mean = prior_mean + posterior_correction

posterior_stddev = posterior.stddev.reshape(nlat, nlon, -1).permute(2, 0, 1)
posterior_stddev = model.sigma_targets * posterior_stddev

prediction = torch.distributions.Normal(posterior_mean[time_slice], posterior_stddev[time_slice])
groundtruth = test_tas[time_slice]

In [106]:
metrics_ssp126 = {**compute_deterministic_metrics(prediction.mean, groundtruth, wlat),
                  **compute_probabilistic_metrics(prediction, groundtruth, wlat)}

### SSP245

In [107]:
train_cfg = {'dataset' : {'dirpath': '../data/', 'keys': ['historical', 'ssp126', 'ssp370', 'ssp585']}}
train_data = fairgp.make_data(train_cfg)

test_cfg = {'dataset' : {'dirpath': '../data/', 'keys': ['ssp245']}}
test_data = fairgp.make_data(test_cfg)

model = fairgp.make_model(train_cfg, train_data)
state_dict = torch.load('../data/models/leave-one-out-ssp/ssp245/SimpleSpatialFaIRGP/state_dict.pt')
model.load_state_dict(state_dict)

<All keys matched successfully>

In [108]:
test_scenarios = test_data.scenarios
test_tas = test_scenarios.tas

model = model.eval()
with torch.no_grad():
    posterior = model.posterior(test_scenarios, diag=False)
    posterior = model.likelihood(posterior)

In [109]:
nlat = len(test_scenarios[0].lat)
nlon = len(test_scenarios[0].lon)
time_slice = slice(-21, None)

prior_mean = model._compute_mean(test_scenarios[0])
posterior_mean = posterior.mean.reshape(nlat, nlon, -1).permute(2, 0, 1)
posterior_correction = model.mu_targets + model.sigma_targets * posterior_mean
posterior_mean = prior_mean + posterior_correction

posterior_stddev = posterior.stddev.reshape(nlat, nlon, -1).permute(2, 0, 1)
posterior_stddev = model.sigma_targets * posterior_stddev

prediction = torch.distributions.Normal(posterior_mean[time_slice], posterior_stddev[time_slice])
groundtruth = test_tas[time_slice]

In [110]:
metrics_ssp245 = {**compute_deterministic_metrics(prediction.mean, groundtruth, wlat),
                  **compute_probabilistic_metrics(prediction, groundtruth, wlat)}

### SSP370

In [111]:
train_cfg = {'dataset' : {'dirpath': '../data/', 'keys': ['historical', 'ssp126', 'ssp245', 'ssp585']}}
train_data = fairgp.make_data(train_cfg)

test_cfg = {'dataset' : {'dirpath': '../data/', 'keys': ['ssp370']}}
test_data = fairgp.make_data(test_cfg)

model = fairgp.make_model(train_cfg, train_data)
state_dict = torch.load('../data/models/leave-one-out-ssp/ssp370/SimpleSpatialFaIRGP/state_dict.pt')
model.load_state_dict(state_dict)

<All keys matched successfully>

In [112]:
test_scenarios = test_data.scenarios
test_tas = test_scenarios.tas

model = model.eval()
with torch.no_grad():
    posterior = model.posterior(test_scenarios, diag=False)
    posterior = model.likelihood(posterior)

In [113]:
nlat = len(test_scenarios[0].lat)
nlon = len(test_scenarios[0].lon)
time_slice = slice(-21, None)

prior_mean = model._compute_mean(test_scenarios[0])
posterior_mean = posterior.mean.reshape(nlat, nlon, -1).permute(2, 0, 1)
posterior_correction = model.mu_targets + model.sigma_targets * posterior_mean
posterior_mean = prior_mean + posterior_correction

posterior_stddev = posterior.stddev.reshape(nlat, nlon, -1).permute(2, 0, 1)
posterior_stddev = model.sigma_targets * posterior_stddev

prediction = torch.distributions.Normal(posterior_mean[time_slice], posterior_stddev[time_slice])
groundtruth = test_tas[time_slice]

In [114]:
metrics_ssp370 = {**compute_deterministic_metrics(prediction.mean, groundtruth, wlat),
                  **compute_probabilistic_metrics(prediction, groundtruth, wlat)}

### SSP585

In [115]:
train_cfg = {'dataset' : {'dirpath': '../data/', 'keys': ['historical', 'ssp126', 'ssp245', 'ssp370']}}
train_data = fairgp.make_data(train_cfg)

test_cfg = {'dataset' : {'dirpath': '../data/', 'keys': ['ssp585']}}
test_data = fairgp.make_data(test_cfg)

model = fairgp.make_model(train_cfg, train_data)
state_dict = torch.load('../data/models/leave-one-out-ssp/ssp585/SimpleSpatialFaIRGP/state_dict.pt')
model.load_state_dict(state_dict)

<All keys matched successfully>

In [116]:
test_scenarios = test_data.scenarios
test_tas = test_scenarios.tas

model = model.eval()
with torch.no_grad():
    posterior = model.posterior(test_scenarios, diag=False)
    posterior = model.likelihood(posterior)

In [117]:
nlat = len(test_scenarios[0].lat)
nlon = len(test_scenarios[0].lon)
time_slice = slice(-21, None)

prior_mean = model._compute_mean(test_scenarios[0])
posterior_mean = posterior.mean.reshape(nlat, nlon, -1).permute(2, 0, 1)
posterior_correction = model.mu_targets + model.sigma_targets * posterior_mean
posterior_mean = prior_mean + posterior_correction

posterior_stddev = posterior.stddev.reshape(nlat, nlon, -1).permute(2, 0, 1)
posterior_stddev = model.sigma_targets * posterior_stddev

prediction = torch.distributions.Normal(posterior_mean[time_slice], posterior_stddev[time_slice])
groundtruth = test_tas[time_slice]

In [118]:
metrics_ssp585 = {**compute_deterministic_metrics(prediction.mean, groundtruth, wlat),
                  **compute_probabilistic_metrics(prediction, groundtruth, wlat)}

In [119]:
metrics_fairgp = [metrics_ssp126, metrics_ssp245, metrics_ssp370, metrics_ssp585]
scores_fairgp = pd.DataFrame(metrics_fairgp)

---

In [122]:
fairgp_summary = scores_fairgp.aggregate(['mean', 'std'])
gp_summary = scores_plaingp.aggregate(['mean', 'std'])

In [124]:
model_cols = ['FaIRGP', 'Plain GP']
metrics_cols = ['Bias', 'RMSE', 'MAE', 'Corr', 'LL', 'Calib95']


results = pd.concat([fairgp_summary.loc['mean'], gp_summary.loc['mean']], axis=1)
results.columns = model_cols
results = results.T
results.columns = metrics_cols

stddev = pd.concat([fairgp_summary.loc['std'], gp_summary.loc['std']], axis=1)
stddev.columns = model_cols
stddev = stddev.T
stddev.columns = metrics_cols

In [126]:
results.to_latex('./tables/spatial-ssp-experiment-mean-scores.tex')
results

  results.to_latex('./tables/spatial-ssp-experiment-mean-scores.tex')


Unnamed: 0,Bias,RMSE,MAE,Corr,LL,Calib95
FaIRGP,-0.077211,0.623604,0.455664,0.915279,-1.109238,0.788524
Plain GP,-0.205981,1.061161,0.76788,0.831512,-4.769297,0.615509


In [127]:
stddev

Unnamed: 0,Bias,RMSE,MAE,Corr,LL,Calib95
FaIRGP,0.152633,0.191426,0.143806,0.08274,0.626609,0.131048
Plain GP,0.239941,0.565356,0.415852,0.120903,4.075295,0.121804
