In [1]:
import sys
sys.path.append('/app')

from src.climate_learn import convert_nc2npz, IterDataModule
from src.climate_learn.utils import load_downscaling_module
import numpy as np
import os
import glob

from IPython.display import HTML
import pytorch_lightning as pl
from pytorch_lightning.callbacks import (
    EarlyStopping,
    ModelCheckpoint,
    RichModelSummary,
    RichProgressBar
)
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger

The following function call processes the WeatherBench ERA5 data into a form that is easily ingestable for PyTorch models and defines the training-validation-testing splits. In addition, we shard the data to create sets of smaller files rather than one large file for each split.

### CMIP_ERA task

In [3]:
# ERA WeatherBench for CMIP_ERA task
convert_nc2npz(
    root_dir="/app/data/raw/cmip6-era5/era5_0.25/D/",
    save_dir="/app/data/processed/cmip6-era5/era5_0.25/D",
    src="era5",
    variables=[
               "2m_temperature",
               "10m_u_component_of_wind",
               "10m_v_component_of_wind",
            #    "surface_pressure",  # exclude for "D" frequency
               "total_precipitation"
               ],
    start_train_year=1960,
    start_val_year=2011,
    start_test_year=2013,
    end_year=2015,
    num_shards=5, # set 5 for "D" and 20 for "3H"
    frequency="D", # H | 3H | D
    align_target = None
)

100%|██████████| 1/1 [01:45<00:00, 105.81s/it]
100%|██████████| 2/2 [03:02<00:00, 91.36s/it]
100%|██████████| 2/2 [02:59<00:00, 89.88s/it]


In [2]:
# CMIP data for CMIP_ERA task
convert_nc2npz(
    root_dir="/app/data/raw/cmip6-era5/cmip6/D",
    save_dir="/app/data/processed/cmip6-era5/cmip6/D",
    src="cmip6",
    variables=[
                "air_temperature",
                "u_component_of_wind",
                "v_component_of_wind",
                "pressure_sea_level", # "surface_pressure" for 3H | "pressure_sea_level" for "D"
                "precipitation",
                "specific_humidity",
                "cloud_cover",
                "upward_heat_flux",
                "moisture_in_soil"
               ],
    start_train_year=1960,
    start_val_year=2011,
    start_test_year=2013,
    end_year=2015,
    num_shards=5,  # set 5 for "D" and 20 for "3H"
    frequency="D", # H | 3H | D
    align_target = "/app/data/raw/cmip6-era5/era5_0.25/D",
    scale_factor=4,
)

100%|██████████| 51/51 [17:35<00:00, 20.70s/it]
100%|██████████| 2/2 [00:33<00:00, 16.96s/it]
100%|██████████| 2/2 [00:36<00:00, 18.23s/it]


### ERA_EOBS task

In [7]:
# #ERA data for ERA-EOBS task
convert_nc2npz(
    root_dir="/app/data/raw/era5-eobs/era5_0.25_D/",
    save_dir="/app/data/processed/era5-eobs/era5_0.25_D/",
    src="era5",
    variables=[
            "2m_temperature",
            "maximum_temperature",
            "minimum_temperature",
         #    "surface_pressure",  # exclude for "D" frequency
            "rainfall"
               ],
    start_train_year=1960,
    start_val_year=2018,
    start_test_year=2020,
    end_year=2022,
    num_shards=5,
    frequency="D",
    align_target = None,
    periodic=False # whether data cover all the globe
)

100%|██████████| 58/58 [47:28<00:00, 49.11s/it] 
100%|██████████| 2/2 [00:22<00:00, 11.30s/it]
100%|██████████| 2/2 [00:22<00:00, 11.07s/it]


In [5]:
# EOBS data for ERA-EOBS task
convert_nc2npz(
    root_dir="/app/data/raw/era5-eobs/e-obs/ensemble_mean/010_grid/1950-2023/",
    save_dir="/app/data/processed/era5-eobs/e-obs/ensemble_mean/0125_grid/",
    src="eobs",
    variables=[
                "mean_temperature",
                "minimum_temperature",
                "maximum_temperature",
                "precipitation_sum",
                "sea_level_pressure_avg",
                "relative_humidity_avg",
                "global_radiation_mean",
               ],
    start_train_year=1960,
    start_val_year=2018,
    start_test_year=2020,
    end_year=2022,
    num_shards=5, # set 5 for "D" and 20 for "3H"
    frequency="D",
    align_target = "/app/data/raw/era5-eobs/era5_0.25_D",
    scale_factor=0.5,
    periodic=False # whether data cover all the globe
)

100%|██████████| 58/58 [1:09:28<00:00, 71.86s/it]
100%|██████████| 2/2 [01:24<00:00, 42.31s/it]
100%|██████████| 2/2 [01:42<00:00, 51.28s/it]


