In [1]:
import os
import sys
print(os.path.abspath(os.path.join('../../climate-learn')))
module_path = os.path.abspath(os.path.join('../../climate-learn'))
if module_path not in sys.path:
    sys.path.append(module_path)
print(sys.path)

/home/snandy/climate-learn
['/home/snandy/climate-models/weatherbench-replication', '/home/snandy/climate-models/weatherbench-replication', '/home/software/utils', '/home/snandy/miniconda3/envs/climate-models1/lib/python310.zip', '/home/snandy/miniconda3/envs/climate-models1/lib/python3.10', '/home/snandy/miniconda3/envs/climate-models1/lib/python3.10/lib-dynload', '', '/home/snandy/miniconda3/envs/climate-models1/lib/python3.10/site-packages', '/home/snandy/miniconda3/envs/climate-models1/lib/python3.10/site-packages/PyQt5_sip-12.11.0-py3.10-linux-x86_64.egg', '/home/snandy/miniconda3/envs/climate-models1/lib/python3.10/site-packages/PyYAML-6.0-py3.10-linux-x86_64.egg', '/home/snandy/climate-learn']


In [2]:
USING_COLAB = False

In [3]:
import torch
torch.cuda.is_available()

True

In [4]:
if USING_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')

In [5]:
from climate_learn.data import download

DOWNLOAD = False

if DOWNLOAD:
    if USING_COLAB:
        download(root = "/content/drive/MyDrive/Climate/.climate_tutorial", source = "weatherbench", variable = "geopotential_500", dataset = "era5", resolution = "5.625")
        download(root = "/content/drive/MyDrive/Climate/.climate_tutorial", source = "weatherbench", variable = "temperature_850", dataset = "era5", resolution = "5.625")
    else:
        download(root = "/data0/datasets/weatherbench", source = "weatherbench", variable = "geopotential_500", dataset = "era5", resolution = "5.625")
        download(root = "/data0/datasets/weatherbench", source = "weatherbench", variable = "temperature_850", dataset = "era5", resolution = "5.625")

# Temporal Forecasting - Z500

## Data Preprocessing


In [6]:
from climate_learn.utils.data import load_dataset, view

dataset_path = "/content/drive/MyDrive/Climate/.climate_tutorial/data/weatherbench/era5/5.625/" \
               if USING_COLAB else "/data0/datasets/weatherbench/data/weatherbench/era5/5.625deg/"

dataset = load_dataset(dataset_path + "geopotential/")
view(dataset)

Unnamed: 0,Array,Chunk
Bytes,34.78 GiB,892.12 MiB
Shape,"(350640, 13, 32, 64)","(8784, 13, 32, 64)"
Count,120 Tasks,40 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 34.78 GiB 892.12 MiB Shape (350640, 13, 32, 64) (8784, 13, 32, 64) Count 120 Tasks 40 Chunks Type float32 numpy.ndarray",350640  1  64  32  13,

Unnamed: 0,Array,Chunk
Bytes,34.78 GiB,892.12 MiB
Shape,"(350640, 13, 32, 64)","(8784, 13, 32, 64)"
Count,120 Tasks,40 Chunks
Type,float32,numpy.ndarray


## Data Conversion
We further convert the *NetCDF* files to *PyTorch* Dataloaders.

**Pros**: We can use the dataloaders for training and evaluating neural networks.\
**Cons**: We loose useful meta information (such as 'time', 'location') during conversion as dataloaders only allow for integer location based treatment. 

We store the useful information about the data ('lat', 'long') of the regions as _data members_ of our dataloaders. 



In [7]:
from climate_learn.utils.datetime import Year, Days, Hours
from climate_learn.data import DataModule

data_module = DataModule(
    dataset = "ERA5",
    task = "forecasting",
    root_dir = dataset_path,
    in_vars = ["geopotential_500"],
    out_vars = ["geopotential_500"],
    train_start_year = Year(1979),
    val_start_year = Year(2015),
    test_start_year = Year(2017),
    end_year = Year(2018),
    pred_range = Days(3),
    subsample = Hours(6),
    batch_size = 128,
    num_workers = 1
)

