In [1]:
from pathlib import Path
import torch

cmip_path = Path("/data0/datasets/weatherbench/data/weatherbench/cmip6/5.625")   # replace with path
era_path = Path("/data0/datasets/weatherbench/data/weatherbench/era5/5.625deg")

In [2]:
import xarray as xr

import sys
sys.path.insert(0, '/home/seongbin/climate-learn/src')

In [3]:
from climate_learn.utils.datetime import Year, Days, Hours
from climate_learn.data import DataModule

In [4]:
era5_data_module = DataModule(
    dataset = "ERA5",
    task = "forecasting",
    root_dir = era_path,
    in_vars = ["temperature"],
    out_vars = ["temperature_850"],
    train_start_year = Year(2007),
    val_start_year = Year(2011),
    test_start_year = Year(2013),
    end_year = Year(2014),
    pred_range = Days(5),
    subsample = Hours(6),
    batch_size = 128,
    num_workers = 16
) 

Creating train dataset
update


100%|█| 4/4 [00:00<00:00, 96.6

out 35064





in 35064
Finished inp and out _data
True
Almost done!
Creating val dataset
update


100%|█| 2/2 [00:00<00:00, 135.


out 17544
in 17544
Finished inp and out _data
True
Almost done!
Creating test dataset
update


100%|█| 2/2 [00:00<00:00, 158.

out 17520





in 17520
Finished inp and out _data
True
Almost done!


In [5]:
cmip6_data_module = DataModule(
    dataset = "CMIP6",
    task = "forecasting",
    root_dir = cmip_path,
    in_vars = ["temperature"],
    out_vars = ["temperature_850"],
    train_start_year = Year(2000),
    val_start_year = Year(2011),
    test_start_year = Year(2013),
    end_year = Year(2014),
    pred_range = Days(5),
    subsample = Hours(6),
    batch_size = 128,
    num_workers = 16
) 

Creating train dataset


  0%|          | 0/11 [00:00<?, ?it/s]

in 16069
out 16069
Creating val dataset


  0%|          | 0/2 [00:00<?, ?it/s]

in 2918
out 2918
Creating test dataset


  0%|          | 0/2 [00:00<?, ?it/s]

in 2918
out 2918


In [6]:
from climate_learn.models import load_model

In [7]:
cmip_model_kwargs = {
    "in_channels": len(cmip6_data_module.hparams.in_vars),
    "out_channels": len(cmip6_data_module.hparams.out_vars),
    "n_blocks": 4
}

optim_kwargs = {
    "lr": 1e-4,
    "weight_decay": 1e-5,
    "warmup_epochs": 1,
    "max_epochs": 2,
}

cmip_model_module = load_model(name = "resnet", task = "forecasting", model_kwargs = cmip_model_kwargs, optim_kwargs = optim_kwargs)

In [15]:
from climate_learn.models import set_climatology
from climate_learn.training import Trainer, WandbLogger

In [None]:
set_climatology(cmip_model_module, cmip6_data_module)

In [10]:
cmip_trainer = Trainer(
    seed = 0,
    accelerator = "gpu",
    precision = 16,
    max_epochs = 5,
    # logger = WandbLogger(project = "climate_tutorial", name = "forecast-vit")
)

In [11]:
cmip_trainer.fit(cmip_model_module, cmip6_data_module)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


Output()

In [12]:
from climate_learn.models import fit_lin_reg_baseline
fit_lin_reg_baseline(cmip_model_module, cmip6_data_module, reg_hparam=0.0)

## Train on ERA5

In [9]:
era5_model_kwargs = {
    "in_channels": len(era5_data_module.hparams.in_vars),
    "out_channels": len(era5_data_module.hparams.out_vars),
    "n_blocks": 4
}

optim_kwargs = {
    "lr": 1e-4,
    "weight_decay": 1e-5,
    "warmup_epochs": 1,
    "max_epochs": 5,
}

era5_model_module = load_model(name = "resnet", task = "forecasting", model_kwargs = era5_model_kwargs, optim_kwargs = optim_kwargs)

In [17]:
set_climatology(era5_model_module, era5_data_module)

In [18]:
era5_trainer = Trainer(
    seed = 0,
    accelerator = "gpu",
    precision = 16,
    max_epochs = 5,
    # logger = WandbLogger(project = "climate_tutorial", name = "forecast-vit")
)

In [19]:
era5_trainer.fit(era5_model_module, era5_data_module)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


Output()

In [20]:
fit_lin_reg_baseline(era5_model_module, era5_data_module, reg_hparam=0.0)

# Data

Days: 5, Var: temp 850, model: resnet

train_start_year = 1979 / 1851, val_start_year = 2011, test_start_year = 2013, end_year = 2014

In [13]:
# cmip -> era
cmip_trainer.test(cmip_model_module, era5_data_module)

Output()

In [21]:
# era -> era
era5_trainer.test(era5_model_module, era5_data_module)

Output()

In [15]:
# cmip -> cmip
cmip_trainer.test(cmip_model_module, cmip6_data_module)

Output()

In [22]:
# era -> cmip
era5_trainer.test(era5_model_module, cmip6_data_module)

Output()

## Unet??

In [10]:
cmip_unet_model_module = load_model(name = "unet", task = "forecasting", model_kwargs = cmip_model_kwargs, optim_kwargs = optim_kwargs)
era5_unet_model_module = load_model(name = "unet", task = "forecasting", model_kwargs = era5_model_kwargs, optim_kwargs = optim_kwargs)

In [13]:
set_climatology(era5_unet_model_module, era5_data_module)
set_climatology(cmip_unet_model_module, cmip6_data_module)

In [16]:
cmip_unet_trainer = Trainer(
    seed = 0,
    accelerator = "gpu",
    precision = 16,
    max_epochs = 5,
    # logger = WandbLogger(project = "climate_tutorial", name = "forecast-vit")
)
cmip_unet_trainer.fit(cmip_unet_model_module, cmip6_data_module)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


Output()

In [19]:
era5_unet_trainer = Trainer(
    seed = 0,
    accelerator = "gpu",
    precision = 16,
    max_epochs = 5,
    # logger = WandbLogger(project = "climate_tutorial", name = "forecast-vit")
)
era5_unet_trainer.fit(era5_unet_model_module, era5_data_module)

Output()

In [20]:
# era -> era
era5_unet_trainer.test(era5_unet_model_module, era5_data_module)

Output()

In [17]:
# cmip -> era
cmip_unet_trainer.test(cmip_unet_model_module, era5_data_module)

Output()

In [18]:
# cmip -> cmip
cmip_unet_trainer.test(cmip_unet_model_module, cmip6_data_module)

Output()

In [21]:
# era -> cmip
era5_unet_trainer.test(era5_unet_model_module, cmip6_data_module)

Output()

## ViT???

In [22]:
cmip_vit_model_kwargs = {
    "img_size": [32, 64],
    "in_vars": ["temperature_850"],
    "out_vars": ["temperature_850"],
}

era5_vit_model_kwargs = {
    "img_size": [32, 64],
    "in_vars": ["temperature_850"],
    "out_vars": ["temperature_850"],
}

# model_kwargs = {
#     "in_vars": ["2m_temperature"],
#     "out_vars": ["2m_temperature"],
#     "img_size": [32, 64]
# }

In [23]:
cmip_vit_model_module = load_model(name = "vit", task = "forecasting", model_kwargs = cmip_vit_model_kwargs, optim_kwargs = optim_kwargs)
era5_vit_model_module = load_model(name = "vit", task = "forecasting", model_kwargs = era5_vit_model_kwargs, optim_kwargs = optim_kwargs)

In [24]:
set_climatology(era5_vit_model_module, era5_data_module)
set_climatology(cmip_vit_model_module, cmip6_data_module)

In [25]:
cmip_vit_trainer = Trainer(
    seed = 0,
    accelerator = "gpu",
    precision = 16,
    max_epochs = 5,
    # logger = WandbLogger(project = "climate_tutorial", name = "forecast-vit")
)
cmip_vit_trainer.fit(cmip_vit_model_module, cmip6_data_module)

Output()

In [26]:
era5_vit_trainer = Trainer(
    seed = 0,
    accelerator = "gpu",
    precision = 16,
    max_epochs = 5,
    # logger = WandbLogger(project = "climate_tutorial", name = "forecast-vit")
)
era5_vit_trainer.fit(era5_vit_model_module, era5_data_module)

Output()

In [27]:
# cmip -> era
cmip_vit_trainer.test(cmip_vit_model_module, era5_data_module)

Output()

In [28]:
# cmip -> cmip
cmip_vit_trainer.test(cmip_vit_model_module, cmip6_data_module)

Output()

In [29]:
# era -> era
era5_vit_trainer.test(era5_vit_model_module, era5_data_module)

Output()

In [30]:
# era -> cmip
era5_vit_trainer.test(era5_vit_model_module, cmip6_data_module)

Output()