In [6]:
# Create mask for E-OBS dataset to exclude NaNs
src_dir='/app/data/processed/era5-eobs/e-obs/ensemble_mean/0125_grid/'

for folder in ["train", "test", "val"]:
    inp_file_list = sorted(
            glob.glob(os.path.join(src_dir, folder, "*.npz"))
        )
    inp_file_list = [f for f in inp_file_list if "climatology" not in f]
    n_files = len(inp_file_list)
    for idx in range(n_files):
        inp = np.load(inp_file_list[idx])
        for k in ["tg", "tx", "tn", "rr"]:
            mask = ~np.isnan(inp[k]).any(axis=0)*1
            try:
                mask_global = (mask_global == 1) & (mask == 1)
            except NameError:
                mask_global = mask == 1
            # print(np.sum(mask_global))

# Save mask
np.save(os.path.join(src_dir, "mask.npy"), mask_global)

### CMIP-CMIP task

In [2]:
# CMIP HR
convert_nc2npz(
    root_dir="/app/data/raw/cmip6-cmip6/HR",
    save_dir="/app/data/processed/cmip6-cmip6/HR",
    src="cmip6",
    variables=[
                "air_temperature",
                "u_component_of_wind",
                "v_component_of_wind",
                "pressure_sea_level", # "surface_pressure" for 3H | "pressure_sea_level" for "D"
                "precipitation",
                "specific_humidity",
                "cloud_cover",
                "upward_heat_flux",
                "moisture_in_soil"
               ],
    start_train_year=1960,
    start_val_year=2011,
    start_test_year=2013,
    end_year=2015,
    num_shards=5,  # set 5 for "D" and 20 for "3H"
    frequency="D", # H | 3H | D
    align_target = None
)

100%|██████████| 51/51 [20:59<00:00, 24.70s/it]
100%|██████████| 2/2 [00:45<00:00, 22.83s/it]
100%|██████████| 2/2 [00:46<00:00, 23.22s/it]


In [2]:
# CMIP LR
convert_nc2npz(
    root_dir="/app/data/raw/cmip6-cmip6/LR",
    save_dir="/app/data/processed/cmip6-cmip6/LR",
    src="cmip6",
    variables=[
                "air_temperature",
                "u_component_of_wind",
                "v_component_of_wind",
                "pressure_sea_level", # "surface_pressure" for 3H | "pressure_sea_level" for "D"
                "precipitation",
                "specific_humidity",
                "cloud_cover",
                "upward_heat_flux",
                "moisture_in_soil"
               ],
    start_train_year=1960,
    start_val_year=2011,
    start_test_year=2013,
    end_year=2015,
    num_shards=5,  # set 5 for "D" and 20 for "3H"
    frequency="D", # H | 3H | D
    align_target = "/app/data/raw/cmip6-cmip6/HR",
    scale_factor=2
)

100%|██████████| 51/51 [07:51<00:00,  9.25s/it]
100%|██████████| 2/2 [00:16<00:00,  8.18s/it]
100%|██████████| 2/2 [00:16<00:00,  8.27s/it]