Creating train dataset


  0%|          | 0/36 [00:00<?, ?it/s]

vars
Creating val dataset


  0%|          | 0/2 [00:00<?, ?it/s]

vars
Creating test dataset


  0%|          | 0/2 [00:00<?, ?it/s]

vars


In [8]:
# from sklearn.linear_model import Ridge
print(type(data_module.train_dataset.inp_data))
print(data_module.train_dataset.inp_data.reshape(data_module.train_dataset.inp_data.shape[0], -1).shape)
print(data_module.train_dataset.out_data.shape)

x_train = data_module.train_dataset.inp_data.reshape(data_module.train_dataset.inp_data.shape[0], -1)
y_train = data_module.train_dataset.out_data.reshape(data_module.train_dataset.out_data.shape[0], -1)
x_val = data_module.val_dataset.inp_data.reshape(data_module.val_dataset.inp_data.shape[0], -1)
y_val = data_module.val_dataset.out_data.reshape(data_module.val_dataset.out_data.shape[0], -1)

# lr_model = Ridge(alpha=1.0)
# lr_model.fit(x_train, y_train)
# pred_val = lr_model.predict(x_val)
print(x_val.shape)
print(y_val.shape)
# print(pred_val.shape)

<class 'numpy.ndarray'>
(315576, 2048)
(315576, 1, 32, 64)
(17544, 2048)
(17544, 2048)


In [9]:
# print(pred_val)
# print(y_val)

## Model initialization 

In [10]:
from climate_learn.models import load_model

# model_kwargs = {
#     "img_size": [32, 64],
#     "patch_size": 2,
#     "drop_path": 0.1,
#     "drop_rate": 0.1,
#     "learn_pos_emb": True,
#     "in_vars": data_module.hparams.in_vars,
#     "out_vars": data_module.hparams.out_vars,
#     "embed_dim": 128,
#     "depth": 8,
#     "decoder_depth": 0,
#     "num_heads": 4,
#     "mlp_ratio": 4,
# }

model_kwargs = {
    "in_channels": len(data_module.hparams.in_vars),
    "out_channels": len(data_module.hparams.out_vars),
    "n_blocks": 4
}

optim_kwargs = {
    "lr": 1e-4,
    "weight_decay": 1e-5,
    "warmup_epochs": 1,
    "max_epochs": 3,
}

# model_module = load_model(name = "vit", task = "forecasting", model_kwargs = model_kwargs, optim_kwargs = optim_kwargs)
resnet_model_module = load_model(name = "resnet", task = "forecasting", model_kwargs = model_kwargs, optim_kwargs = optim_kwargs)
# unet_model_module = load_model(name = "unet", task = "forecasting", model_kwargs = model_kwargs, optim_kwargs = optim_kwargs)

In [11]:
# add_description
from climate_learn.models import set_climatology
set_climatology(resnet_model_module, data_module)
# set_climatology(unet_model_module, data_module)

In [12]:
# add_description
from climate_learn.models import fit_lin_reg_baseline
fit_lin_reg_baseline(resnet_model_module, data_module, reg_hparam=0.0)
# fit_lin_reg_baseline(unet_model_module, data_module)

## Training

In [13]:
from climate_learn.training import Trainer, WandbLogger

resnet_trainer = Trainer(
    seed = 0,
    accelerator = "gpu",
    precision = 16,
    max_epochs = 5,
    # logger = WandbLogger(project = "climate_tutorial", name = "forecast-vit")
)

unet_trainer = Trainer(
    seed = 0,
    accelerator = "gpu",
    precision = 16,
    max_epochs = 5,
    # logger = WandbLogger(project = "climate_tutorial", name = "forecast-vit")
)

