## Chlorophyll-a forecasting using LSTM Model



#### Getting Started:
1. Before running the notebook, please make sure to have the following python version and libraries are installed <br>
- python 3.9.12
- pytorch (https://pytorch.org/get-started/locally/)

2. Create an account in Weights and Biases (WANDB) (https://wandb.ai/home). While running the notebook, you maybe prompted to enter the WANDB username

<br>
The requirements.txt file lists the basic libraries require. Running the following cell should install all of them (in case they are not already installed). 

In case, any library is missed here, you would be prompted with an ImportError. In such case, just install it with pip (google -> pip install library_name)

In [1]:
!pip install -r requirements.txt



In [2]:
import random
import pandas as pd
import numpy as np
from tqdm import trange
import os
import datetime
import matplotlib.pyplot as plt
import math

import torch
import torch.nn as nn
from torch import optim

from utils import Utils
from encoder_decoder import seq2seq

import warnings
warnings.filterwarnings('ignore')

[34m[1mwandb[0m: Currently logged in as: [33mrladwig[0m ([33mcomputational-limnology[0m). Use [1m`wandb login --relogin`[0m to force relogin


## 0. GPU Selection
Check if GPU is available on the machine the notebook is running. If yes, then assign a GPU, else run it on CPU

In [3]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print(device)

cuda


## 1. Parameter setting

#### Specify the wandb project and wandb run
wandb refers to Weights and Biases. Integrating this tool into the notebook will allow it to access the run details and generate train and test curves, among many other information

In [4]:
# wandb project name
wandb_project = "mcl_lstm"

# wandb run name
wandb_run = "mix_run_{}_{}".format(str(datetime.datetime.now().date()), str(datetime.datetime.now().time()))

# Yes if we want wandb to save our python code, else no
save_code = True

#### Specify the input path (where the dataset is stored) and dataset name
Note: For different dataset, the processing/handling can/will be different. In this notebook, FCR (observational) data has been considered. It also has a metadata file that stores the column names and types. 
<br>
For the purpose of the tutorial, the notebook is kept simple, hence, going with FCR data for now

In [5]:
# Input path
path = './'

# Name of the file
file = '../1_trainingData/all_data_lake_modeling_in_time_perRow.csv'

# Name of the metadata file
#../metadata = 'LSTM_dataset_column_key_07OCT22.csv'

#### Specify the Time-series specific parameters

In [6]:
# Lookback window
input_window = 24*3

# horizon window
output_window = 24*3

# stride - While creating samples (lookback window + horizon window = 1 sample) define the amount of stride the sliding window needs to take
stride = 1

# The ratio in which train and test data is split. If it is 0.8, then first 80% of data goes into train and remaining 20% into test
split_ratio = 0.6

#### Specify the model specific parameters

In [7]:
# Types of Model include: LSTM, GRU, RNN
model_type = "LSTM"

# Number of layers in our deep learning model
num_layers = 2

# Hidden cell (RNN/LSTM/GRU) size
hidden_feature_size = 32

# Output size of our encoder_decoder model, i.e. number of target variables
output_size = 50

'''
Model Training parameters
'''
# batch_size during training
batch_size = 32

# Number of epochs we want to train the model for (1 epoch = 1 pass of the complete training data through the model)
epochs = 100

# Learning rate specifies the rate at which we want to update the model parameters after every training pass
learning_rate = 0.001

# Eval freq says how frequently during training do you want to evaluate your model on the validation data (to see its performance on non-training data)
eval_freq = 1 # logic is -> if iteration_num % eval_freq == 0 -> then perform evaluation

# While generating the training batches do we want the generator to shuffle the batches?
batch_shuffle = True

# Dropout is a form of regularization
dropout = 0.0

'''
Learning rate scheduler parameters
'''
max_lr=5e-3
div_factor=100
pct_start=0.05 
anneal_strategy='cos'
final_div_factor=10000.0

'''
Parameters for early stopping
'''
# Set to True if we want Early stopping
early_stop = False

# If there is no improvement for a 'thres' number of epocs stop the training process
thres=5

# Quantifying the improvement. If the validation loss is greater than min_val_loss_so_far + delta for thres number of iterations stop the training
delta=0.5

'''
Other parameters
'''
# Specify the amount of L2 regularization to be applied.
weight_decay=0.0

# Specify the percentage of times we want to enforce teacher forcing
teacher_forcing_ratio = 0.0
training_prediction = 'recursive'

## 2. Data Processing

#### Read the metadata file

In [8]:
depth_steps = 25 * 2 

depth_list = np.array(list(range(1, depth_steps+1))   )*0.5



In [9]:
incoming_temp = ['temp_conv04_{}'.format(x) for x in depth_list]
outgoing_temp = ['temp_mix05_{}'.format(x) for x in depth_list]

dx = pd.read_csv(os.path.join(path,file))

#feature_cols = ['AirTemp_degC', 'Longwave_Wm-2', 'Latent_Wm-2', 'Sensible_Wm-2', 'Shortwave_Wm-2',
#                'lightExtinct_m-1', 'ShearStress_Nm-2',
#                 'day_of_year', 'time_of_day', 'ice', 'snow', 'snowice','Volume_m2','Osgood','MaxDepth_m',
#                'MeanDepth_m','Area_m2'] + incoming_temp

feature_cols = ['AirTemp_degC', 'Longwave_Wm-2', 'Latent_Wm-2', 'Sensible_Wm-2', 'Shortwave_Wm-2',
                'lightExtinct_m-1', 'ShearStress_Nm-2',
                 'day_of_year', 'time_of_day', 'ice', 'snow', 'snowice'] + incoming_temp

date_col = ['time']

target_cols = outgoing_temp

In [10]:
#dx = pd.read_csv(os.path.join(path, metadata))

# Extract all col names from Metadata
#feature_cols = dx[dx['column_type']=='feature']['column_names'].tolist()  # feature colums represent the input drivers
#target_cols = dx[dx['column_type']=='target']['column_names'].tolist()   # target column represent the chlorophyll values
#date_col = dx[dx['column_type']=='date']['column_names'].tolist()[0]    # date column stores the date timeline

In [11]:
# Specify whether we want to add chlorophyll to the input feature set
#feature_cols += target_cols
feature_cols

['AirTemp_degC',
 'Longwave_Wm-2',
 'Latent_Wm-2',
 'Sensible_Wm-2',
 'Shortwave_Wm-2',
 'lightExtinct_m-1',
 'ShearStress_Nm-2',
 'day_of_year',
 'time_of_day',
 'ice',
 'snow',
 'snowice',
 'temp_conv04_0.5',
 'temp_conv04_1.0',
 'temp_conv04_1.5',
 'temp_conv04_2.0',
 'temp_conv04_2.5',
 'temp_conv04_3.0',
 'temp_conv04_3.5',
 'temp_conv04_4.0',
 'temp_conv04_4.5',
 'temp_conv04_5.0',
 'temp_conv04_5.5',
 'temp_conv04_6.0',
 'temp_conv04_6.5',
 'temp_conv04_7.0',
 'temp_conv04_7.5',
 'temp_conv04_8.0',
 'temp_conv04_8.5',
 'temp_conv04_9.0',
 'temp_conv04_9.5',
 'temp_conv04_10.0',
 'temp_conv04_10.5',
 'temp_conv04_11.0',
 'temp_conv04_11.5',
 'temp_conv04_12.0',
 'temp_conv04_12.5',
 'temp_conv04_13.0',
 'temp_conv04_13.5',
 'temp_conv04_14.0',
 'temp_conv04_14.5',
 'temp_conv04_15.0',
 'temp_conv04_15.5',
 'temp_conv04_16.0',
 'temp_conv04_16.5',
 'temp_conv04_17.0',
 'temp_conv04_17.5',
 'temp_conv04_18.0',
 'temp_conv04_18.5',
 'temp_conv04_19.0',
 'temp_conv04_19.5',
 'temp_co

In [12]:
target_cols

['temp_mix05_0.5',
 'temp_mix05_1.0',
 'temp_mix05_1.5',
 'temp_mix05_2.0',
 'temp_mix05_2.5',
 'temp_mix05_3.0',
 'temp_mix05_3.5',
 'temp_mix05_4.0',
 'temp_mix05_4.5',
 'temp_mix05_5.0',
 'temp_mix05_5.5',
 'temp_mix05_6.0',
 'temp_mix05_6.5',
 'temp_mix05_7.0',
 'temp_mix05_7.5',
 'temp_mix05_8.0',
 'temp_mix05_8.5',
 'temp_mix05_9.0',
 'temp_mix05_9.5',
 'temp_mix05_10.0',
 'temp_mix05_10.5',
 'temp_mix05_11.0',
 'temp_mix05_11.5',
 'temp_mix05_12.0',
 'temp_mix05_12.5',
 'temp_mix05_13.0',
 'temp_mix05_13.5',
 'temp_mix05_14.0',
 'temp_mix05_14.5',
 'temp_mix05_15.0',
 'temp_mix05_15.5',
 'temp_mix05_16.0',
 'temp_mix05_16.5',
 'temp_mix05_17.0',
 'temp_mix05_17.5',
 'temp_mix05_18.0',
 'temp_mix05_18.5',
 'temp_mix05_19.0',
 'temp_mix05_19.5',
 'temp_mix05_20.0',
 'temp_mix05_20.5',
 'temp_mix05_21.0',
 'temp_mix05_21.5',
 'temp_mix05_22.0',
 'temp_mix05_22.5',
 'temp_mix05_23.0',
 'temp_mix05_23.5',
 'temp_mix05_24.0',
 'temp_mix05_24.5',
 'temp_mix05_25.0']

#### Create an utility object
An object of the Utils class, it contains all the utility functions like splitting train and test data, normalizing the data, etc.

In [13]:
'''
Utility instance - to perform data processing, train test split
'''
utils = Utils(num_features=len(feature_cols), inp_cols=feature_cols, target_cols=target_cols, date_col=date_col,
              input_window=input_window, output_window=output_window, num_out_features=output_size, stride=stride)

#### Read the dataset

In [14]:
'''
Read data
'''
df = pd.read_csv(path+file)

#### Train Test split
Ideally, a 3-way split is done - train, val and test. The validation split is generally used to tune the hyper-parameters during training. Once the hyper-parameters are tuned, the model
is re-trained on the train+val data. To keep the notebook short and simple, hyper-parameter tuning is not included

In [15]:
'''
Split data into train and test
'''
df_train, df_test = utils.train_test_split(df, split_ratio=split_ratio)

In [16]:
def cycle_encode(x, period):
    sin = np.sin(2*math.pi*x/period)
    cos = np.cos(2*math.pi*x/period)
    
    return sin, cos

In [17]:
train_doy_sin, train_doy_cos = cycle_encode(df_train.day_of_year.values, 365)

In [18]:
train_tod_sin, train_tod_cos = cycle_encode(df_train.time_of_day.values, 24)

In [19]:
test_doy_sin, test_doy_cos = cycle_encode(df_test.day_of_year.values, 365)

In [20]:
test_tod_sin, test_tod_cos = cycle_encode(df_test.time_of_day.values, 24)

In [21]:
feature_cols.remove('day_of_year')
feature_cols.remove('time_of_day')
feature_cols

['AirTemp_degC',
 'Longwave_Wm-2',
 'Latent_Wm-2',
 'Sensible_Wm-2',
 'Shortwave_Wm-2',
 'lightExtinct_m-1',
 'ShearStress_Nm-2',
 'ice',
 'snow',
 'snowice',
 'temp_conv04_0.5',
 'temp_conv04_1.0',
 'temp_conv04_1.5',
 'temp_conv04_2.0',
 'temp_conv04_2.5',
 'temp_conv04_3.0',
 'temp_conv04_3.5',
 'temp_conv04_4.0',
 'temp_conv04_4.5',
 'temp_conv04_5.0',
 'temp_conv04_5.5',
 'temp_conv04_6.0',
 'temp_conv04_6.5',
 'temp_conv04_7.0',
 'temp_conv04_7.5',
 'temp_conv04_8.0',
 'temp_conv04_8.5',
 'temp_conv04_9.0',
 'temp_conv04_9.5',
 'temp_conv04_10.0',
 'temp_conv04_10.5',
 'temp_conv04_11.0',
 'temp_conv04_11.5',
 'temp_conv04_12.0',
 'temp_conv04_12.5',
 'temp_conv04_13.0',
 'temp_conv04_13.5',
 'temp_conv04_14.0',
 'temp_conv04_14.5',
 'temp_conv04_15.0',
 'temp_conv04_15.5',
 'temp_conv04_16.0',
 'temp_conv04_16.5',
 'temp_conv04_17.0',
 'temp_conv04_17.5',
 'temp_conv04_18.0',
 'temp_conv04_18.5',
 'temp_conv04_19.0',
 'temp_conv04_19.5',
 'temp_conv04_20.0',
 'temp_conv04_20.5',

#### Normalize the data
Standard normalization - 0 mean and 1 standard deviation

In [22]:
utils.inp_cols

['AirTemp_degC',
 'Longwave_Wm-2',
 'Latent_Wm-2',
 'Sensible_Wm-2',
 'Shortwave_Wm-2',
 'lightExtinct_m-1',
 'ShearStress_Nm-2',
 'ice',
 'snow',
 'snowice',
 'temp_conv04_0.5',
 'temp_conv04_1.0',
 'temp_conv04_1.5',
 'temp_conv04_2.0',
 'temp_conv04_2.5',
 'temp_conv04_3.0',
 'temp_conv04_3.5',
 'temp_conv04_4.0',
 'temp_conv04_4.5',
 'temp_conv04_5.0',
 'temp_conv04_5.5',
 'temp_conv04_6.0',
 'temp_conv04_6.5',
 'temp_conv04_7.0',
 'temp_conv04_7.5',
 'temp_conv04_8.0',
 'temp_conv04_8.5',
 'temp_conv04_9.0',
 'temp_conv04_9.5',
 'temp_conv04_10.0',
 'temp_conv04_10.5',
 'temp_conv04_11.0',
 'temp_conv04_11.5',
 'temp_conv04_12.0',
 'temp_conv04_12.5',
 'temp_conv04_13.0',
 'temp_conv04_13.5',
 'temp_conv04_14.0',
 'temp_conv04_14.5',
 'temp_conv04_15.0',
 'temp_conv04_15.5',
 'temp_conv04_16.0',
 'temp_conv04_16.5',
 'temp_conv04_17.0',
 'temp_conv04_17.5',
 'temp_conv04_18.0',
 'temp_conv04_18.5',
 'temp_conv04_19.0',
 'temp_conv04_19.5',
 'temp_conv04_20.0',
 'temp_conv04_20.5',

In [23]:
'''
Data Scaling
'''
df_train = utils.normalize(df_train)

In [24]:
df_test = utils.normalize(df_test, use_stat=True)

In [25]:
feature_cols += ['doy_sin', 'doy_cos', 'tod_sin', 'tod_cos']

In [26]:
utils.inp_cols

['AirTemp_degC',
 'Longwave_Wm-2',
 'Latent_Wm-2',
 'Sensible_Wm-2',
 'Shortwave_Wm-2',
 'lightExtinct_m-1',
 'ShearStress_Nm-2',
 'ice',
 'snow',
 'snowice',
 'temp_conv04_0.5',
 'temp_conv04_1.0',
 'temp_conv04_1.5',
 'temp_conv04_2.0',
 'temp_conv04_2.5',
 'temp_conv04_3.0',
 'temp_conv04_3.5',
 'temp_conv04_4.0',
 'temp_conv04_4.5',
 'temp_conv04_5.0',
 'temp_conv04_5.5',
 'temp_conv04_6.0',
 'temp_conv04_6.5',
 'temp_conv04_7.0',
 'temp_conv04_7.5',
 'temp_conv04_8.0',
 'temp_conv04_8.5',
 'temp_conv04_9.0',
 'temp_conv04_9.5',
 'temp_conv04_10.0',
 'temp_conv04_10.5',
 'temp_conv04_11.0',
 'temp_conv04_11.5',
 'temp_conv04_12.0',
 'temp_conv04_12.5',
 'temp_conv04_13.0',
 'temp_conv04_13.5',
 'temp_conv04_14.0',
 'temp_conv04_14.5',
 'temp_conv04_15.0',
 'temp_conv04_15.5',
 'temp_conv04_16.0',
 'temp_conv04_16.5',
 'temp_conv04_17.0',
 'temp_conv04_17.5',
 'temp_conv04_18.0',
 'temp_conv04_18.5',
 'temp_conv04_19.0',
 'temp_conv04_19.5',
 'temp_conv04_20.0',
 'temp_conv04_20.5',

In [27]:
df_train['doy_sin'] = train_doy_sin
df_train['doy_cos'] = train_doy_cos
df_test['doy_sin'] = test_doy_sin
df_test['doy_cos'] = test_doy_cos

In [28]:
df_train['tod_sin'] = train_tod_sin
df_train['tod_cos'] = train_tod_cos
df_test['tod_sin'] = test_tod_sin
df_test['tod_cos'] = test_tod_cos

In [29]:
'''
convert the mean and std to torch
'''
utils.y_mean = torch.tensor(utils.y_mean, device=device)
utils.y_std = torch.tensor(utils.y_std, device=device)

In [30]:
utils.y_mean

tensor([10.0881, 10.1474, 10.1955, 10.2307, 10.2536, 10.2640, 10.2621, 10.2534,
        10.2343, 10.2041, 10.1698, 10.1208, 10.0467,  9.9483,  9.8600,  9.7364,
         9.6203,  9.4911,  9.3375,  9.1913,  9.0196,  8.8059,  8.5462,  8.2926,
         8.0832,  7.8677,  7.6297,  7.4296,  7.2486,  7.0241,  6.7801,  6.5824,
         6.4099,  6.2339,  6.0484,  5.8629,  5.7071,  5.6068,  5.5393,  5.5037,
         5.4894,  5.4739,  5.4826,  5.5073,  5.5378,  5.5720,  5.6060,  5.6388,
         5.6481,  5.6179], device='cuda:0', dtype=torch.float64)

In [31]:
utils.y_std

tensor([9.2798, 9.2152, 9.1558, 9.0999, 9.0472, 8.9927, 8.9378, 8.8856, 8.8320,
        8.7717, 8.7136, 8.6422, 8.5339, 8.4035, 8.2942, 8.1591, 8.0415, 7.8854,
        7.7071, 7.5447, 7.3595, 7.1618, 6.9438, 6.7437, 6.5731, 6.3606, 6.0943,
        5.8815, 5.6644, 5.3624, 5.0458, 4.8157, 4.6251, 4.4228, 4.1847, 3.9359,
        3.7434, 3.6271, 3.5353, 3.4676, 3.3976, 3.2930, 3.2298, 3.1876, 3.1528,
        3.1208, 3.0906, 3.0600, 3.0296, 2.8981], device='cuda:0',
       dtype=torch.float64)

In [32]:
utils.num_features = len(utils.inp_cols)

#### Create train and test samples
Each sample is created using a sliding window. 1 sliding window = 1 lookback window + 1 horizon window = 1 sample

In [33]:
'''
Prepare data : 1 training sample = lookback window + horizon window
'''
Xtrain, Ytrain = utils.windowed_dataset(df_train)
Xtest, Ytest = utils.windowed_dataset(df_test)

In [34]:
Xtrain.shape

(31399, 72, 64)

In [35]:
Ytrain.shape

(31399, 72, 50)

In [36]:
Xtest.shape

(20885, 72, 64)

In [37]:
Ytest.shape

(20885, 72, 50)

#### Datatype conversion to torch

In [38]:
'''
Convert data into torch type
'''
X_train, Y_train, X_test, Y_test = utils.numpy_to_torch(Xtrain, Ytrain, Xtest, Ytest)

## 3. Modeling

#### Define the model

In [39]:
'''
Create the seq2seq model
'''
model = seq2seq(input_size = X_train.shape[2], 
                hidden_size = hidden_feature_size, 
                output_size=output_size,
                model_type=model_type,
                num_layers = num_layers,
                utils=utils,
                dropout=dropout,
                device=device
               )

#### Train the model

In [None]:
'''
Train the model
'''
config = {
    "batch_size": batch_size,
    "epochs": epochs,
    "learning_rate": learning_rate,
    "eval_freq": eval_freq,
    "batch_shuffle": batch_shuffle,
    "dropout":dropout,
    "num_layers": num_layers,
    "hidden_feature_size": hidden_feature_size,
    "model_type": model_type,
    "teacher_forcing_ratio": teacher_forcing_ratio,
    "max_lr": max_lr,
    "div_factor": div_factor,
    "pct_start": pct_start,
    "anneal_strategy": anneal_strategy,
    "final_div_factor": final_div_factor,
    "dataset": file,
    "split_ratio":split_ratio,
    "input_window":input_window,
    "output_window":output_window,
    "early_stop_thres":thres,
    "early_stop_delta":delta,
    "early_stop":early_stop,
    "weight_decay":weight_decay
}
loss, test_rmse, train_rmse = model.train_model(X_train, 
                                                Y_train,
                                                X_test,
                                                Y_test,
                                                target_len = output_window,
                                                config = config,
                                                training_prediction = training_prediction,  
                                                dynamic_tf = False,
                                                project_name = wandb_project,
                                                run_name = wandb_run,
                                                save_code = save_code)

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01128888888957186, max=1.0)…

  0%|                                                                                          | 0/100 [00:00<?, ?it/s]
  0%|                                                                                          | 0/982 [00:00<?, ?it/s][A
  0%|                                                                                  | 1/982 [00:02<43:45,  2.68s/it][A
  0%|▏                                                                                 | 2/982 [00:03<27:45,  1.70s/it][A
  0%|▎                                                                                 | 3/982 [00:04<22:31,  1.38s/it][A
  0%|▎                                                                                 | 4/982 [00:05<20:09,  1.24s/it][A
  1%|▍                                                                                 | 5/982 [00:06<18:50,  1.16s/it][A
  1%|▌                                                                                 | 6/982 [00:08<21:23,  1.32s/it][A
  1%|▌             

  7%|█████▎                                                                           | 65/982 [02:53<42:33,  2.78s/it][A
  7%|█████▍                                                                           | 66/982 [02:56<42:33,  2.79s/it][A
  7%|█████▌                                                                           | 67/982 [02:59<42:24,  2.78s/it][A
  7%|█████▌                                                                           | 68/982 [03:01<42:17,  2.78s/it][A
  7%|█████▋                                                                           | 69/982 [03:04<42:12,  2.77s/it][A
  7%|█████▊                                                                           | 70/982 [03:07<42:18,  2.78s/it][A
  7%|█████▊                                                                           | 71/982 [03:10<42:11,  2.78s/it][A
  7%|█████▉                                                                           | 72/982 [03:13<42:09,  2.78s/it][A
  7%|██████     

 13%|██████████▋                                                                     | 131/982 [05:58<39:47,  2.81s/it][A
 13%|██████████▊                                                                     | 132/982 [06:01<39:55,  2.82s/it][A
 14%|██████████▊                                                                     | 133/982 [06:04<39:50,  2.82s/it][A
 14%|██████████▉                                                                     | 134/982 [06:07<39:49,  2.82s/it][A
 14%|██████████▉                                                                     | 135/982 [06:09<39:42,  2.81s/it][A
 14%|███████████                                                                     | 136/982 [06:12<39:39,  2.81s/it][A
 14%|███████████▏                                                                    | 137/982 [06:15<39:37,  2.81s/it][A
 14%|███████████▏                                                                    | 138/982 [06:18<39:36,  2.82s/it][A
 14%|███████████

 20%|████████████████                                                                | 197/982 [09:04<36:37,  2.80s/it][A
 20%|████████████████▏                                                               | 198/982 [09:06<36:33,  2.80s/it][A
 20%|████████████████▏                                                               | 199/982 [09:09<36:26,  2.79s/it][A
 20%|████████████████▎                                                               | 200/982 [09:12<36:21,  2.79s/it][A
 20%|████████████████▎                                                               | 201/982 [09:15<36:19,  2.79s/it][A
 21%|████████████████▍                                                               | 202/982 [09:18<36:18,  2.79s/it][A
 21%|████████████████▌                                                               | 203/982 [09:20<36:13,  2.79s/it][A
 21%|████████████████▌                                                               | 204/982 [09:23<36:12,  2.79s/it][A
 21%|███████████

 27%|█████████████████████▍                                                          | 263/982 [12:09<33:45,  2.82s/it][A
 27%|█████████████████████▌                                                          | 264/982 [12:12<33:39,  2.81s/it][A
 27%|█████████████████████▌                                                          | 265/982 [12:15<33:36,  2.81s/it][A
 27%|█████████████████████▋                                                          | 266/982 [12:18<33:35,  2.82s/it][A
 27%|█████████████████████▊                                                          | 267/982 [12:20<33:25,  2.81s/it][A
 27%|█████████████████████▊                                                          | 268/982 [12:23<33:25,  2.81s/it][A
 27%|█████████████████████▉                                                          | 269/982 [12:26<33:17,  2.80s/it][A
 27%|█████████████████████▉                                                          | 270/982 [12:29<33:14,  2.80s/it][A
 28%|███████████

 34%|██████████████████████████▏                                                   | 329/982 [17:42<1:02:53,  5.78s/it][A
 34%|██████████████████████████▏                                                   | 330/982 [17:48<1:02:49,  5.78s/it][A
 34%|██████████████████████████▎                                                   | 331/982 [17:54<1:03:00,  5.81s/it][A
 34%|██████████████████████████▎                                                   | 332/982 [17:59<1:02:30,  5.77s/it][A
 34%|██████████████████████████▍                                                   | 333/982 [18:05<1:02:33,  5.78s/it][A
 34%|██████████████████████████▌                                                   | 334/982 [18:11<1:02:28,  5.79s/it][A
 34%|██████████████████████████▌                                                   | 335/982 [18:17<1:02:18,  5.78s/it][A
 34%|██████████████████████████▋                                                   | 336/982 [18:22<1:01:43,  5.73s/it][A
 34%|███████████

 40%|████████████████████████████████▏                                               | 395/982 [24:04<56:24,  5.77s/it][A
 40%|████████████████████████████████▎                                               | 396/982 [24:10<56:18,  5.77s/it][A
 40%|████████████████████████████████▎                                               | 397/982 [24:16<56:30,  5.80s/it][A
 41%|████████████████████████████████▍                                               | 398/982 [24:22<56:23,  5.79s/it][A
 41%|████████████████████████████████▌                                               | 399/982 [24:27<56:12,  5.79s/it][A
 41%|████████████████████████████████▌                                               | 400/982 [24:33<56:14,  5.80s/it][A
 41%|████████████████████████████████▋                                               | 401/982 [24:39<56:04,  5.79s/it][A
 41%|████████████████████████████████▋                                               | 402/982 [24:45<56:00,  5.79s/it][A
 41%|███████████

 47%|█████████████████████████████████████▌                                          | 461/982 [30:29<51:03,  5.88s/it][A
 47%|█████████████████████████████████████▋                                          | 462/982 [30:35<50:45,  5.86s/it][A
 47%|█████████████████████████████████████▋                                          | 463/982 [30:40<50:25,  5.83s/it][A
 47%|█████████████████████████████████████▊                                          | 464/982 [30:46<50:22,  5.83s/it][A
 47%|█████████████████████████████████████▉                                          | 465/982 [30:52<50:11,  5.82s/it][A
 47%|█████████████████████████████████████▉                                          | 466/982 [30:58<49:54,  5.80s/it][A
 48%|██████████████████████████████████████                                          | 467/982 [31:04<49:50,  5.81s/it][A
 48%|██████████████████████████████████████▏                                         | 468/982 [31:09<49:38,  5.79s/it][A
 48%|███████████

 54%|██████████████████████████████████████████▉                                     | 527/982 [36:51<44:01,  5.81s/it][A
 54%|███████████████████████████████████████████                                     | 528/982 [36:57<43:59,  5.81s/it][A
 54%|███████████████████████████████████████████                                     | 529/982 [37:03<43:48,  5.80s/it][A
 54%|███████████████████████████████████████████▏                                    | 530/982 [37:08<43:38,  5.79s/it][A
 54%|███████████████████████████████████████████▎                                    | 531/982 [37:14<43:06,  5.74s/it][A
 54%|███████████████████████████████████████████▎                                    | 532/982 [37:20<43:13,  5.76s/it][A
 54%|███████████████████████████████████████████▍                                    | 533/982 [37:26<43:08,  5.76s/it][A
 54%|███████████████████████████████████████████▌                                    | 534/982 [37:32<43:18,  5.80s/it][A
 54%|███████████

 60%|████████████████████████████████████████████████▎                               | 593/982 [43:14<37:30,  5.79s/it][A
 60%|████████████████████████████████████████████████▍                               | 594/982 [43:20<37:28,  5.79s/it][A
 61%|████████████████████████████████████████████████▍                               | 595/982 [43:26<37:15,  5.78s/it][A
 61%|████████████████████████████████████████████████▌                               | 596/982 [43:31<37:14,  5.79s/it][A
 61%|████████████████████████████████████████████████▋                               | 597/982 [43:37<37:07,  5.79s/it][A
 61%|████████████████████████████████████████████████▋                               | 598/982 [43:43<37:05,  5.80s/it][A
 61%|████████████████████████████████████████████████▊                               | 599/982 [43:49<37:04,  5.81s/it][A
 61%|████████████████████████████████████████████████▉                               | 600/982 [43:54<36:32,  5.74s/it][A
 61%|███████████

 67%|█████████████████████████████████████████████████████▋                          | 659/982 [49:36<31:21,  5.82s/it][A
 67%|█████████████████████████████████████████████████████▊                          | 660/982 [49:42<31:17,  5.83s/it][A
 67%|█████████████████████████████████████████████████████▊                          | 661/982 [49:48<31:13,  5.84s/it][A
 67%|█████████████████████████████████████████████████████▉                          | 662/982 [49:54<31:12,  5.85s/it][A
 68%|██████████████████████████████████████████████████████                          | 663/982 [50:00<31:09,  5.86s/it][A
 68%|██████████████████████████████████████████████████████                          | 664/982 [50:06<31:00,  5.85s/it][A
 68%|██████████████████████████████████████████████████████▏                         | 665/982 [50:12<30:57,  5.86s/it][A
 68%|██████████████████████████████████████████████████████▎                         | 666/982 [50:17<30:41,  5.83s/it][A
 68%|███████████

 74%|███████████████████████████████████████████████████████████                     | 725/982 [55:59<24:54,  5.81s/it][A
 74%|███████████████████████████████████████████████████████████▏                    | 726/982 [56:05<24:50,  5.82s/it][A
 74%|███████████████████████████████████████████████████████████▏                    | 727/982 [56:11<24:35,  5.78s/it][A
 74%|███████████████████████████████████████████████████████████▎                    | 728/982 [56:16<24:30,  5.79s/it][A
 74%|███████████████████████████████████████████████████████████▍                    | 729/982 [56:22<24:24,  5.79s/it][A
 74%|███████████████████████████████████████████████████████████▍                    | 730/982 [56:28<24:09,  5.75s/it][A
 74%|███████████████████████████████████████████████████████████▌                    | 731/982 [56:34<24:08,  5.77s/it][A
 75%|███████████████████████████████████████████████████████████▋                    | 732/982 [56:39<24:03,  5.77s/it][A
 75%|███████████

 81%|██████████████████████████████████████████████████████████████▊               | 791/982 [1:02:22<18:33,  5.83s/it][A
 81%|██████████████████████████████████████████████████████████████▉               | 792/982 [1:02:28<18:27,  5.83s/it][A
 81%|██████████████████████████████████████████████████████████████▉               | 793/982 [1:02:34<18:22,  5.83s/it][A
 81%|███████████████████████████████████████████████████████████████               | 794/982 [1:02:40<18:14,  5.82s/it][A
 81%|███████████████████████████████████████████████████████████████▏              | 795/982 [1:02:45<18:06,  5.81s/it][A
 81%|███████████████████████████████████████████████████████████████▏              | 796/982 [1:02:51<18:02,  5.82s/it][A
 81%|███████████████████████████████████████████████████████████████▎              | 797/982 [1:02:57<17:55,  5.81s/it][A
 81%|███████████████████████████████████████████████████████████████▍              | 798/982 [1:03:03<17:41,  5.77s/it][A
 81%|███████████

 87%|████████████████████████████████████████████████████████████████████          | 857/982 [1:08:45<12:07,  5.82s/it][A
 87%|████████████████████████████████████████████████████████████████████▏         | 858/982 [1:08:51<12:01,  5.82s/it][A
 87%|████████████████████████████████████████████████████████████████████▏         | 859/982 [1:08:57<11:55,  5.82s/it][A
 88%|████████████████████████████████████████████████████████████████████▎         | 860/982 [1:09:02<11:47,  5.80s/it][A
 88%|████████████████████████████████████████████████████████████████████▍         | 861/982 [1:09:08<11:42,  5.81s/it][A
 88%|████████████████████████████████████████████████████████████████████▍         | 862/982 [1:09:14<11:36,  5.80s/it][A
 88%|████████████████████████████████████████████████████████████████████▌         | 863/982 [1:09:20<11:29,  5.79s/it][A
 88%|████████████████████████████████████████████████████████████████████▋         | 864/982 [1:09:25<11:03,  5.63s/it][A
 88%|███████████

 94%|█████████████████████████████████████████████████████████████████████████▎    | 923/982 [1:15:05<05:43,  5.82s/it][A
 94%|█████████████████████████████████████████████████████████████████████████▍    | 924/982 [1:15:11<05:37,  5.82s/it][A
 94%|█████████████████████████████████████████████████████████████████████████▍    | 925/982 [1:15:17<05:31,  5.81s/it][A
 94%|█████████████████████████████████████████████████████████████████████████▌    | 926/982 [1:15:23<05:24,  5.80s/it][A
 94%|█████████████████████████████████████████████████████████████████████████▋    | 927/982 [1:15:29<05:21,  5.85s/it][A
 95%|█████████████████████████████████████████████████████████████████████████▋    | 928/982 [1:15:35<05:14,  5.83s/it][A
 95%|█████████████████████████████████████████████████████████████████████████▊    | 929/982 [1:15:40<05:08,  5.83s/it][A
 95%|█████████████████████████████████████████████████████████████████████████▊    | 930/982 [1:15:46<05:01,  5.80s/it][A
 95%|███████████

In [None]:
plt.figure(figsize=(5,4), dpi=150)
plt.plot(train_rmse, lw=2.0, label='train_rmse')
plt.plot(test_rmse, lw=2.0, label='test_rmse')
plt.yscale("log")
plt.grid("on", alpha=0.2)
plt.legend()
plt.show()

#### Plot the train test curves

In [None]:
plt.figure(figsize=(5,4), dpi=150)
plt.plot(train_rmse, lw=2.0, label='train_rmse')
plt.plot(test_rmse, lw=2.0, label='test_rmse')
plt.yscale("log")
plt.grid("on", alpha=0.2)
plt.legend()
plt.show()

#### Save the model

In [None]:
load = False

# If load=True, specify the model to load in the below line
MODEL_PATH = "./models/model_weights_test_run_2023-03-21_22:42:45.577525"

In [None]:

if load:
    model.load_state_dict(torch.load(MODEL_PATH))
else:
    MODEL_PATH='./models/model_weights_{}'.format(wandb_run[:22])
    if not os.path.exists('./models'):
        os.mkdir('./models')
    torch.save(model.state_dict(), MODEL_PATH)

In [None]:
'''
Perform evaluation
'''
train_eval_dict = model.evaluate_batch(X_train.to(device), Y_train.to(device))
test_eval_dict = model.evaluate_batch(X_test.to(device), Y_test.to(device))

In [None]:
X_test.shape

In [None]:
X_train.shape

In [None]:
Y_train.shape

In [None]:
Y_test.shape

In [None]:
test_eval_dict['y_true'][1001][2]

In [None]:
test_eval_dict['y_pred'][1001][2]

## 4. Plotting and Evaluation

In [None]:
'''
Create plot tables for T+n th predictions
'''
train_gt = train_eval_dict['y_true']
train_gt_df = pd.DataFrame(train_gt.cpu().numpy()[:,:,0])
train_gt_values = np.append(train_gt_df[0].values, train_gt_df.iloc[-1,1:]) # ground-truth values for train data

test_gt = test_eval_dict['y_true']
test_gt_df = pd.DataFrame(test_gt.cpu().numpy()[:,:,0])
test_gt_values = np.append(test_gt_df[0].values, test_gt_df.iloc[-1,1:]) # ground-truth values for test data

train_pred = train_eval_dict['y_pred'] # model predicted values for train data
test_pred = test_eval_dict['y_pred'] # model predicted values for test data

df_train_comp = df_train
#df_train_comp=df_train_comp.rename(columns = {'time':'Date'})
#print(df_train_comp.Date)

df_test_comp = df_test
#df_test_comp=df_test_comp.rename(columns = {'time':'Date'})
#print(df_test_comp.Date)

print(df_train_comp.shape)
print(train_pred.shape)
print(train_gt_values.shape)

train_T_pred_table, train_plot_df, plot_train_gt_values = utils.predictionTable(df_train_comp, train_pred, train_gt_values)

test_T_pred_table, test_plot_df, plot_test_gt_values = utils.predictionTable(df_test_comp, test_pred, test_gt_values)

In [None]:
'''
Generate the plots on train data
'''

# Specify the list of T+n predictions to plot
horizon_range = [1,7, 14] # this will plot T+1 and T+n predictions w.r.t Ground truth

utils.plotTable(train_plot_df, plot_train_gt_values, horizon_range)

In [None]:
'''
Generate the plots on test data
'''

# Specify the list of T+n predictions to plot
horizon_range = [1,7,14] # this will plot T+1 and T+n predictions w.r.t Ground truth

utils.plotTable(test_plot_df, plot_test_gt_values, horizon_range)

#### Compute the RMSE values

In [None]:
'''
Compute train rmse

- Train RMSE values for all T+n th predictions. The index represents the T+n

'''
rmse_values = []
for i in range(output_window):
    rmse_values.append(utils.compute_rmse(i, train_T_pred_table, train_gt_values))
rmse_values = pd.DataFrame(rmse_values, columns=['RMSE'], index=range(1,output_window+1))
rmse_values

In [None]:
'''
Compute test rmse

- Test RMSE values for all T+n th predictions. The index represents the T+n

'''
test_rmse_values = []
for i in range(output_window):
    test_rmse_values.append(utils.compute_rmse(i, test_T_pred_table, test_gt_values))
test_rmse_values = pd.DataFrame(test_rmse_values, columns=['RMSE'], index=range(1,output_window+1))
test_rmse_values