In [1]:
from pathlib import Path
import torch

cmip_path = Path("/data0/datasets/weatherbench/data/esgf/cmip6/5.625deg")   # replace with path
era_path = Path("/data0/datasets/weatherbench/data/weatherbench/era5/5.625deg")

In [2]:
import xarray as xr

import sys
sys.path.insert(0, '/home/seongbin/climate-learn/src')

In [3]:
from climate_learn.data import DataModuleArgs, DataModule
from climate_learn.data.climate_dataset.args import ERA5Args, CMIP6Args
from climate_learn.data.tasks.args import ForecastingArgs

In [4]:
%load_ext autoreload
%autoreload 2

In [6]:
cmip_data_args = CMIP6Args(
    root_dir=cmip_path,
    variables=["temperature", "geopotential", "u_component_of_wind", "v_component_of_wind", "specific_humidity", "air_temperature"],
    years=range(1979, 2014)
)

forecasting_args = ForecastingArgs(
    dataset_args=cmip_data_args,
    in_vars = ['geopotential', 'u_component_of_wind', 'v_component_of_wind', 'temperature', 'specific_humidity', 'air_temperature'],
    out_vars = ["temperature_850", "geopotential_500", "air_temperature"],
    pred_range=3*24
)

data_module_args = DataModuleArgs(
    task_args=forecasting_args,
    train_start_year=2011,
    val_start_year=2012,
    test_start_year=2013,
    end_year=2014
)

cmip_data_module = DataModule(
    data_module_args=data_module_args,
    batch_size=32,
    num_workers=4
)

<class 'climate_learn.data.climate_dataset.cmip6_module.CMIP6'>
Creating train dataset


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  6.41it/s]


(1460, 36, 32, 64)
<class 'climate_learn.data.climate_dataset.cmip6_module.CMIP6'>
Creating val dataset


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  6.52it/s]


(1464, 36, 32, 64)
<class 'climate_learn.data.climate_dataset.cmip6_module.CMIP6'>
Creating test dataset


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  6.87it/s]


(2920, 36, 32, 64)


In [59]:
cmip_data_module.test_dataset.inp_data.dtype

dtype('float32')

In [8]:
data_args = ERA5Args(
    root_dir=era_path,
    variables=['geopotential', 'u_component_of_wind', 'v_component_of_wind', 'temperature', 'specific_humidity', '2m_temperature'],
    years=range(1979, 2018)
)

forecasting_args = ForecastingArgs(
    dataset_args=data_args,
    in_vars = ['geopotential', 'u_component_of_wind', 'v_component_of_wind', 'temperature', 'specific_humidity', '2m_temperature'],
    out_vars = ["temperature_850", "geopotential_500", "2m_temperature"],
    pred_range=3*24,
    subsample=6
)

data_module_args = DataModuleArgs(
    task_args=forecasting_args,
    train_start_year=2011,
    val_start_year=2012,
    test_start_year=2013,
    end_year=2014
)

data_module = DataModule(
    data_module_args=data_module_args,
    batch_size=32,
    num_workers=4
)

<class 'climate_learn.data.climate_dataset.era5_module.ERA5'>
Creating train dataset


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 20.81it/s]


(1460, 36, 32, 64)
<class 'climate_learn.data.climate_dataset.era5_module.ERA5'>
Creating val dataset


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 20.35it/s]


(1464, 36, 32, 64)
<class 'climate_learn.data.climate_dataset.era5_module.ERA5'>
Creating test dataset


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 21.21it/s]


(2920, 36, 32, 64)


In [9]:
from climate_learn.models import load_model
from torch.optim import AdamW

In [13]:
cmip6_model_kwargs = {
    "in_channels": 36,
    "out_channels": 3,
    "n_blocks": 19
}

optim_kwargs = {
    "lr": 1e-4,
    "weight_decay": 1e-5,
    "warmup_epochs": 1,
    "max_epochs": 5,
    "optimizer": AdamW
}

cmip_model_module = load_model(name = "resnet", task = "forecasting", model_kwargs = cmip6_model_kwargs, optim_kwargs = optim_kwargs)

In [14]:
from climate_learn.training import Trainer, WandbLogger
from climate_learn.models import set_climatology

In [16]:
set_climatology(cmip_model_module, cmip_data_module)

In [17]:
cmip_trainer = Trainer(
    seed = 0,
    accelerator = "gpu",
    precision = 16,
    max_epochs = 64,
    # logger = WandbLogger(project = "climate_tutorial", name = "forecast-vit")
)

In [18]:
from climate_learn.models import fit_lin_reg_baseline
fit_lin_reg_baseline(model_module, cmip_data_module, reg_hparam=0.0)

  dual_coef = linalg.solve(K, y, assume_a="pos", overwrite_a=False)


In [20]:
# cmip -> era
cmip_trainer.test(cmip_model_module, data_module, "/data0/ckpts/seongbin/data-cross-train-2/clean-36-input/129j158z/checkpoints/last.ckpt")

Output()

Output()

IndexError: pop from empty list

In [21]:
# cmip -> cmip
cmip_trainer.test(cmip_model_module, cmip_data_module, "/data0/ckpts/seongbin/data-cross-train-2/clean-36-input/129j158z/checkpoints/last.ckpt")

Output()

Output()

IndexError: pop from empty list

## Train on ERA5

In [29]:
era5_model_kwargs = {
    "in_channels": 36,
    "out_channels": 3,
    "n_blocks": 19
}

optim_kwargs = {
    "lr": 1e-4,
    "weight_decay": 1e-5,
    "warmup_epochs": 1,
    "max_epochs": 5,
    "optimizer": AdamW
}

era_model_module = load_model(name = "resnet", task = "forecasting", model_kwargs = era5_model_kwargs, optim_kwargs = optim_kwargs)

In [30]:
set_climatology(era_model_module, data_module)

In [31]:
era5_trainer = Trainer(
    seed = 0,
    accelerator = "gpu",
    precision = 16,
    max_epochs = 64,
    # logger = WandbLogger(project = "climate_tutorial", name = "forecast-vit")
)

In [25]:
from climate_learn.models import fit_lin_reg_baseline
fit_lin_reg_baseline(era_model_module, data_module, reg_hparam=0.0)

  dual_coef = linalg.solve(K, y, assume_a="pos", overwrite_a=False)


# Data

Days: 5, Var: geopotential 500, model: resnet

train_start_year = 1979 / 1851, val_start_year = 2011, test_start_year = 2013, end_year = 2014

In [32]:
# era -> era
era5_trainer.test(era_model_module, data_module, "/data0/ckpts/seongbin/data-cross-train-2/clean-36-input/15jo4h1y/checkpoints/last.ckpt")

Output()

Output()

IndexError: pop from empty list

In [33]:
# era -> cmip
era5_trainer.test(era_model_module, cmip_data_module,  "/data0/ckpts/seongbin/data-cross-train-2/clean-36-input/15jo4h1y/checkpoints/last.ckpt")

Output()

Output()

IndexError: pop from empty list