In [14]:
resnet_trainer.fit(resnet_model_module, data_module)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


Output()

In [15]:
# unet_trainer.fit(unet_model_module, data_module)

## Evaluation 


Once our prediction model is trained, we want to be able to evaluate it against the ground truth labels for data samples in the test set. 

In addition to the Latitude weighted RMSE (Eq. 1), we shall look at the Anomaly Correlation Coefficient (ACC) which is defined as:

<br>
$ACC = \frac{\sum_{i,j,k}L(j)f'_{i,j,k}t'_{i,j,k}}{\sqrt{\sum_{i,j,k}L(j)f'^{2}_{i,j,k}L(j)t'^{2}_{i,j,k}}} \tag{3}$
<br>

where $'$ denotes the difference to the climatology. We define climatology as:

<br>
$climatology_{j,k} = \frac{1}{N_{time}}\sum{t_{j,k}}\tag{4}$
<br>

In [15]:
resnet_trainer.test(resnet_model_module, data_module)

Output()

  rank_zero_warn(


In [17]:
# unet_trainer.test(unet_model_module, data_module)

In [18]:
x_train = data_module.train_dataset.inp_data.reshape(data_module.train_dataset.inp_data.shape[0], -1)
y_train = data_module.train_dataset.out_data.reshape(data_module.train_dataset.out_data.shape[0], -1)
x_val = data_module.val_dataset.inp_data.reshape(data_module.val_dataset.inp_data.shape[0], -1)
y_val = data_module.val_dataset.out_data.reshape(data_module.val_dataset.out_data.shape[0], -1)

lr_model = resnet_model_module.lr_baseline
pred_val = lr_model.predict(x_val)

In [22]:
print(pred_val.shape)

(17544, 2048)


In [20]:
lr_pred = resnet_model_module.lr_baseline.predict(x_val.reshape((x_val.shape[0], -1))).reshape(y_val.shape)
# lr_pred = lr_pred[:, np.newaxis, :, :, :] # B, 1, C, H, W
lr_pred = torch.from_numpy(lr_pred).float().cuda()
print(lr_pred)

tensor([[50076.1836, 50114.4258, 50152.7969,  ..., 49876.9141, 49869.2695,
         49869.2656],
        [50099.1602, 50137.4570, 50175.7461,  ..., 49884.5977, 49884.5859,
         49892.2500],
        [50129.8398, 50168.0625, 50206.4258,  ..., 49861.5859, 49861.5938,
         49876.9219],
        ...,
        [50949.6914, 50965.0117, 50987.9688,  ..., 51210.1836, 51217.8242,
         51233.1602],
        [50926.7383, 50949.6719, 50972.6953,  ..., 51179.5273, 51202.5234,
         51217.8359],
        [50942.0469, 50964.9883, 50987.9922,  ..., 51187.1797, 51210.1797,
         51225.5078]], device='cuda:0')


## Visualization 

# We visualize the **bias**, given by the difference in the predicted and the ground truth values, to better analyze our learned model.

In [None]:
# from climate_tutorial.utils import visualize

import os
import random
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime

# more use cases for visualize, make a more intuitive api
# which split of the data? train/val/test? currently test with a random data point
# timestamp that we are visualizing?
# only timestamp works -- can infer the split, we don't have the timestamp info for now -> include it in the dataloader
# number: 5 data points

# add lat long information 
# plotly to zoom in

samples = []
def visualize(model_module, data_module, split = "test", samples = 2, save_dir = None):
    if save_dir is not None:
        os.makedirs(save_dir, exist_ok = True)

    # dataset.setup()
    dataset = eval(f"data_module.{split}_dataset")

    if(type(samples) == int):
        idxs = random.sample(range(0, len(dataset)), samples)
    elif(type(samples) == list):
        idxs = [np.searchsorted(dataset.time, np.datetime64(datetime.strptime(dt, "%Y-%m-%d:%H"))) for dt in samples]
    else:
        raise Exception("Invalid type for samples; Allowed int or list[datetime.datetime or np.datetime64]")

    fig, axes = plt.subplots(len(idxs), 4, figsize=(20, 2 * len(idxs)), squeeze = False)

    for index, idx in enumerate(idxs):
        x, y, _, _ = dataset[idx] # 1, 1, 32, 64
        pred = model_module.forward(x.unsqueeze(0)) # 1, 1, 32, 64

        inv_normalize = model_module.denormalization
        init_condition, gt = inv_normalize(x), inv_normalize(y)
        pred = inv_normalize(pred)
        bias = pred - gt

        for i, tensor in enumerate([init_condition, gt, pred, bias]):
            ax = axes[index][i]
            im = ax.imshow(tensor.detach().squeeze().cpu().numpy())
            im.set_cmap(cmap=plt.cm.RdBu)
            fig.colorbar(im, ax=ax)

        if(data_module.hparams.task == "forecasting"):
            axes[index][0].set_title("Initial condition")
            axes[index][1].set_title("Ground truth")
            axes[index][2].set_title("Prediction")
            axes[index][3].set_title("Bias")
        elif(data_module.hparams.task == "downscaling"):
            axes[index][0].set_title("Low resolution data")
            axes[index][1].set_title("High resolution data")
            axes[index][2].set_title("Downscaled")
            axes[index][3].set_title("Bias")
        else:
            raise NotImplementedError

    fig.tight_layout()
    
    if save_dir is not None:
        plt.savefig(os.path.join(save_dir, 'visualize.png'))
    else:
        plt.show()

In [None]:
visualize(resnet_model_module, data_module)

In [None]:
visualize(unet_model_module, data_module)

In [None]:
# from climate_tutorial.utils import visualize

import os
import random
import numpy as np
from datetime import datetime
from plotly.express import imshow
from plotly.subplots import make_subplots
import plotly.graph_objects as go

# more use cases for visualize, make a more intuitive api
# which split of the data? train/val/test? currently test with a random data point
# timestamp that we are visualizing?
# only timestamp works -- can infer the split, we don't have the timestamp info for now -> include it in the dataloader
# number: 5 data points

# add lat long information 
# plotly to zoom in

samples = ["2017-01-01:12", "2017-02-01:18"]
def visualize(model_module, data_module, split = "test", samples = 2):
    # dataset.setup()
    dataset = eval(f"data_module.{split}_dataset")

    if(type(samples) == int):
        idxs = random.sample(range(0, len(dataset)), samples)
    elif(type(samples) == list):
        samples = [np.datetime64(datetime.strptime(dt, "%Y-%m-%d:%H")) for dt in samples]
        idxs = [dataset.time.index(dt) for dt in samples if dt in dataset.time]
    else:
        raise Exception("Invalid type for samples; Allowed int or list[datetime.datetime or np.datetime64]")

    # print(dataset.time[idxs[0]])
    # row_titles = [datetime.strftime(None, "%Y-%m-%d:%H") for idx in idxs]

    if(data_module.hparams.task == "forecasting"):
        col_titles = ["Initial condition", "Ground truth", "Prediction", "Bias"]
    elif(data_module.hparams.task == "downscaling"):
        col_titles = ["Low resolution data", "High resolution data", "Downscaled", "Bias"]
    else:
        raise NotImplementedError

    fig = make_subplots(len(idxs), 4, subplot_titles = col_titles * len(idxs))
    for i, idx in enumerate(idxs):
        x, y, _, _ = dataset[idx] # 1, 1, 32, 64
        pred = model_module.forward(x.unsqueeze(0)) # 1, 1, 32, 64

        inv_normalize = model_module.denormalization
        init_condition, gt = inv_normalize(x), inv_normalize(y)
        pred = inv_normalize(pred)
        bias = pred - gt

        for j, tensor in enumerate([init_condition, gt, pred, bias]):
            fig.add_trace(imshow(tensor.detach().squeeze().cpu().numpy(), color_continuous_scale = "rdbu", x = dataset.inp_lon if i == 0 else dataset.out_lon, y = dataset.inp_lat if i == 0 else dataset.out_lat).data[0], row = i + 1, col = j + 1)
            # fig.colorbar(im, ax=ax)

    # fig.tight_layout()
    fig.show()

In [None]:
visualize(resnet_model_module, data_module)

In [None]:
visualize(unet_model_module, data_module)

# Temporal Forecasting - T850

## Data Preprocessing


The data is stored in the [NetCDF](https://en.wikipedia.org/wiki/NetCDF) files with _.nc_ extension. One of the distinct features of this format is the **named** specification to the coordinates and the data variables. 

As shown below, we first merge all the yearly NetCDF files, and display the structure of the format. xarray library is used to read the NetCDF files. It allows the users to manipulate data based on more informative labels instead of integer location. 



In [None]:
from climate_learn.utils.data import load_dataset, view

dataset = load_dataset(dataset_path + "temperature_850")
view(dataset)

## Data Conversion
We further convert the *NetCDF* files to *PyTorch* Dataloaders.

**Pros**: We can use the dataloaders for training and evaluating neural networks.\
**Cons**: We loose useful meta information (such as 'time', 'location') during conversion as dataloaders only allow for integer location based treatment. 

We store the useful information about the data ('lat', 'long') of the regions as _data members_ of our dataloaders. 



In [None]:
from climate_learn.utils.datetime import Year, Days, Hours
from climate_learn.data import DataModule

data_module = DataModule(
    dataset = "ERA5",
    task = "forecasting",
    root_dir = dataset_path,
    in_vars = ["temperature_850"],
    out_vars = ["temperature_850"],
    train_start_year = Year(1979),
    val_start_year = Year(2015),
    test_start_year = Year(2017),
    end_year = Year(2018),
    pred_range = Days(3),
    subsample = Hours(6),
    batch_size = 128,
    num_workers = 1
)

## Model initialization 

In [None]:
from climate_learn.models import load_model

# model_kwargs = {
#     "img_size": [32, 64],
#     "patch_size": 2,
#     "drop_path": 0.1,
#     "drop_rate": 0.1,
#     "learn_pos_emb": True,
#     "in_vars": data_module.hparams.in_vars,
#     "out_vars": data_module.hparams.out_vars,
#     "embed_dim": 128,
#     "depth": 8,
#     "decoder_depth": 0,
#     "num_heads": 4,
#     "mlp_ratio": 4,
# }

model_kwargs = {
    "in_channels": len(data_module.hparams.in_vars),
    "out_channels": len(data_module.hparams.out_vars),
    "n_blocks": 4
}

optim_kwargs = {
    "lr": 1e-4,
    "weight_decay": 1e-5,
    "warmup_epochs": 1,
    "max_epochs": 5,
}

# model_module = load_model(name = "vit", task = "forecasting", model_kwargs = model_kwargs, optim_kwargs = optim_kwargs)
resnet_model_module = load_model(name = "resnet", task = "forecasting", model_kwargs = model_kwargs, optim_kwargs = optim_kwargs)
unet_model_module = load_model(name = "unet", task = "forecasting", model_kwargs = model_kwargs, optim_kwargs = optim_kwargs)

In [None]:
# add_description
from climate_learn.models import set_climatology
set_climatology(resnet_model_module, data_module)
set_climatology(unet_model_module, data_module)

## Training

In [None]:
from climate_learn.training import Trainer, WandbLogger

resnet_trainer = Trainer(
    seed = 0,
    accelerator = "gpu",
    precision = 16,
    max_epochs = 5,
    # logger = WandbLogger(project = "climate_tutorial", name = "forecast-vit")
)

unet_trainer = Trainer(
    seed = 0,
    accelerator = "gpu",
    precision = 16,
    max_epochs = 5,
    # logger = WandbLogger(project = "climate_tutorial", name = "forecast-vit")
)

In [None]:
resnet_trainer.fit(resnet_model_module, data_module)

In [None]:
unet_trainer.fit(unet_model_module, data_module)

## Evaluation 


Once our prediction model is trained, we want to be able to evaluate it against the ground truth labels for data samples in the test set. 

In addition to the Latitude weighted RMSE (Eq. 1), we shall look at the Anomaly Correlation Coefficient (ACC) which is defined as:

<br>
$ACC = \frac{\sum_{i,j,k}L(j)f'_{i,j,k}t'_{i,j,k}}{\sqrt{\sum_{i,j,k}L(j)f'^{2}_{i,j,k}L(j)t'^{2}_{i,j,k}}} \tag{3}$
<br>

where $'$ denotes the difference to the climatology. We define climatology as:

<br>
$climatology_{j,k} = \frac{1}{N_{time}}\sum{t_{j,k}}\tag{4}$
<br>

In [None]:
resnet_trainer.test(resnet_model_module, data_module)

In [None]:
unet_trainer.test(unet_model_module, data_module)

## Visualization 

We visualize the **bias**, given by the difference in the predicted and the ground truth values, to better analyze our learned model.

In [None]:
# from climate_tutorial.utils import visualize

import os
import random
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime

# more use cases for visualize, make a more intuitive api
# which split of the data? train/val/test? currently test with a random data point
# timestamp that we are visualizing?
# only timestamp works -- can infer the split, we don't have the timestamp info for now -> include it in the dataloader
# number: 5 data points

# add lat long information 
# plotly to zoom in

samples = []
def visualize(model_module, data_module, split = "test", samples = 2, save_dir = None):
    if save_dir is not None:
        os.makedirs(save_dir, exist_ok = True)

    # dataset.setup()
    dataset = eval(f"data_module.{split}_dataset")

    if(type(samples) == int):
        idxs = random.sample(range(0, len(dataset)), samples)
    elif(type(samples) == list):
        idxs = [np.searchsorted(dataset.time, np.datetime64(datetime.strptime(dt, "%Y-%m-%d:%H"))) for dt in samples]
    else:
        raise Exception("Invalid type for samples; Allowed int or list[datetime.datetime or np.datetime64]")

    fig, axes = plt.subplots(len(idxs), 4, figsize=(20, 2 * len(idxs)), squeeze = False)

    for index, idx in enumerate(idxs):
        x, y, _, _ = dataset[idx] # 1, 1, 32, 64
        pred = model_module.forward(x.unsqueeze(0)) # 1, 1, 32, 64

        inv_normalize = model_module.denormalization
        init_condition, gt = inv_normalize(x), inv_normalize(y)
        pred = inv_normalize(pred)
        bias = pred - gt

        for i, tensor in enumerate([init_condition, gt, pred, bias]):
            ax = axes[index][i]
            im = ax.imshow(tensor.detach().squeeze().cpu().numpy())
            im.set_cmap(cmap=plt.cm.RdBu)
            fig.colorbar(im, ax=ax)

        if(data_module.hparams.task == "forecasting"):
            axes[index][0].set_title("Initial condition")
            axes[index][1].set_title("Ground truth")
            axes[index][2].set_title("Prediction")
            axes[index][3].set_title("Bias")
        elif(data_module.hparams.task == "downscaling"):
            axes[index][0].set_title("Low resolution data")
            axes[index][1].set_title("High resolution data")
            axes[index][2].set_title("Downscaled")
            axes[index][3].set_title("Bias")
        else:
            raise NotImplementedError

    fig.tight_layout()
    
    if save_dir is not None:
        plt.savefig(os.path.join(save_dir, 'visualize.png'))
    else:
        plt.show()

In [None]:
visualize(resnet_model_module, data_module)

In [None]:
visualize(unet_model_module, data_module)

In [None]:
# from climate_tutorial.utils import visualize

import os
import random
import numpy as np
from datetime import datetime
from plotly.express import imshow
from plotly.subplots import make_subplots
import plotly.graph_objects as go

# more use cases for visualize, make a more intuitive api
# which split of the data? train/val/test? currently test with a random data point
# timestamp that we are visualizing?
# only timestamp works -- can infer the split, we don't have the timestamp info for now -> include it in the dataloader
# number: 5 data points

# add lat long information 
# plotly to zoom in

samples = ["2017-01-01:12", "2017-02-01:18"]
def visualize(model_module, data_module, split = "test", samples = 2):
    # dataset.setup()
    dataset = eval(f"data_module.{split}_dataset")

    if(type(samples) == int):
        idxs = random.sample(range(0, len(dataset)), samples)
    elif(type(samples) == list):
        samples = [np.datetime64(datetime.strptime(dt, "%Y-%m-%d:%H")) for dt in samples]
        idxs = [dataset.time.index(dt) for dt in samples if dt in dataset.time]
    else:
        raise Exception("Invalid type for samples; Allowed int or list[datetime.datetime or np.datetime64]")

    # print(dataset.time[idxs[0]])
    # row_titles = [datetime.strftime(None, "%Y-%m-%d:%H") for idx in idxs]

    if(data_module.hparams.task == "forecasting"):
        col_titles = ["Initial condition", "Ground truth", "Prediction", "Bias"]
    elif(data_module.hparams.task == "downscaling"):
        col_titles = ["Low resolution data", "High resolution data", "Downscaled", "Bias"]
    else:
        raise NotImplementedError

    fig = make_subplots(len(idxs), 4, subplot_titles = col_titles * len(idxs))
    for i, idx in enumerate(idxs):
        x, y, _, _ = dataset[idx] # 1, 1, 32, 64
        pred = model_module.forward(x.unsqueeze(0)) # 1, 1, 32, 64

        inv_normalize = model_module.denormalization
        init_condition, gt = inv_normalize(x), inv_normalize(y)
        pred = inv_normalize(pred)
        bias = pred - gt

        for j, tensor in enumerate([init_condition, gt, pred, bias]):
            fig.add_trace(imshow(tensor.detach().squeeze().cpu().numpy(), color_continuous_scale = "rdbu", x = dataset.inp_lon if i == 0 else dataset.out_lon, y = dataset.inp_lat if i == 0 else dataset.out_lat).data[0], row = i + 1, col = j + 1)
            # fig.colorbar(im, ax=ax)

    # fig.tight_layout()
    fig.show()

In [None]:
visualize(resnet_model_module, data_module)

In [None]:
visualize(unet_model_module, data_module)

# Temporal Forecasting - Z500 & T850

## Data Preprocessing


The data is stored in the [NetCDF](https://en.wikipedia.org/wiki/NetCDF) files with _.nc_ extension. One of the distinct features of this format is the **named** specification to the coordinates and the data variables. 

As shown below, we first merge all the yearly NetCDF files, and display the structure of the format. xarray library is used to read the NetCDF files. It allows the users to manipulate data based on more informative labels instead of integer location. 



In [None]:
from climate_learn.utils.data import load_dataset, view
import xarray as xr

z_dataset = load_dataset("/content/drive/MyDrive/Climate/.climate_tutorial/data/weatherbench/era5/5.625/geopotential_500")
t_dataset = load_dataset("/content/drive/MyDrive/Climate/.climate_tutorial/data/weatherbench/era5/5.625/temperature_850")
dataset = xr.merge([z_dataset.drop("level"), t_dataset.drop("level")])
view(dataset)

In [None]:
view(z_dataset)

In [None]:
view(t_dataset)

## Data Conversion
We further convert the *NetCDF* files to *PyTorch* Dataloaders.

**Pros**: We can use the dataloaders for training and evaluating neural networks.\
**Cons**: We loose useful meta information (such as 'time', 'location') during conversion as dataloaders only allow for integer location based treatment. 

We store the useful information about the data ('lat', 'long') of the regions as _data members_ of our dataloaders. 



In [None]:
from climate_learn.utils.datetime import Year, Days, Hours
from climate_learn.data import DataModule

data_module = DataModule(
    dataset = "ERA5",
    task = "forecasting",
    root_dir = "/content/drive/MyDrive/Climate/.climate_tutorial/data/weatherbench/era5/5.625/",
    in_vars = ["geopotential_500", "temperature_850"],
    out_vars = ["geopotential_500", "temperature_850"],
    train_start_year = Year(1979),
    val_start_year = Year(2015),
    test_start_year = Year(2017),
    end_year = Year(2018),
    pred_range = Days(3),
    subsample = Hours(6),
    batch_size = 128,
    num_workers = 1
)

## Model initialization 

In [None]:
from climate_learn.models import load_model

# model_kwargs = {
#     "img_size": [32, 64],
#     "patch_size": 2,
#     "drop_path": 0.1,
#     "drop_rate": 0.1,
#     "learn_pos_emb": True,
#     "in_vars": data_module.hparams.in_vars,
#     "out_vars": data_module.hparams.out_vars,
#     "embed_dim": 128,
#     "depth": 8,
#     "decoder_depth": 0,
#     "num_heads": 4,
#     "mlp_ratio": 4,
# }

model_kwargs = {
    "in_channels": len(data_module.hparams.in_vars),
    "out_channels": len(data_module.hparams.out_vars),
    "n_blocks": 4
}

optim_kwargs = {
    "lr": 1e-4,
    "weight_decay": 1e-5,
    "warmup_epochs": 1,
    "max_epochs": 5,
}

# model_module = load_model(name = "vit", task = "forecasting", model_kwargs = model_kwargs, optim_kwargs = optim_kwargs)
resnet_model_module = load_model(name = "resnet", task = "forecasting", model_kwargs = model_kwargs, optim_kwargs = optim_kwargs)
unet_model_module = load_model(name = "unet", task = "forecasting", model_kwargs = model_kwargs, optim_kwargs = optim_kwargs)

In [None]:
# add_description
from climate_learn.models import set_climatology
set_climatology(resnet_model_module, data_module)
set_climatology(unet_model_module, data_module)

## Training

In [None]:
from climate_learn.training import Trainer, WandbLogger

resnet_trainer = Trainer(
    seed = 0,
    accelerator = "gpu",
    precision = 16,
    max_epochs = 5,
    # logger = WandbLogger(project = "climate_tutorial", name = "forecast-vit")
)

unet_trainer = Trainer(
    seed = 0,
    accelerator = "gpu",
    precision = 16,
    max_epochs = 5,
    # logger = WandbLogger(project = "climate_tutorial", name = "forecast-vit")
)

In [None]:
resnet_trainer.fit(resnet_model_module, data_module)

In [None]:
unet_trainer.fit(unet_model_module, data_module)

## Evaluation 


Once our prediction model is trained, we want to be able to evaluate it against the ground truth labels for data samples in the test set. 

In addition to the Latitude weighted RMSE (Eq. 1), we shall look at the Anomaly Correlation Coefficient (ACC) which is defined as:

<br>
$ACC = \frac{\sum_{i,j,k}L(j)f'_{i,j,k}t'_{i,j,k}}{\sqrt{\sum_{i,j,k}L(j)f'^{2}_{i,j,k}L(j)t'^{2}_{i,j,k}}} \tag{3}$
<br>

where $'$ denotes the difference to the climatology. We define climatology as:

<br>
$climatology_{j,k} = \frac{1}{N_{time}}\sum{t_{j,k}}\tag{4}$
<br>

In [None]:
resnet_trainer.test(resnet_model_module, data_module)

In [None]:
unet_trainer.test(unet_model_module, data_module)