The downloaded and processed data is loaded into a PyTorch Lightning data module. In the following code cell, we use the following settings:
- `subsample = 6`. The dataset is subsampled at 6 hour intervals; this is done so that training is faster, but one could also use no subsampling (_i.e._, `subsample = 1`, which is the default).
- `pred_range = 24`. The model's objective is to predict `2m_temperature` 24 hours in the future.
- `history = 3`. When making a prediction, the model is given data at time `t`, `t-subsample`, and `t-subsample*2`.
- `task = "direct-forecasting"`. Given the inputs, the model directly predicts the outputs at `pred_range`. Other methods of forecasting are iterative forecasting and continuous forecasting. We refer to section 3 of [this paper by Rasp and Theurey](https://arxiv.org/pdf/2008.08626.pdf) for a description of these forecasting types.

Note further that `in_vars` and `out_vars` are the same, meaning the model consumes historical temperature and geopotential as input and produces predicted temperature and geopotential as output.

Before running this next code cell, we recommend switching to a GPU-accelerated runtime then re-running all code cells related to installation and library imports. You do _NOT_ need to re-download/process the data. Those should be saved to your Google Drive.

##### Cmip to Cmip downscaling gives MSE=0 and pearson=1 roughly

In [None]:
dm_cmip = IterDataModule(
    task="downscaling",
    inp_root_dir="/app/data/processed/cmip6/3H",
    out_root_dir="/app/data/processed/cmip6/3H",
    in_vars=["air_temperature", "u_component_of_wind", "v_component_of_wind"],
    out_vars=["air_temperature", "u_component_of_wind", "v_component_of_wind"],
    src="cmip6",
    subsample=6,
    pred_range=24,
    history=3,
    batch_size=256
)
dm_cmip.setup()

In [None]:
interpolation = load_downscaling_module(
    data_module=dm_cmip,
    architecture="bilinear-interpolation" #nearest-interpolation
)

trainer = pl.Trainer()
trainer.test(interpolation, dm_cmip)

##### Cmip to era downscaling. Baseline

In [None]:
dm_cmip_era = IterDataModule(
    task="downscaling",
    inp_root_dir="/app/data/processed/cmip6/3H",
    out_root_dir="/app/data/processed/era5_0.25deg/3H",
    in_vars=["air_temperature", "u_component_of_wind", "v_component_of_wind", "surface_pressure"],
    out_vars=["2m_temperature", "10m_u_component_of_wind", "10m_v_component_of_wind", "surface_pressure"],
    batch_size=256
)
dm_cmip_era.setup()

In [None]:
nearest = load_downscaling_module(
    data_module=dm_cmip_era,
    architecture="nearest-interpolation"
)

trainer = pl.Trainer()
trainer.test(nearest, dm_cmip_era)

Loading architecture: nearest-interpolation
Using optimizer associated with architecture
Using learning rate scheduler associated with architecture
Loading training loss: mse
No train transform
Loading validation loss: rmse
Loading validation loss: pearson
Loading validation loss: mean_bias
Loading validation loss: mse
Loading validation transform: denormalize
Loading validation transform: denormalize
Loading validation transform: denormalize
No validation transform
Loading test loss: rmse
Loading test loss: pearson
Loading test loss: mean_bias
Loading test loss: mse
Loading test transform: denormalize
Loading test transform: denormalize
Loading test transform: denormalize
No test transform


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/opt/conda/envs/bias_correction/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test/rmse:2m_temperature': 5.359344482421875,
  'test/rmse:10m_u_component_of_wind': 5.9006452560424805,
  'test/rmse:10m_v_component_of_wind': 5.991654872894287,
  'test/rmse:surface_pressure': 1908.81640625,
  'test/rmse:aggregate': 481.51708984375,
  'test/pearson:2m_temperature': 1.0766712427139282,
  'test/pearson:10m_u_component_of_wind': 0.45916029810905457,
  'test/pearson:10m_v_component_of_wind': 0.21616476774215698,
  'test/pearson:surface_pressure': 1.1198564767837524,
  'test/pearson:aggregate': 0.7179632186889648,
  'test/mean_bias:2m_temperature': -0.12681575119495392,
  'test/mean_bias:10m_u_component_of_wind': -0.023121390491724014,
  'test/mean_bias:10m_v_component_of_wind': -0.0034593292511999607,
  'test/mean_bias:surface_pressure': -44.6484375,
  'test/mean_bias:aggregate': -11.200458526611328,
  'test/mse:2m_temperature': 29.183940887451172,
  'test/mse:10m_u_component_of_wind': 34.938941955566406,
  'test/mse:10m_v_component_of_wind': 36.00859832763672,
  'tes

ClimateLearn provides standard metrics. For forecasting, it displays the latitude weighted RMSE and the latitude weighted ACC. Lower RMSE is better, while higher ACC is better. ACC has a range of [0, 1]. We use latitude weighting to adjust for the fact that we flatten the curved surface of the Earth to a 2D grid, which is squishes information at the equator and stretches information near the poles. For more info about these metrics, see this link: https://geo.libretexts.org/Bookshelves/Meteorology_and_Climate_Science/Practical_Meteorology_(Stull)/20%3A_Numerical_Weather_Prediction_(NWP)/20.7%3A_Forecast_Quality_and_Verfication

Also, you might have noticed the metrics with `aggregate` as the suffix. These represent averages. For example, `lat_rmse:aggregate` is the average of `lat_rmse:temperature` and `lat_rmse:geopotential`.

Besides these metrics, ClimateLearn also provides visualization tools. In the following cell, we first get the denormalization tranfsorm to transform the data returned by the PyTorch Lightning data module, which was normalized to $\mathcal{N}(0,1)$, back into its original range. As we can see the logging messages displayed in the previous cell's output, the persistence model's 0-th test tranfsormation is denormalization.

Then, we visualize the ground truth, prediction, and bias for the persistence prediction made on the 0-th sample of the testing set. Bias is defined as predicted minus observed (see the link provided above). It is useful to gain a visual understanding of model performance. In this example, we can see that persistence generally underpredicts the true values.

For weather forecasting with history greater than 1, the visualization function also returns a value which we save here as `in_graphic`. This graphic can be animated, as seen in the next code cell.