In [None]:
import pandas as pd
from darts import TimeSeries
from darts.models import TFTModel
from darts.dataprocessing.transformers import Scaler
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from darts.utils.likelihood_models import QuantileRegression
import sys
import json
sys.path.append('../Helper/')
from dataPreprocessing import get_untransformed_exog, TRAIN_DATA_PATH_1990S
from PyTorchModular import darts_optuna,save_model_hyper_params, train_valid_split_darts, load_prediction
from torch.nn import MSELoss
from torch.optim import AdamW

In [None]:

data = pd.read_csv(TRAIN_DATA_PATH_1990S)
data= get_untransformed_exog(data)

df= data.copy()
df['observation_date'] = pd.to_datetime(df['observation_date'], format='%m/%Y')
df.set_index('observation_date', inplace=True)
df_exog= df.drop('fred_PCEPI',axis=1)

In [None]:
VALID_SIZE=3
HORIZON=1

In [None]:
search_space = {
    'input_chunk_length': (int, 1, 36,{'step':1,'log':False}),
    'lstm_layers': (int, 1, 32,{'step':1,'log':False}),
    'num_attention_heads': (int, 1, 4,{'step':1,'log':False}),
    'hidden_size': (int, 1, 32,{'step':1,'log':False}),
    'dropout': (float, 0., 0.5,{'step':0.1,'log':False}),
    'optimizer_kwargs': {"lr": (float, 1e-5, 1e-1,{'step':None,'log':True})},
    'loss_fn':('categorical',["QuantileRegression", "MSE"])
}

invariates= {
    'output_chunk_length':HORIZON,
    'optimizer_cls':AdamW,
    'add_encoders':{
        "cyclic": {"future": ["month", "quarter"]}
        }
}

best_params=darts_optuna(TFTModel,'TFT',search_space,invariates,df['fred_PCEPI'].copy(),df_exog.copy(),VALID_SIZE,HORIZON,n_trials=2,patience=5,tol=1e-4,verbose=True)
save_model_hyper_params('best_params_tft.json',best_params)

In [None]:
with open('best_params_tft.json','r') as f:
    loaded_params= json.load(f)

INPUT_SIZE= loaded_params['input_chunk_length']
PROBABILISTIC= True if loaded_params['loss_fn'] =="QuantileRegression" else False

if 'loss_fn' in loaded_params.keys():
    if loaded_params['loss_fn']=='MSE':
        loaded_params['loss_fn']=MSELoss()
        loaded_params['likelihood']=None
    else:
        loaded_params['loss_fn']=None
        loaded_params['likelihood']=QuantileRegression((0.25,0.5,0.75))
loaded_params.update(invariates)

loaded_params['optimizer_kwargs']={"lr":loaded_params["lr"]}
del loaded_params["lr"]

print(loaded_params)

early_stopper = EarlyStopping(
    monitor="val_loss",
    patience=25,
    min_delta=1e-5,
    mode='min',
)

loaded_params["pl_trainer_kwargs"]={
            "accelerator": 'auto',
            "callbacks": [early_stopper]
        }

In [None]:
target_scaler= Scaler()
exog_scaler=Scaler()
train_target, valid_target, train_exog, valid_exog=train_valid_split_darts(df_exog.copy(),df['fred_PCEPI'].copy(),VALID_SIZE,INPUT_SIZE)

train_target= target_scaler.fit_transform(train_target)
train_exog= exog_scaler.fit_transform(train_exog)

valid_target= target_scaler.transform(valid_target)
valid_exog= exog_scaler.transform(valid_exog)

model = TFTModel(**loaded_params)

model.fit(train_target,
          past_covariates=train_exog,
          val_series=valid_target,
          val_past_covariates=valid_exog,
          epochs=10000)

In [None]:
# def darts_predict(model:object,input_size,horizon, train_target:pd.Series,train_exog_df:pd.DataFrame, test_target:pd.Series,
#                   test_exog_df:pd.DataFrame, exog_scaler, target_scaler, probabilistic=False):

#     test_dates= test_target.index.copy()

#     combined_target_df= pd.concat(train_target,test_target)
#     combined_exog_df= pd.concat(train_exog_df,test_exog_df)

#     preds_arr=[]
#     for test_date in test_dates:
        
#         x,y= load_prediction(input_size, combined_exog_df, combined_target_df, exog_scaler, target_scaler, test_date)

#         if probabilistic:
#             y_preds= model.predict(n=horizon,series=y,past_covariates=x,num_samples=100).values()
#         else:
#             y_preds= model.predict(n=horizon,series=y,past_covariates=x).values()
        
#         preds_arr.append(y_preds)

# # darts_predict(model,)