In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys

In [None]:
sys.path.append("/path/to/repo/explaining_cirrus") # add path to where you cloned the model

In [None]:
from pprint import pprint
from copy import deepcopy
import gc

import pandas as pd
import xarray as xr
import numpy as np
import scipy

import torch
import torch.nn as nn
import torch.nn.functional as F

import sklearn
from sklearn.preprocessing import StandardScaler, PowerTransformer

from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmRestarts, ExponentialLR

import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning import loggers as pl_loggers

from src.ml_pipeline.instantaneous.ml_preprocess import split_train_val_test, oh_encoding, CAT_VARS
from src.ml_pipeline.temporal.lstm_model import LSTMRegressor, LogCallback
from src.ml_pipeline.temporal.data_module import BacktrajDataModule, BacktrajDataset
from src.ml_pipeline.temporal.custom_loss_functions import *

In [None]:
import datetime

# Train & Evaluate LSTM+Attention Model on Backtrajectory dataset

We recommend using a machine with a GPU

## Load Data

In [None]:
TEMPORAL_DATASET_PATH = "/path/to/temporal_dataset" # specify path where the temporal dataset is stored

In [None]:
all_months = [f"{year}{month:02d}" for year in range(2007,2010) for month in range(1,13)]

In [None]:
print("available months:", all_months)

In [None]:
dtypes = {'lev': '<f8',
 'lat': '<f8',
 'lon': '<f8',
 'latr': '<f8',
 'lonr': '<f8',
 'timestep': '<f8',
 'trajectory_id': 'O',
 'cloud_cover': '<f8',
 't': '<f8',
 'w': '<f8',
 'u': '<f8',
 'v': '<f8',
 'rh_ice': '<f8',
 'SO4': '<f4',
 'land_water_mask': '<f8',
 'season': 'O',
 'nightday_flag': '<f8',
 'instrument_flag': '<f8',
 'dz_top_v2': '<f8',
 'iwc': '<f8',
 'icnc_5um': '<f8',
 'reffcli': '<f8',
 'lat_region': '<i8',
 'lon_region': '<i8',
 'DU': '<f4',
 'DU_sub': '<f4',
 'DU_sup': '<f4',
 'wind_speed': '<f8',
 'surface_height': '<f4',
 'cloud_thickness_v2': '<f8',
 'year_month': 'O'}
parse_dates = ['time', 'date']

In [None]:
# caution I: loading all months requires a lot of memory
# caution II: a minimum of 8 months has to be selected when using train_size=0.8
month_dfs = []
for month in all_months:
    print(month)
    month_df = pd.read_csv(TEMPORAL_DATASET_PATH + f"/temporal_{month}.csv".format(month), dtype=dtypes, parse_dates=parse_dates)
    month_dfs.append(month_df)

In [None]:
df = pd.concat(month_dfs)

## Define Hyperparameters and other config

In [None]:
sequential_features = ["t", "w", "wind_speed", "DU_sup", "DU_sub", 'SO4', "surface_height"]  # ,"cc_traj","IWC","RWC","LWC","SWC"]
static_features = ["land_water_mask", "season", "dz_top_v2", "cloud_thickness_v2"]
coord_vars = ['lev', 'lat', 'lon', 'latr', 'lonr', 'time', 'date', 'timestep', 'trajectory_id', 'cloud_cover']
predictands = ['iwc', 'icnc_5um']

In [None]:
experiment_config = dict(
    predictands=predictands,  # list of predictands
    sequential_features=sequential_features,  # list of temporally resolved features
    static_features=static_features,  # list of static features
    mtl_weighting_type="equal", # weighting between predictand losses
    criterion=RMSELoss(),  # initialized loss class
    lstm_hparams={"num_layers": 1, "hidden_size": 250, "dropout": 0, "attention": True}, # lstm architecture choices
    static_branch_hparams={"layer_sizes": [50], "dropout": 0.5, "batchnorm": True}, # static branch architecture choices
    final_fc_layers_hparams={"layer_sizes": [100, 50], "dropout": 0.5, "batchnorm": True}, # final fc architecture choices
    reweight="none",  # if sample based reweighting (i.e deep imbalanced regression) → reweighting mechansim
    multiple_predictand_reweight_type="individual", # if sample based reweighting combined with multi-task learning: calculating sample weights 'individual' for each predictand or based on 1 'lead_predictand'
    reweight_lead_predictand="iwc",
    reweight_bin_width=10,
    lds=False,
    lds_kernel="gaussian",
    lds_ks=5,
    lds_sigma=2,
    data_filters=[],  # conditions dataframe will be filtered on e.g. cloud_cover>0.8
    sequential_scaler=StandardScaler(),
    static_scaler=StandardScaler(),
    regional_feature_resolution=10,
    backtraj_timesteps=48,
    train_size=0.5,
    batch_size=1000,
    num_workers=1,
    learning_rate=1e-5,
    lr_scheduler=False, # needs to be a learning_rate scheduler according to https://pytorch-lightning.readthedocs.io/en/stable/_modules/pytorch_lightning/core/lightning.html#LightningModule.configure_optimizers
    grad_clip=0.5,
    early_stopping=False,
    max_epochs=5
)

## Init, Train, Evaluate Model

