In [1]:
from pathlib import Path
import torch

cmip_path = Path("/data0/datasets/weatherbench/data/esgf/cmip6/5.625deg")   # replace with path
era_path = Path("/data0/datasets/weatherbench/data/weatherbench/era5/5.625deg")

In [2]:
import xarray as xr

import sys
sys.path.insert(0, '/home/seongbin/climate_learn')

In [3]:
from climate_learn.utils.data import load_dataset, view

cmip6_t_dataset = load_dataset(cmip_path/"temperature")

In [4]:
era5_t_dataset = load_dataset(era_path/"temperature")

In [5]:
# convert air pressure of cmip to hPa
cmip6_t_dataset['plev'] = cmip6_t_dataset['plev']/100
cmip6_t_dataset['plev'].attrs['units'] = 'hPa'

In [6]:
cmip6_t850_dataset = cmip6_t_dataset['ta'][:,6,:,:] #ds at 850 hpa
cmip6_t850_dataset

Unnamed: 0,Array,Chunk
Bytes,1.84 GiB,57.09 MiB
Shape,"(241060, 32, 64)","(7308, 32, 64)"
Count,132 Tasks,33 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 1.84 GiB 57.09 MiB Shape (241060, 32, 64) (7308, 32, 64) Count 132 Tasks 33 Chunks Type float32 numpy.ndarray",64  32  241060,

Unnamed: 0,Array,Chunk
Bytes,1.84 GiB,57.09 MiB
Shape,"(241060, 32, 64)","(7308, 32, 64)"
Count,132 Tasks,33 Chunks
Type,float32,numpy.ndarray


In [7]:
era5_t850_dataset = era5_t_dataset['t'][:,10,:,:] #ds at 850 hpa
era5_t850_dataset

Unnamed: 0,Array,Chunk
Bytes,2.68 GiB,68.62 MiB
Shape,"(350640, 32, 64)","(8784, 32, 64)"
Count,160 Tasks,40 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 2.68 GiB 68.62 MiB Shape (350640, 32, 64) (8784, 32, 64) Count 160 Tasks 40 Chunks Type float32 numpy.ndarray",64  32  350640,

Unnamed: 0,Array,Chunk
Bytes,2.68 GiB,68.62 MiB
Shape,"(350640, 32, 64)","(8784, 32, 64)"
Count,160 Tasks,40 Chunks
Type,float32,numpy.ndarray


# Training on CMIP6, Testing on ERA5

In [8]:
from climate_learn.utils.datetime import Year, Days, Hours
from climate_learn.data import DataModule

cmip6_data_module = DataModule(
    dataset = "CMIP6",
    task = "forecasting",
    root_dir = cmip_path,
    in_vars = ["temperature"],
    out_vars = ["temperature"],
    train_start_year = Year(1851),
    val_start_year = Year(2011),
    test_start_year = Year(2013),
    end_year = Year(2014),
    pred_range = Days(3),
    subsample = Hours(6),
    batch_size = 128,
    num_workers = 1,
    data_file = cmip6_t850_dataset
)

Creating train dataset
.1
.1
.1
Creating val dataset
.1
.1
.1
Creating test dataset
.1
.1
.1


In [9]:
era5_data_module = DataModule(
    dataset = "ERA5",
    task = "forecasting",
    root_dir = era_path,
    in_vars = ["temperature_850"],
    out_vars = ["temperature_850"],
    train_start_year = Year(1979),
    val_start_year = Year(2011),
    test_start_year = Year(2013),
    end_year = Year(2014),
    pred_range = Days(3),
    subsample = Hours(6),
    batch_size = 128,
    num_workers = 1,
    data_file = cmip6_t850_dataset
)

Creating train dataset


  0%|          | 0/32 [00:00<?, ?it/s]

Creating val dataset


  0%|          | 0/2 [00:00<?, ?it/s]

Creating test dataset


  0%|          | 0/2 [00:00<?, ?it/s]

In [10]:
from climate_learn.models import load_model


model_kwargs = {
    "in_channels": len(cmip6_data_module.hparams.in_vars),
    "out_channels": len(cmip6_data_module.hparams.out_vars),
    "n_blocks": 4
}

optim_kwargs = {
    "lr": 1e-4,
    "weight_decay": 1e-5,
    "warmup_epochs": 1,
    "max_epochs": 5,
}

model_module1 = load_model(name = "resnet", task = "forecasting", model_kwargs = model_kwargs, optim_kwargs = optim_kwargs)

In [11]:
from climate_learn.models import set_climatology
set_climatology(model_module1, cmip6_data_module)

In [12]:
from climate_learn.training import Trainer, WandbLogger

trainer1 = Trainer(
    seed = 0,
    accelerator = "gpu",
    precision = 16,
    max_epochs = 5,
    # logger = WandbLogger(project = "climate_tutorial", name = "forecast-vit")
)

In [13]:
trainer1.fit(model_module1, cmip6_data_module)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


Output()

In [15]:
trainer1.test(model_module1, era5_data_module)

Output()

  rank_zero_warn(


In [16]:
trainer1.test(model_module1, cmip6_data_module)

Output()

?? Somehow the model trained on cmip6 predicted era5 data better than cmip6 data

# Training on ERA5, Testing on CMIP6

In [17]:
model_kwargs = {
    "in_channels": len(era5_data_module.hparams.in_vars),
    "out_channels": len(era5_data_module.hparams.out_vars),
    "n_blocks": 4
}

optim_kwargs = {
    "lr": 1e-4,
    "weight_decay": 1e-5,
    "warmup_epochs": 1,
    "max_epochs": 5,
}

model_module2 = load_model(name = "resnet", task = "forecasting", model_kwargs = model_kwargs, optim_kwargs = optim_kwargs)

In [18]:
set_climatology(model_module2, era5_data_module)

In [19]:
trainer2 = Trainer(
    seed = 0,
    accelerator = "gpu",
    precision = 16,
    max_epochs = 5,
    # logger = WandbLogger(project = "climate_tutorial", name = "forecast-vit")
)

In [20]:
trainer2.fit(model_module2, era5_data_module)

Output()

In [21]:
trainer2.test(model_module2, cmip6_data_module)

Output()

In [22]:
trainer2.test(model_module2, era5_data_module)

Output()