Forecasts should outperform the following "null models":

 - Feed-forward: predicted mortality is the most recent observation in each pixel.
 - Average: predicted mortality is the mean of observed values in the pixel.

We will have to be a little creative to get these to work in the torch framework.

In [32]:
import xarray as xr
import numpy as np
from matplotlib import pyplot as plt
import torch
import torchmetrics
from tqdm import tqdm

import os
try:
    import util
except ImportError:
    os.chdir("..")
finally:
    import util

In [14]:
# Prepare data
# Ignore smoothing and total BA calculation because we don't use those at all here.

ds = xr.open_dataset("data_working/westmort.nc")

years = ds.time.values

train_years = years[16:]
valid_years = years[:8]
test_years = years[8:16]

print("Training years:", train_years)
print("Validation years:", valid_years)
print("Testing years:", test_years)

Training years: [2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023]
Validation years: [1997 1998 1999 2000 2001 2002 2003 2004]
Testing years: [2005 2006 2007 2008 2009 2010 2011 2012]


In [21]:
# Geometric mean over time of the training set
def safe_geometric_mean(arr, dim):
    # Prevent -Inf
    arr = (arr / 100) + 1e-3
    
    arr_log = np.log(arr)
    arr_log_mean = arr_log.mean(dim=dim)
    return (np.exp(arr_log_mean) - 1e-3) * 100
    
ds["mortality_average"]  = safe_geometric_mean(ds.sel(time=train_years).mortality, ("time"))

# Last valid observation in a cell
ds["mortality_last_obs"] = ds["mortality"].ffill(dim="time").shift(time=1)

In [29]:
# Use the same setup as when we train the convnets
window = dict(x=[8, False], y=[8, False], time=[5, False])
valid_wds = util.datasets.WindowXarrayDataset(ds.sel(time=valid_years), window, mask="mortality")
test_wds  = util.datasets.WindowXarrayDataset(ds.sel(time=test_years), window, mask="mortality")

print("N valid", len(valid_wds))
print("N test", len(test_wds))


N valid 8361
N test 22516


In [31]:
# Windows have the "null models" already associated so we can just iterate
# over and calculate metrics.
w = valid_wds[10]
w

In [37]:
last_obs_valid_metrics = util.training.get_regr_metrics() + [torchmetrics.MeanSquaredError()]
last_obs_test_metrics  = util.training.get_regr_metrics() + [torchmetrics.MeanSquaredError()]
avg_valid_metrics = util.training.get_regr_metrics() + [torchmetrics.MeanSquaredError()]
avg_test_metrics  = util.training.get_regr_metrics() + [torchmetrics.MeanSquaredError()]

for patch in tqdm(valid_wds):
    target   = torch.tensor(patch.mortality.values[-1, ...])
    last_obs = torch.tensor(patch.mortality_last_obs.values[-1, ...])
    avg      = torch.tensor(patch.mortality_average.values)

    for m in last_obs_valid_metrics:
        m(target.view(-1), last_obs.view(-1))

    for m in avg_valid_metrics:
        m(target.view(-1), avg.view(-1))

for patch in tqdm(test_wds):
    target   = torch.tensor(patch.mortality.values[-1, ...])
    last_obs = torch.tensor(patch.mortality_last_obs.values[-1, ...])
    avg      = torch.tensor(patch.mortality_average.values)

    for m in last_obs_test_metrics:
        m(target.view(-1), last_obs.view(-1))

    for m in avg_test_metrics:
        m(target.view(-1), avg.view(-1))

100%|██████████| 8361/8361 [00:32<00:00, 259.98it/s]
100%|██████████| 22516/22516 [01:26<00:00, 259.25it/s]


In [38]:
print("Last observation, validation")
for metric in last_obs_valid_metrics:
    print(f"\t{str(metric)}: {metric.compute():.3f}")

print("Last observation, testing")
for metric in last_obs_test_metrics:
    print(f"\t{str(metric)}: {metric.compute():.3f}")

print("Time average, validation")
for metric in avg_valid_metrics:
    print(f"\t{str(metric)}: {metric.compute():.3f}")

print("Time average, testing")
for metric in avg_test_metrics:
    print(f"\t{str(metric)}: {metric.compute():.3f}")

Last observation, validation
	NormalizedRootMeanSquaredError(): 5.495
	R2Score(): -0.349
	MeanAbsoluteError(): 2.147
	MeanSquaredError(): 67.392
Last observation, testing
	NormalizedRootMeanSquaredError(): 4.700
	R2Score(): -0.115
	MeanAbsoluteError(): 1.280
	MeanSquaredError(): 22.611
Time average, validation
	NormalizedRootMeanSquaredError(): 16.975
	R2Score(): -18.430
	MeanAbsoluteError(): 1.951
	MeanSquaredError(): 58.455
Time average, testing
	NormalizedRootMeanSquaredError(): 9.229
	R2Score(): -5.567
	MeanAbsoluteError(): 1.135
	MeanSquaredError(): 18.212


When interpreting these remember that we didn't divide through by 100, so the MAE and MSE values should be interpreted as percentages.