In [1]:
from pathlib import Path
import torch

cmip_path = Path("/data0/datasets/weatherbench/data/esgf/cmip6/5.625deg")   # replace with path
era_path = Path("/data0/datasets/weatherbench/data/weatherbench/era5/5.625deg")

In [2]:
import xarray as xr

import sys
sys.path.insert(0, '/home/seongbin/climate-learn/src')

In [3]:
from climate_learn.utils.datetime import Year, Days, Hours
from climate_learn.data import DataModule

In [4]:
era5_data_module = DataModule(
    dataset = "ERA5",
    task = "forecasting",
    root_dir = era_path,
    in_vars = ["temperature", "geopotential", "2m_temperature"],
    out_vars = ["temperature_850", "geopotential_500", "2m_temperature"],
    train_start_year = Year(1979),
    val_start_year = Year(2012),
    test_start_year = Year(2013),
    end_year = Year(2014),
    pred_range = Days(3),
    subsample = Hours(6),
    batch_size = 32,
    num_workers = 64
) 

Creating train dataset


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 33/33 [00:00<00:00, 46.45it/s]


out 289272
in 289272
Finished inp and out _data
Almost done!
Creating val dataset


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 47.56it/s]

out 8784





in 8784
Finished inp and out _data
Almost done!
Creating test dataset


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 45.60it/s]

out 17520





in 17520
Finished inp and out _data
Almost done!


In [5]:
era5_nb = len(era5_data_module.train_dataloader())
era5_wepochs = 1000 // era5_nb + 1
era5_wepochs

1

In [7]:
cmip6_data_module = DataModule(
    dataset = "CMIP6",
    task = "forecasting",
    root_dir = cmip_path,
    in_vars = ["temperature", "geopotential", "air_temperature"],
    out_vars = ["temperature_850", "geopotential_500", "air_temperature"],
    train_start_year = Year(1979),
    val_start_year = Year(2012),
    test_start_year = Year(2013),
    end_year = Year(2014),
    pred_range = Days(3),
    batch_size = 32,
    num_workers = 64
) 

Creating train dataset


  0%|          | 0/33 [00:00<?, ?it/s]

/data0/datasets/weatherbench/data/esgf/cmip6/5.625deg/temperature/1975*.nc
/data0/datasets/weatherbench/data/esgf/cmip6/5.625deg/geopotential/1975*.nc
/data0/datasets/weatherbench/data/esgf/cmip6/5.625deg/air_temperature/1975*.nc
/data0/datasets/weatherbench/data/esgf/cmip6/5.625deg/temperature/1980*.nc
/data0/datasets/weatherbench/data/esgf/cmip6/5.625deg/geopotential/1980*.nc
/data0/datasets/weatherbench/data/esgf/cmip6/5.625deg/air_temperature/1980*.nc
/data0/datasets/weatherbench/data/esgf/cmip6/5.625deg/temperature/1980*.nc
/data0/datasets/weatherbench/data/esgf/cmip6/5.625deg/geopotential/1980*.nc
/data0/datasets/weatherbench/data/esgf/cmip6/5.625deg/air_temperature/1980*.nc
/data0/datasets/weatherbench/data/esgf/cmip6/5.625deg/temperature/1980*.nc
/data0/datasets/weatherbench/data/esgf/cmip6/5.625deg/geopotential/1980*.nc
/data0/datasets/weatherbench/data/esgf/cmip6/5.625deg/air_temperature/1980*.nc
/data0/datasets/weatherbench/data/esgf/cmip6/5.625deg/temperature/1980*.nc
/data

  0%|          | 0/1 [00:00<?, ?it/s]

/data0/datasets/weatherbench/data/esgf/cmip6/5.625deg/temperature/2010*.nc
/data0/datasets/weatherbench/data/esgf/cmip6/5.625deg/geopotential/2010*.nc
/data0/datasets/weatherbench/data/esgf/cmip6/5.625deg/air_temperature/2010*.nc
in 1459
out 1459
Creating test dataset


  0%|          | 0/2 [00:00<?, ?it/s]

