<a name="software-requirements"></a>
# Software Requirements
This notebook requires the following libraries:
*   climate_tutorial (pip)

`climate_learn` contains the source files used modeling climate extremes.

The package is written using `PyTorch` machine learning library.

In [1]:
USING_COLAB = False

# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

In [2]:
import torch
torch.cuda.is_available()

True

In [3]:
import sys
sys.path.insert(0, '/home/snandy/climate-learn-sys-gen/src/')

In [4]:
if USING_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')

In [5]:
from climate_learn.data import download

DOWNLOAD = False

if DOWNLOAD:
    if USING_COLAB:
        download(root = "/content/drive/MyDrive/Climate/.climate_tutorial", source = "weatherbench", variable = "geopotential_500", dataset = "era5", resolution = "5.625")
        download(root = "/content/drive/MyDrive/Climate/.climate_tutorial", source = "weatherbench", variable = "temperature_850", dataset = "era5", resolution = "5.625")
    else:
        download(root = "/data0/datasets/weatherbench", source = "weatherbench", variable = "geopotential_500", dataset = "era5", resolution = "5.625")
        download(root = "/data0/datasets/weatherbench", source = "weatherbench", variable = "temperature_850", dataset = "era5", resolution = "5.625")

# Temporal Forecasting - Z500

## Data Preprocessing


In [6]:
from climate_learn.utils.data import load_dataset, view

dataset_path = "/content/drive/MyDrive/Climate/.climate_tutorial/data/weatherbench/era5/5.625/" \
               if USING_COLAB else "/data0/datasets/weatherbench/data/weatherbench/era5/5.625deg/"

dataset = load_dataset(dataset_path + "2m_temperature/")
view(dataset)

Unnamed: 0,Array,Chunk
Bytes,2.68 GiB,68.62 MiB
Shape,"(350640, 32, 64)","(8784, 32, 64)"
Count,120 Tasks,40 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 2.68 GiB 68.62 MiB Shape (350640, 32, 64) (8784, 32, 64) Count 120 Tasks 40 Chunks Type float32 numpy.ndarray",64  32  350640,

Unnamed: 0,Array,Chunk
Bytes,2.68 GiB,68.62 MiB
Shape,"(350640, 32, 64)","(8784, 32, 64)"
Count,120 Tasks,40 Chunks
Type,float32,numpy.ndarray


In [7]:
dataset = load_dataset(dataset_path + "geopotential_500/")
view(dataset)

Unnamed: 0,Array,Chunk
Bytes,2.68 GiB,68.62 MiB
Shape,"(350640, 32, 64)","(8784, 32, 64)"
Count,120 Tasks,40 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 2.68 GiB 68.62 MiB Shape (350640, 32, 64) (8784, 32, 64) Count 120 Tasks 40 Chunks Type float32 numpy.ndarray",64  32  350640,

Unnamed: 0,Array,Chunk
Bytes,2.68 GiB,68.62 MiB
Shape,"(350640, 32, 64)","(8784, 32, 64)"
Count,120 Tasks,40 Chunks
Type,float32,numpy.ndarray


## Data Conversion
We further convert the *NetCDF* files to *PyTorch* Dataloaders.

**Pros**: We can use the dataloaders for training and evaluating neural networks.\
**Cons**: We loose useful meta information (such as 'time', 'location') during conversion as dataloaders only allow for integer location based treatment. 

We store the useful information about the data ('lat', 'long') of the regions as _data members_ of our dataloaders. 



In [None]:
from climate_learn.utils.datetime import Year, Days, Hours
from climate_learn.data import DataModule

data_module = DataModule(
    dataset = "ERA5",
    task = "forecasting",
    root_dir = dataset_path,
    in_vars = ["temperature", "geopotential", "2m_temperature"],
    out_vars = ["temperature_850", "geopotential_500", "2m_temperature"],
    train_start_year = Year(1979),
    val_start_year = Year(2016),
    test_start_year = Year(2017),
    end_year = Year(2018),
    pred_range = Days(3),
    subsample = Hours(6),
    batch_size = 32,
    num_workers = 8,
)

Creating train dataset


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 37/37 [00:00<00:00, 58.50it/s]


input NP conversion


In [None]:
import numpy as np
lat_grid = np.tile(data_module.train_dataset.lat.reshape(-1, 1), (1, 64))
lon_grid = np.tile(data_module.train_dataset.lon, (32, 1))
lat_grid = np.repeat(lat_grid[np.newaxis, np.newaxis, :, :], 3, axis=0)
lon_grid = np.repeat(lon_grid[np.newaxis, np.newaxis, :, :], 4, axis=0)
# print(lat_grid.shape)
# print(lon_grid.shape)
# print(np.concatenate((lat_grid, lon_grid)).shape)
# print(data_module.train_dataset.lat.shape)