In [None]:
callbacks = []
if experiment_config["early_stopping"]:
    early_stop_callback = pl.callbacks.early_stopping.EarlyStopping(monitor="val_loss", min_delta=0.001,
                                                                    patience=15)
    callbacks.append(early_stop_callback)

trainer = Trainer(
    callbacks=callbacks,
    max_epochs=experiment_config['max_epochs'],
    log_every_n_steps=100,
    progress_bar_refresh_rate=1000,
    gradient_clip_val=experiment_config["grad_clip"],
    accelerator="gpu",
    devices=-1,
    #     fast_dev_run=1
)

In [None]:
dm = BacktrajDataModule(
    traj_df=df,
    data_filters=experiment_config["data_filters"],
    predictands=experiment_config['predictands'],
    sequential_features=experiment_config['sequential_features'],
    static_features=experiment_config["static_features"],
    sequential_scaler=experiment_config['sequential_scaler'],
    static_scaler=experiment_config['static_scaler'],
    batch_size=experiment_config['batch_size'],
    train_size=experiment_config["train_size"],
    num_workers=experiment_config["num_workers"],
    regional_feature_resolution=experiment_config["regional_feature_resolution"],
    reweight=experiment_config["reweight"],
    multiple_predictand_reweight_type=experiment_config["multiple_predictand_reweight_type"],
    reweight_lead_predictand=experiment_config["reweight_lead_predictand"],
    reweight_bin_width=experiment_config["reweight_bin_width"],
    lds=experiment_config["lds"],
    lds_kernel=experiment_config["lds_kernel"],
    lds_ks=experiment_config["lds_ks"],
    lds_sigma=experiment_config["lds_sigma"],
    backtraj_timesteps=experiment_config["backtraj_timesteps"]
)

In [None]:
model = LSTMRegressor(
    predictands=experiment_config["predictands"],
    n_sequential_features=len(dm.sequential_features),
    n_static_features=len(dm.static_features),
    lstm_hparams=experiment_config["lstm_hparams"],
    static_branch_hparams=experiment_config["static_branch_hparams"],
    final_fc_layers_hparams=experiment_config["final_fc_layers_hparams"],
    batch_size=experiment_config["batch_size"],
    learning_rate=experiment_config["learning_rate"],
    lr_scheduler=experiment_config["lr_scheduler"],
    criterion=experiment_config["criterion"],
    grad_clip=experiment_config["grad_clip"])

In [None]:
### train model ###
print("start training")
trainer.fit(model, dm)
# evaluate
trainer.test(model, datamodule=dm)
torch.save(trainer.model.state_dict(), 'lstm_model')

## Plot Attention weights

retrieve attention weights for test data and plot mean attention weight per time step

In [None]:
# import plotting library
import hvplot.pandas
import holoviews as hv
hv.extension('matplotlib')

from bokeh.resources import INLINE
import bokeh.io

bokeh.io.output_notebook(INLINE)z

In [None]:
test_dataloader = dm.test_dataloader()
train_dataloader = dm.train_dataloader()

In [None]:
# create tensors of test dataset
torch.multiprocessing.set_sharing_strategy('file_system')
X_seq_test=[]
X_static_test=[]
y_test=[]
weights_test=[]
coords_test=[]

for batch in test_dataloader:
    X_seq, X_static, y, weights, coords = batch
    X_seq_test.append(X_seq)
    X_static_test.append(X_static)
    y_test.append(y)
    weights_test.append(weights)
    coords_test.append(coords)
    

X_seq_test=torch.concat(X_seq_test)
X_static_test=torch.concat(X_static_test)
y_test = torch.concat(y_test)
weights_test = torch.concat(weights_test)
coords_test = torch.concat(coords_test)

In [None]:
with torch.no_grad():
    lstm_out, (hn, cn) = model.lstm(X_seq_test) # lstm_out (N, T, hidden_size)
    #lstm_out = lstm_out.to(torch.device,("cuda:0"))

with torch.no_grad():
    z, alpha = model.attention_module(lstm_out)

attention_weights = alpha.cpu().numpy()

In [None]:
# calc mean and sd of attention weight per timestep
mean_attention = np.mean(attention_weights,axis=(0,2))
sd_attention = np.std(attention_weights,axis=(0,2))

attention_df = pd.DataFrame(columns=["mean", "lower", "upper"])
attention_df["mean"] = mean_attention

attention_df["lower"] = mean_attention - 0.5 * sd_attention
attention_df["upper"] = mean_attention + sd_attention

attention_df = attention_df.reset_index()
attention_df = attention_df.rename(columns={"index":"timestep"})


In [None]:
# create attention plot
plt_options = {'fontsize': {'xlabel': '20px',
  'ylabel': '15px',
  'ticks': '15px',
  'legend': '30px'},
 'cmap': 'Colorblind',
 'width': 900,
 'height': 200,
 'line_width':3       
}

attention_df.timestep -= 48
attention_plt = attention_df.hvplot.line(x="timestep", y="mean", color="orange", **plt_options) * attention_df.hvplot.area(x="timestep",y="lower",y2="upper", color="orange", line_alpha=0, fill_alpha=0.2, xlabel="timestep [h]", ylabel="mean attention weight", **plt_options)

attention_plt