# Fairness of AI Weather Predictions
Author(s):
*   Nick Masi (nicholas_masi@alumni.brown.edu)
*   Randall Balestriero (randall_balestriero@brown.edu)

### Summary

In this tutorial, we will investigate AI weather prediction (AIWP) models using the [SAFE](https://pypi.org/project/safe-earth/) package. We will compare performance across different territories/countries, global subregions, and income categories of countries to analyze whether models have relatively fair performance or exhibit biases.

### Expected Learning Outcomes
* Measuring the performance of AIWP models in specific territories or regions
* Comparing the performance across strata to assess the fairness of AIWP models
* Visualizing AIWP model fairness

# Table of Contents

*   [Climate Impact](#climate-impact)
*   [Target Audience](#target-audience)
*   [Background & Prerequisites](#background-and-prereqs)
*   [Software Requirements](#software-requirements)
*   [Data Description](#data-description)
*   [Investigate](#investigate)
*   [References](#references)


<a name="climate-impact"></a>
# Climate Impact

It has been found that more accurate weather forecasts causally relates with mortality [7]. In tandem, the usage of AI models for weather forecasting is growing. Despite this, there is little research that evaluates how well the AIWP models work in one regions versus another. The recently introduced SAFE package makes this possible. Using this package can be beneficial for weather forecasters to understand which models perform best in their particular region, improving forecasts that emergency responses and policy decisions are made upon, potentially saving lives.

<a name="target-audience"></a>
# Target Audience

* ML researchers wanting to develop more fair AIWP models
* Weather forecasters trying to determine which models perform best in their region

<a name="background-and-prereqs"></a>
# Background & Prerequisites

In this tutorial we will be exploring the Google's deterministic AIWP model, GraphCast. 

GraphCast uses an MSE training objective $$\textup{MSE}=\frac{1}{|D||\Tau||I||J||V|}\sum\limits_{d,\tau,i,j,v}a_j(\hat{y}_{i,j,v}^{d+\tau}-y_{i,j,v}^{d+\tau})^2$$

where $y$ is the ground truth value for a given variable (e.g., temperature or winde speed) that a model is trying to predict, and $\hat{y}$ is the model's prediction for each corresponding $y\in Y$. Every $y$ is the value of the given variable at a specific combination of time $d\in D$, lead time $\tau\in\Tau$, longitude $i\in I$, latitude $j\in J$, and (for certain atmospheric variables) vertical level $v\in V$. $a_j$ is the latitude-weight for the cell. It is the surface area of a grid cell (on Earth's surface) at $i,j$ normalized by the mean cell surface area. For evaluation of the model's performance, the RMSE is used, which is $\sqrt{\textup{MSE}}$. Model performance is reported separately for each lead time, so we really have $\textup{RMSE}_\tau=\sqrt{\textup{MSE}_\tau}~\forall\tau\in\Tau$.

This is a standard approach in the field of using AI for weather and climate modeling. However, this approach masks spatial variation by averaging over latitude and longitude. This tutorial will demonstrate how to move past this paradigm.

<a name="software-requirements"></a>
# Software Requirements

All packages have been confirmed to work and be available as follows in this tutorial as of August 20, 2025.

In [1]:
# installs
!pip install safe-earth==0.0.18 -q # SAFE package
!pip install nbformat -q # for visualizations

In [2]:
# imports
import safe_earth as safe
from safe_earth.data.climate.era5 import ERA5Var
import numpy as np

<a name="data-description"></a>
# Data Description

Model training data and evaluation data comes from the ERA5 dataset [[1](#references)] provided by the European Centre for Medium-Range Weather Forecasts (ECMWF). A subset of this data is accessible through SAFE via a mirror hosted by WeatherBench 2 (WB2) [[2](#references)]. For this tutorial we will be looking at 1.5degree spatial resolution, meaning the Earth is segmented into a grid of 240 by 121. Each of these cells is synonymous with the gridpoint it's centered on, providing the cell with precise latitude and longitude coordinates. Predictions for lead times every 12 hours in [12, 240] are used.

One of the main benefits of SAFE is that it centralizes disparate datasets on attributes of gridpoints across the globe. The territory (e.g., which country a gridpoint is within), global subregion, and income attributes originate from the geoBoundaries dataset [[3](#references)] and [pygeoboundaries](https://github.com/ibhalin/pygeoboundaries) package. geoBoundaries determines the territorial assignments. Subregion classifications for territories come from the United Nation's Department of Economic and Social Affairs [[4](#references)]. Income classifications of territories are made by the World Bank based on their Atlas methodology for calculating gross national income (GNI) per capita [[5](#references)].

## Data Download

SAFE takes care of downloading model predictions from WB2 and ERA5 data. Attribute data is baked in. We will look at predictions made by Google's GraphCast [[6](#references)].

In [3]:
model = 'graphcast'
resolution = '240x121'
lead_times = [np.timedelta64(x, 'h') for x in range(12, 49, 12)]

variables = [ERA5Var('2m_temperature', name='T2M')]
era5 = safe.data.climate.era5.get_era5(resolution, variables=variables)
preds = safe.data.climate.wb2.get_wb2_preds(model, resolution, lead_times, variables=variables)

  ds = xr.open_zarr(wb2_stores.models[model_name][resolution])


## Data Exploration

GraphCast was trained to predict the weather at various set times into the future ("lead times") with ERA5 data spanning 1979–2019. `era5` includes snapshots of the ERA5 weather data every 6 hours in 2020, part of GraphCast's evaluation set. Here, we analyze model performance on temperature at 2 meters off the surface. This means there are no vertical levels $v$ to reduce over over. `preds` includes GraphCast's predictions for this variable (downsampled to every 12 hours) with lead times of 12 to 48 hours.

In [4]:
era5

In [5]:
preds

<a name="investigate"></a>
# Investigate

## Per-gridpoint loss

In order to do any sort of fairness evaluation, we first must get the loss at each individual gridpoint as defined by the unique $(i,j)$ pair. We've already eliminated vertical level $v$, and as stated want to calculate a different metric for each lead time $\tau$ too. We can use SAFE to calculate the resulting per-gridpoint loss we want

$$\textup{MSE}_{i,j}=\frac{1}{|D|}\sum\limits_{d}a_j(\hat{y}_{i,j}^{d}-y_{i,j}^{d})^2$$ 
where we are not reducing over any $v$ or $\tau$. 

With the SAFE package we can do this in two easy steps. First, we will further get the per-gridpoint loss at every time $d\in D$ as well. This is $a_j(\hat{y}_{i,j}-y_{i,j})^2~\forall d\in D$, which really more resembles $L2$ than $\textup{MSE}$.

In [6]:
loss_gdf = safe.metrics.losses.climate_weighted_l2(
    data=preds, 
    ground_truth=era5, 
    lon_dim='longitude', 
    lat_dim='latitude',
    lead_time_dim='prediction_timedelta'
)

We should have 116160 entries for our variable at 240x121 gridpoints and 4 different lead times

In [7]:
assert 116160 == 240*121*4
assert 116160 == len(loss_gdf)

We can quickly inspect `loss_gdf`

In [8]:
loss_gdf

Unnamed: 0,longitude,latitude,weighted_l2,geometry,variable,lead_time
0,0.0,-90.0,"[0.0005280023902331615, 0.034471699966342126, ...","POLYGON ((0.75 -90, 0.75 -89.25, -0.75 -89.25,...",T2M,12
1,0.0,-88.5,"[0.009565884919968941, 0.04475717726366594, 0....","POLYGON ((0.75 -89.25, 0.75 -87.75, -0.75 -87....",T2M,12
2,0.0,-87.0,"[0.0181766406392264, 0.07144891140664797, 0.00...","POLYGON ((0.75 -87.75, 0.75 -86.25, -0.75 -86....",T2M,12
3,0.0,-85.5,"[0.01226580855314763, 0.04359714707207892, 0.0...","POLYGON ((0.75 -86.25, 0.75 -84.75, -0.75 -84....",T2M,12
4,0.0,-84.0,"[0.027680050051256393, 0.04959510845460031, 0....","POLYGON ((0.75 -84.75, 0.75 -83.25, -0.75 -83....",T2M,12
...,...,...,...,...,...,...
116155,-1.5,84.0,"[0.0008910524695590192, 0.010219005699457776, ...","POLYGON ((-0.75 83.25, -0.75 84.75, -2.25 84.7...",T2M,48
116156,-1.5,85.5,"[0.2818628197884443, 0.15576449761849318, 0.00...","POLYGON ((-0.75 84.75, -0.75 86.25, -2.25 86.2...",T2M,48
116157,-1.5,87.0,"[0.4700292219093021, 0.047008480208742, 0.0206...","POLYGON ((-0.75 86.25, -0.75 87.75, -2.25 87.7...",T2M,48
116158,-1.5,88.5,"[0.031107043484190522, 0.018230850748467243, 0...","POLYGON ((-0.75 87.75, -0.75 89.25, -2.25 89.2...",T2M,48


We see that the `weighted_l2` column provides the values $a_j(\hat{y}_{i,j}-y_{i,j})^2~\forall d\in D$ as a list. So we now have the loss at each individual gridpoint, at every 12 hour timestamp $d$ in the year 2020, across all lead times.

## Per-strata RMSE

Now that we have the per-gridpoint losses, we can spatially reduce, but doing so in a stratified manner. From SAFE, we will get the metrics for the attributes (`attributes`) of territory, subregion, and income group.

In [9]:
import safe_earth as safe
metrics = safe.metrics.errors.stratified_rmse(
    loss_gdf,
    loss_metrics=['weighted_l2'],
    attributes='all',
    added_cols={'model': model}
)

We can investigate the stratified $\textup{RMSE}_\tau$ values of GraphCast by territory. SAFE spatially reduced $i$ and $j$ such that we have an $\textup{RMSE}_\tau$ metric for each territory.

In [10]:
metrics['territory']

Unnamed: 0,variable,lead_time,model,rmse_weighted_l2,territory
0,T2M,12.0,graphcast,0.398030,Antarctica
1,T2M,24.0,graphcast,0.482999,Antarctica
2,T2M,36.0,graphcast,0.576830,Antarctica
3,T2M,48.0,graphcast,0.674554,Antarctica
4,T2M,12.0,graphcast,0.724330,Ghana
...,...,...,...,...,...
919,T2M,48.0,graphcast,0.405269,Isle of Man
920,T2M,12.0,graphcast,0.248690,Guernsey
921,T2M,24.0,graphcast,0.317997,Guernsey
922,T2M,36.0,graphcast,0.381150,Guernsey


We can look at the other attributes too

In [11]:
metrics['subregion']

Unnamed: 0,variable,lead_time,model,rmse_weighted_l2,subregion
0,T2M,12.0,graphcast,0.398030,Antarctica
1,T2M,24.0,graphcast,0.482999,Antarctica
2,T2M,36.0,graphcast,0.576830,Antarctica
3,T2M,48.0,graphcast,0.674554,Antarctica
4,T2M,12.0,graphcast,0.740554,Western Africa
...,...,...,...,...,...
87,T2M,48.0,graphcast,0.977685,South America
88,T2M,12.0,graphcast,0.261053,Caribbean
89,T2M,24.0,graphcast,0.336955,Caribbean
90,T2M,36.0,graphcast,0.379875,Caribbean


In [12]:
metrics['income']

Unnamed: 0,income,variable,lead_time,model,rmse_weighted_l2
0,Lower-middle-income Countries,T2M,12.0,graphcast,0.686856
1,Lower-middle-income Countries,T2M,24.0,graphcast,0.779353
2,Lower-middle-income Countries,T2M,36.0,graphcast,0.861883
3,Lower-middle-income Countries,T2M,48.0,graphcast,0.909841
4,Low-income Countries,T2M,12.0,graphcast,0.73487
5,Low-income Countries,T2M,24.0,graphcast,0.883258
6,Low-income Countries,T2M,36.0,graphcast,0.986981
7,Low-income Countries,T2M,48.0,graphcast,1.042963
8,Upper-middle-income Countries,T2M,12.0,graphcast,0.706642
9,Upper-middle-income Countries,T2M,24.0,graphcast,0.799247


## Measuring performance in individual countries

Now that we have this data, let's look at some countries! We will consider predictions made 12 hours out ($\tau=12$).

In [13]:
usa_perf = metrics['territory'][metrics['territory'].territory == 'United States'][metrics['territory'].lead_time == 12].rmse_weighted_l2.item()
nam_perf = metrics['territory'][metrics['territory'].territory == 'Republic of Namibia'][metrics['territory'].lead_time == 12].rmse_weighted_l2.item()
print(f'RMSE in USA: {usa_perf}')
print(f'RMSE in Namibia: {nam_perf}')
print(f'Percent difference: {(nam_perf-usa_perf)/usa_perf}')

RMSE in USA: 0.7216426216894967
RMSE in Namibia: 0.9013644394784815
Percent difference: 0.24904545877324072


  usa_perf = metrics['territory'][metrics['territory'].territory == 'United States'][metrics['territory'].lead_time == 12].rmse_weighted_l2.item()
  nam_perf = metrics['territory'][metrics['territory'].territory == 'Republic of Namibia'][metrics['territory'].lead_time == 12].rmse_weighted_l2.item()


As an example, we see that **GraphCast is 24.9% worse at predicting surface temperature 12 hours out in Namibia than in the US.**

## Investigating Systemic Bias

We will now look to see if there are any persistent biases in GraphCast performance when grouping territories by their income level. It is easy to visualize any such trends with SAFE.

In [20]:
safe.viz.viz_metrics.incomes(metrics)

We see that across all lead times, **GraphCast performs worst at predicting temperature in low income countries.**

<a name="references"></a>
# References

[1] Hersbach, Hans, et al. “The ERA5 global reanalysis”. *Quarterly journal of the royal
meteorological society* 146.730 (2020): pp. 1999–2049.

[2] Rasp, Stephan, et al. “Weatherbench 2: A benchmark for the next generation of data-driven global weather models”. *Journal of Advances in Modeling Earth Systems* 16.6 (2024), e2023MS004019.

[3] Runfola, Daniel, et al. "geoBoundaries: A global database of political administrative boundaries." *PLoS one* 15.4 (2020), e0231866.

[4] https://unstats.un.org/sdgs/indicators/regional-groups

[5] https://datahelpdesk.worldbank.org/knowledgebase/articles/906519-world-bank-country-and-lending-groups

[6] Lam, Remi, et al. "Learning skillful medium-range global weather forecasting." *Science* 382.6677 (2023): pp. 1416-1421.

[7] Shrader, Jeffrey G., Laura Bakkensen, and Derek Lemoine. “Fatal Errors: The Mortality Value
of Accurate Weather Forecasts". Working Paper. June 2023. DOI: 10.3386/w31361. URL:
https://www.nber.org/papers/w31361 (visited on 02/16/2025).