modified = np.repeat(data_module.train_dataset.lat[np.newaxis, np.newaxis, 0:16, np.newaxis], 10, axis=0)
modified2 = np.repeat(data_module.train_dataset.lat[np.newaxis, np.newaxis, 16:, np.newaxis], 10, axis=0)
# print(modified2.shape)
# print(np.concatenate((modified, modified2)).shape)

print(data_module.train_dataset.inp_transform)

print(data_module.train_dataset[0][0].shape)

# print(data_module.train_dataset[0][0][:,1].mean())
# print(data_module.train_dataset[0][0][:,2].mean())
# print(data_module.train_dataset[0][0][0,1,0,0])
# print(data_module.train_dataset[0][0][0,2,0,0])


## Model initialization 

In [None]:
from climate_learn.models import load_model

# model_kwargs = {
#     "img_size": [32, 64],
#     "patch_size": 2,
#     "drop_path": 0.1,
#     "drop_rate": 0.1,
#     "learn_pos_emb": True,
#     "in_vars": data_module.hparams.in_vars,
#     "out_vars": data_module.hparams.out_vars,
#     "embed_dim": 128,
#     "depth": 8,
#     "decoder_depth": 0,
#     "num_heads": 4,
#     "mlp_ratio": 4,
# }
model_kwargs = {
    "in_channels": len(data_module.hparams.in_vars),
    "out_channels": len(data_module.hparams.out_vars),
    "n_blocks": 4
}

optim_kwargs = {
    "lr": 1e-4,
    "weight_decay": 1e-5,
    "warmup_epochs": 1,
    "max_epochs": 10,
}

# model_module = load_model(name = "vit", task = "forecasting", model_kwargs = model_kwargs, optim_kwargs = optim_kwargs)
resnet_model_module = load_model(name = "resnet", task = "forecasting", model_kwargs = model_kwargs, optim_kwargs = optim_kwargs)
#unet_model_module = load_model(name = "unet", task = "forecasting", model_kwargs = model_kwargs, optim_kwargs = optim_kwargs)

In [None]:
# add_description
from climate_learn.models import set_climatology
set_climatology(resnet_model_module, data_module)
#set_climatology(unet_model_module, data_module)

## Training

In [None]:
from climate_learn.training import Trainer, WandbLogger

resnet_trainer = Trainer(
    seed = 0,
    accelerator = "gpu",
    precision = 16,
    max_epochs = 5,
    # logger = WandbLogger(project = "climate_learn", name = "forecast-vit")
)

unet_trainer = Trainer(
    seed = 0,
    accelerator = "gpu",
    precision = 16,
    max_epochs = 5,
    # logger = WandbLogger(project = "climate_learn", name = "forecast-vit")
)

In [None]:
resnet_trainer.fit(resnet_model_module, data_module)

In [None]:
# unet_trainer.fit(unet_model_module, data_module)

## Evaluation 


Once our prediction model is trained, we want to be able to evaluate it against the ground truth labels for data samples in the test set. 

In addition to the Latitude weighted RMSE (Eq. 1), we shall look at the Anomaly Correlation Coefficient (ACC) which is defined as:

<br>
$ACC = \frac{\sum_{i,j,k}L(j)f'_{i,j,k}t'_{i,j,k}}{\sqrt{\sum_{i,j,k}L(j)f'^{2}_{i,j,k}L(j)t'^{2}_{i,j,k}}} \tag{3}$
<br>

where $'$ denotes the difference to the climatology. We define climatology as:

<br>
$climatology_{j,k} = \frac{1}{N_{time}}\sum{t_{j,k}}\tag{4}$
<br>

In [None]:
resnet_trainer.test(resnet_model_module, data_module)

In [None]:
#unet_trainer.test(unet_model_module, data_module)

## Visualization 

We visualize the **bias**, given by the difference in the predicted and the ground truth values, to better analyze our learned model.

In [None]:
# from climate_learn.utils import visualize

import os
import random
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime

# more use cases for visualize, make a more intuitive api
# which split of the data? train/val/test? currently test with a random data point
# timestamp that we are visualizing?
# only timestamp works -- can infer the split, we don't have the timestamp info for now -> include it in the dataloader
# number: 5 data points

# add lat long information 
# plotly to zoom in

