In [1]:
from pathlib import Path
import torch

cmip_path = Path("/data0/datasets/weatherbench/data/esgf/cmip6/5.625deg")   # replace with path
era_path = Path("/data0/datasets/weatherbench/data/weatherbench/era5/5.625deg")

In [2]:
import xarray as xr

import sys
sys.path.insert(0, '/home/seongbin/climate_learn')

In [3]:
from climate_learn.utils.data import load_dataset, view

cmip6_t_dataset = load_dataset(cmip_path/"geopotential")

In [4]:
era5_t_dataset = load_dataset(era_path/"geopotential")

In [5]:
# convert air pressure of cmip to hPa
cmip6_t_dataset['plev'] = cmip6_t_dataset['plev']/100
cmip6_t_dataset['plev'].attrs['units'] = 'hPa'

In [6]:
cmip6_t_dataset['zg'] = cmip6_t_dataset['zg']*9.8

In [7]:
cmip6_t_dataset['zg'].attrs['units'] = 'm**2 s**-2'

In [8]:
cmip6_z500_dataset = cmip6_t_dataset['zg'][:,15,:,:] #ds at 500 hpa
cmip6_z500_dataset

Unnamed: 0,Array,Chunk
Bytes,1.84 GiB,57.09 MiB
Shape,"(241060, 32, 64)","(7308, 32, 64)"
Count,165 Tasks,33 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 1.84 GiB 57.09 MiB Shape (241060, 32, 64) (7308, 32, 64) Count 165 Tasks 33 Chunks Type float32 numpy.ndarray",64  32  241060,

Unnamed: 0,Array,Chunk
Bytes,1.84 GiB,57.09 MiB
Shape,"(241060, 32, 64)","(7308, 32, 64)"
Count,165 Tasks,33 Chunks
Type,float32,numpy.ndarray


In [9]:
era5_t850_dataset = era5_t_dataset['z'][:,7,:,:] #ds at 500 hpa
era5_t850_dataset

Unnamed: 0,Array,Chunk
Bytes,2.68 GiB,68.62 MiB
Shape,"(350640, 32, 64)","(8784, 32, 64)"
Count,160 Tasks,40 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 2.68 GiB 68.62 MiB Shape (350640, 32, 64) (8784, 32, 64) Count 160 Tasks 40 Chunks Type float32 numpy.ndarray",64  32  350640,

Unnamed: 0,Array,Chunk
Bytes,2.68 GiB,68.62 MiB
Shape,"(350640, 32, 64)","(8784, 32, 64)"
Count,160 Tasks,40 Chunks
Type,float32,numpy.ndarray


# Training on CMIP6, Testing on ERA5

In [10]:
from climate_learn.utils.datetime import Year, Days, Hours
from climate_learn.data import DataModule

cmip6_data_module = DataModule(
    dataset = "CMIP6",
    task = "forecasting",
    root_dir = cmip_path,
    in_vars = ["geopotential"],
    out_vars = ["geopotential"],
    train_start_year = Year(1851),
    val_start_year = Year(2011),
    test_start_year = Year(2013),
    end_year = Year(2014),
    pred_range = Days(3),
    subsample = Hours(6),
    batch_size = 128,
    num_workers = 16,
    data_file = cmip6_z500_dataset
)

Creating train dataset
returning load_from_nc
done with get_lat_lon
<xarray.DataArray 'zg' (time: 233756, plev: 1, lat: 32, lon: 64)>
dask.array<transpose, shape=(233756, 1, 32, 64), dtype=float32, chunksize=(7308, 1, 32, 64), chunktype=numpy.ndarray>
Coordinates:
  * time     (time) datetime64[ns] 1851-01-01 ... 2010-12-31T18:00:00
  * plev     (plev) float64 500.0
  * lon      (lon) float64 0.0 5.625 11.25 16.88 ... 337.5 343.1 348.8 354.4
  * lat      (lat) float64 -87.19 -81.56 -75.94 -70.31 ... 75.94 81.56 87.19
Attributes:
    units:    m**2 s**-2
<xarray.DataArray 'zg' (time: 233756, plev: 1, lat: 32, lon: 64)>
dask.array<transpose, shape=(233756, 1, 32, 64), dtype=float32, chunksize=(7308, 1, 32, 64), chunktype=numpy.ndarray>
Coordinates:
  * time     (time) datetime64[ns] 1851-01-01 ... 2010-12-31T18:00:00
  * plev     (plev) float64 500.0
  * lon      (lon) float64 0.0 5.625 11.25 16.88 ... 337.5 343.1 348.8 354.4
  * lat      (lat) float64 -87.19 -81.56 -75.94 -70.31 ... 75.

In [11]:
era5_data_module = DataModule(
    dataset = "ERA5",
    task = "forecasting",
    root_dir = era_path,
    in_vars = ["geopotential_500"],
    out_vars = ["geopotential_500"],
    train_start_year = Year(1979),
    val_start_year = Year(2011),
    test_start_year = Year(2013),
    end_year = Year(2014),
    pred_range = Days(3),
    subsample = Hours(6),
    batch_size = 128,
    num_workers = 16,
    data_file = cmip6_z500_dataset
)