/data0/datasets/weatherbench/data/esgf/cmip6/5.625deg/temperature/2010*.nc
/data0/datasets/weatherbench/data/esgf/cmip6/5.625deg/geopotential/2010*.nc
/data0/datasets/weatherbench/data/esgf/cmip6/5.625deg/air_temperature/2010*.nc
/data0/datasets/weatherbench/data/esgf/cmip6/5.625deg/temperature/2010*.nc
/data0/datasets/weatherbench/data/esgf/cmip6/5.625deg/geopotential/2010*.nc
/data0/datasets/weatherbench/data/esgf/cmip6/5.625deg/air_temperature/2010*.nc
in 2918
out 2918


In [8]:
cmip6_nb = len(cmip6_data_module.train_dataloader())
cmip6_wepochs = 1000 // cmip6_nb + 1
cmip6_wepochs, len(era5_data_module.train_dataloader()), len(cmip6_data_module.train_dataloader())

(1, 1505, 1504)

In [9]:
from climate_learn.models import load_model

In [10]:
cmip_model_kwargs = {
    "in_channels": len(cmip6_data_module.hparams.in_vars),
    "out_channels": len(cmip6_data_module.hparams.out_vars),
    "n_blocks": 19
}

optim_kwargs = {
    "lr": 1e-4,
    "weight_decay": 1e-5,
    "warmup_epochs": cmip6_wepochs,
    "max_epochs": 100,
}

cmip_model_module = load_model(name = "resnet", task = "forecasting", model_kwargs = cmip_model_kwargs, optim_kwargs = optim_kwargs)

In [11]:
from climate_learn.models import set_climatology
from climate_learn.training import Trainer, WandbLogger

In [12]:
set_climatology(cmip_model_module, cmip6_data_module)

In [13]:
cmip_trainer = Trainer(
    seed = 0,
    accelerator = "gpu",
    precision = 16,
    max_epochs = 100,
    # logger = WandbLogger(project = "climate_tutorial", name = "forecast-vit")
)

In [14]:
cmip_trainer.fit(cmip_model_module, cmip6_data_module)

Output()

In [15]:
from climate_learn.models import fit_lin_reg_baseline

In [16]:
fit_lin_reg_baseline(cmip_model_module, cmip6_data_module, reg_hparam=0.0)

## Train on ERA5

In [17]:
era5_model_kwargs = {
    "in_channels": len(era5_data_module.hparams.in_vars),
    "out_channels": len(era5_data_module.hparams.out_vars),
    "n_blocks": 19
}

optim_kwargs = {
    "lr": 1e-4,
    "weight_decay": 1e-5,
    "warmup_epochs": era5_wepochs,
    "max_epochs": 100,
}

era5_model_module = load_model(name = "resnet", task = "forecasting", model_kwargs = era5_model_kwargs, optim_kwargs = optim_kwargs)

In [18]:
set_climatology(era5_model_module, era5_data_module)

In [19]:
era5_trainer = Trainer(
    seed = 0,
    accelerator = "gpu",
    precision = 16,
    max_epochs = 100,
    # logger = WandbLogger(project = "climate_tutorial", name = "forecast-vit")
)

In [23]:
era5_trainer.fit(era5_model_module, era5_data_module)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


Output()

In [24]:
fit_lin_reg_baseline(era5_model_module, era5_data_module, reg_hparam=0.0)

# Data

Days: 5, Var: temp 850, model: resnet

train_start_year = 1979 / 1979, val_start_year = 2012, test_start_year = 2013, end_year = 2014

In [25]:
# cmip -> era
cmip_trainer.test(cmip_model_module, era5_data_module)

Output()

In [26]:
# era -> era
era5_trainer.test(era5_model_module, era5_data_module)

Output()

In [22]:
# cmip -> cmip
cmip_trainer.test(cmip_model_module, cmip6_data_module)

Output()

In [27]:
# era -> cmip
era5_trainer.test(era5_model_module, cmip6_data_module)

Output()