samples = []
def visualize(model_module, data_module, split = "test", samples = 2, save_dir = None):
    if save_dir is not None:
        os.makedirs(save_dir, exist_ok = True)

    # dataset.setup()
    dataset = eval(f"data_module.{split}_dataset")

    if(type(samples) == int):
        idxs = random.sample(range(0, len(dataset)), samples)
    elif(type(samples) == list):
        idxs = [np.searchsorted(dataset.time, np.datetime64(datetime.strptime(dt, "%Y-%m-%d:%H"))) for dt in samples]
    else:
        raise Exception("Invalid type for samples; Allowed int or list[datetime.datetime or np.datetime64]")

    fig, axes = plt.subplots(len(idxs), 4, figsize=(20, 2 * len(idxs)), squeeze = False)

    for index, idx in enumerate(idxs):
        x, y, _, _ = dataset[idx] # 1, 1, 32, 64
        pred = model_module.forward(x.unsqueeze(0)) # 1, 1, 32, 64

        inv_normalize = model_module.denormalization
        init_condition, gt = inv_normalize(x), inv_normalize(y)
        pred = inv_normalize(pred)
        bias = pred - gt

        for i, tensor in enumerate([init_condition, gt, pred, bias]):
            ax = axes[index][i]
            im = ax.imshow(tensor.detach().squeeze().cpu().numpy())
            im.set_cmap(cmap=plt.cm.RdBu)
            fig.colorbar(im, ax=ax)

        if(data_module.hparams.task == "forecasting"):
            axes[index][0].set_title("Initial condition")
            axes[index][1].set_title("Ground truth")
            axes[index][2].set_title("Prediction")
            axes[index][3].set_title("Bias")
        elif(data_module.hparams.task == "downscaling"):
            axes[index][0].set_title("Low resolution data")
            axes[index][1].set_title("High resolution data")
            axes[index][2].set_title("Downscaled")
            axes[index][3].set_title("Bias")
        else:
            raise NotImplementedError

    fig.tight_layout()
    
    if save_dir is not None:
        plt.savefig(os.path.join(save_dir, 'visualize.png'))
    else:
        plt.show()

In [None]:
visualize(resnet_model_module, data_module)

In [None]:
visualize(unet_model_module, data_module)

In [None]:
# from climate_tutorial.utils import visualize

import os
import random
import numpy as np
from datetime import datetime
from plotly.express import imshow
from plotly.subplots import make_subplots
import plotly.graph_objects as go

# more use cases for visualize, make a more intuitive api
# which split of the data? train/val/test? currently test with a random data point
# timestamp that we are visualizing?
# only timestamp works -- can infer the split, we don't have the timestamp info for now -> include it in the dataloader
# number: 5 data points

# add lat long information 
# plotly to zoom in

samples = ["2017-01-01:12", "2017-02-01:18"]
def visualize(model_module, data_module, split = "test", samples = 2):
    # dataset.setup()
    dataset = eval(f"data_module.{split}_dataset")

    if(type(samples) == int):
        idxs = random.sample(range(0, len(dataset)), samples)
    elif(type(samples) == list):
        samples = [np.datetime64(datetime.strptime(dt, "%Y-%m-%d:%H")) for dt in samples]
        idxs = [dataset.time.index(dt) for dt in samples if dt in dataset.time]
    else:
        raise Exception("Invalid type for samples; Allowed int or list[datetime.datetime or np.datetime64]")

    # print(dataset.time[idxs[0]])
    # row_titles = [datetime.strftime(None, "%Y-%m-%d:%H") for idx in idxs]

    if(data_module.hparams.task == "forecasting"):
        col_titles = ["Initial condition", "Ground truth", "Prediction", "Bias"]
    elif(data_module.hparams.task == "downscaling"):
        col_titles = ["Low resolution data", "High resolution data", "Downscaled", "Bias"]
    else:
        raise NotImplementedError

    fig = make_subplots(len(idxs), 4, subplot_titles = col_titles * len(idxs))
    for i, idx in enumerate(idxs):
        x, y, _, _ = dataset[idx] # 1, 1, 32, 64
        pred = model_module.forward(x.unsqueeze(0)) # 1, 1, 32, 64

        inv_normalize = model_module.denormalization
        init_condition, gt = inv_normalize(x), inv_normalize(y)
        pred = inv_normalize(pred)
        bias = pred - gt

        for j, tensor in enumerate([init_condition, gt, pred, bias]):
            fig.add_trace(imshow(tensor.detach().squeeze().cpu().numpy(), color_continuous_scale = "rdbu", x = dataset.inp_lon if i == 0 else dataset.out_lon, y = dataset.inp_lat if i == 0 else dataset.out_lat).data[0], row = i + 1, col = j + 1)
            # fig.colorbar(im, ax=ax)

    # fig.tight_layout()
    fig.show()

In [None]:
visualize(resnet_model_module, data_module)

In [None]:
visualize(unet_model_module, data_module)