Creating train dataset


  0%|          | 0/32 [00:00<?, ?it/s]

<xarray.DataArray 'z' (time: 280512, level: 1, lat: 32, lon: 64)>
dask.array<concatenate, shape=(280512, 1, 32, 64), dtype=float32, chunksize=(8784, 1, 32, 64), chunktype=numpy.ndarray>
Coordinates:
  * level    (level) int64 500
  * lon      (lon) float64 0.0 5.625 11.25 16.88 ... 337.5 343.1 348.8 354.4
  * lat      (lat) float64 -87.19 -81.56 -75.94 -70.31 ... 75.94 81.56 87.19
  * time     (time) datetime64[ns] 1979-01-01 ... 2010-12-31T23:00:00
Attributes:
    units:          m**2 s**-2
    long_name:      Geopotential
    standard_name:  geopotential
Creating val dataset


  0%|          | 0/2 [00:00<?, ?it/s]

<xarray.DataArray 'z' (time: 17544, level: 1, lat: 32, lon: 64)>
dask.array<concatenate, shape=(17544, 1, 32, 64), dtype=float32, chunksize=(8784, 1, 32, 64), chunktype=numpy.ndarray>
Coordinates:
  * level    (level) int64 500
  * lat      (lat) float64 -87.19 -81.56 -75.94 -70.31 ... 75.94 81.56 87.19
  * lon      (lon) float64 0.0 5.625 11.25 16.88 ... 337.5 343.1 348.8 354.4
  * time     (time) datetime64[ns] 2011-01-01 ... 2012-12-31T23:00:00
Attributes:
    units:          m**2 s**-2
    long_name:      Geopotential
    standard_name:  geopotential
Creating test dataset


  0%|          | 0/2 [00:00<?, ?it/s]

<xarray.DataArray 'z' (time: 17520, level: 1, lat: 32, lon: 64)>
dask.array<concatenate, shape=(17520, 1, 32, 64), dtype=float32, chunksize=(8760, 1, 32, 64), chunktype=numpy.ndarray>
Coordinates:
  * lat      (lat) float64 -87.19 -81.56 -75.94 -70.31 ... 75.94 81.56 87.19
  * lon      (lon) float64 0.0 5.625 11.25 16.88 ... 337.5 343.1 348.8 354.4
  * level    (level) int64 500
  * time     (time) datetime64[ns] 2013-01-01 ... 2014-12-31T23:00:00
Attributes:
    units:          m**2 s**-2
    long_name:      Geopotential
    standard_name:  geopotential


In [12]:
from climate_learn.models import load_model


model_kwargs = {
    "in_channels": len(cmip6_data_module.hparams.in_vars),
    "out_channels": len(cmip6_data_module.hparams.out_vars),
    "n_blocks": 4
}

optim_kwargs = {
    "lr": 1e-4,
    "weight_decay": 1e-5,
    "warmup_epochs": 1,
    "max_epochs": 5,
}

model_module1 = load_model(name = "resnet", task = "forecasting", model_kwargs = model_kwargs, optim_kwargs = optim_kwargs)

In [13]:
from climate_learn.models import set_climatology
set_climatology(model_module1, cmip6_data_module)

In [14]:
from climate_learn.training import Trainer, WandbLogger

trainer1 = Trainer(
    seed = 0,
    accelerator = "gpu",
    precision = 16,
    max_epochs = 5,
    # logger = WandbLogger(project = "climate_tutorial", name = "forecast-vit")
)

In [15]:
trainer1.fit(model_module1, cmip6_data_module)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


Output()

In [16]:
trainer1.test(model_module1, era5_data_module)

Output()

In [17]:
trainer1.test(model_module1, cmip6_data_module)

Output()

?? Somehow the model trained on cmip6 predicted era5 data better than cmip6 data

# Training on ERA5, Testing on CMIP6

In [18]:
model_kwargs = {
    "in_channels": len(era5_data_module.hparams.in_vars),
    "out_channels": len(era5_data_module.hparams.out_vars),
    "n_blocks": 4
}

optim_kwargs = {
    "lr": 1e-4,
    "weight_decay": 1e-5,
    "warmup_epochs": 1,
    "max_epochs": 5,
}

model_module2 = load_model(name = "resnet", task = "forecasting", model_kwargs = model_kwargs, optim_kwargs = optim_kwargs)

In [19]:
set_climatology(model_module2, era5_data_module)

In [20]:
trainer2 = Trainer(
    seed = 0,
    accelerator = "gpu",
    precision = 16,
    max_epochs = 5,
    # logger = WandbLogger(project = "climate_tutorial", name = "forecast-vit")
)

In [21]:
trainer2.fit(model_module2, era5_data_module)

Output()

In [22]:
trainer2.test(model_module2, cmip6_data_module)

Output()

In [23]:
trainer2.test(model_module2, era5_data_module)

Output()