# TFT
https://arxiv.org/abs/1912.09363

## Imports:

In [58]:
import pandas as pd
from darts import TimeSeries
from pathlib import Path
from darts.models.forecasting.tft_model import TFTModel
from darts.utils.model_selection import train_test_split
import sys
from sklearn.preprocessing import MinMaxScaler
sys.path.append('../Helper/')
from dataPreprocessing import add_time_features, rank_features_ccf, get_untransformed_exog
from darts.dataprocessing.transformers import Scaler
from torch.nn import MSELoss
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torchmetrics import MeanSquaredError
from hyperparameters import tune_hyperparameters, load_best_hyperparameters, save_best_hyperparameters

# Data Preprocessing:

In [59]:
cwd=Path.cwd()
data= pd.read_csv(cwd.parent.parent / 'Data' /'Train'/'train1990s.csv')
data['observation_date'] = pd.to_datetime(data['observation_date'], format='%m/%Y')
data= get_untransformed_exog(data)
ranked_cols= rank_features_ccf(data.drop('observation_date',axis=1),targetCol='fred_PCEPI')
data_ranked= pd.DataFrame()
data_ranked['observation_date']= data.loc[:,'observation_date']
data_ranked['fred_PCEPI']= data.loc[:,'fred_PCEPI']
add_time_features(data_ranked)
data_ranked.loc[:,ranked_cols[:25]]= data.loc[:,ranked_cols[:25]].copy()
display(data_ranked)

2025-04-14 20:32:13,438 - INFO - Added time features: year, month, quarter. DataFrame shape: (408, 5)


Unnamed: 0,observation_date,fred_PCEPI,year,month,quarter,fred_PCUOMFGOMFG,fred_PPIACO,fred_APU000074714,BrentOil_Open,CrudeOilWTI_Open,...,food_price_indices_data_f_Food Price Index,fred_CSUSHPISA,CMO-Historical-Data-Monthly_TSP,food_price_indices_data_f_Meat,BrentOil_Price,CMO-Historical-Data-Monthly_Agriculture,CMO-Historical-Data-Monthly_DAP,CrudeOilWTI_Price,CMO-Historical-Data-Monthly_Metals_minerals,Copper_High
0,1990-01-01,58.553,1990,1,1,112.700,114.900,1.042,20.20,21.81,...,64.4,76.897,109.50,74.3,20.06,55.582121,138.500,22.68,41.358420,1.1250
1,1990-02-01,58.811,1990,2,1,112.200,114.400,1.037,19.90,22.61,...,64.7,77.053,112.00,76.8,19.48,55.814067,139.750,21.54,40.591515,1.1075
2,1990-03-01,59.033,1990,3,1,112.300,114.200,1.023,19.48,21.73,...,64.0,77.201,112.00,78.5,18.38,56.249916,146.000,20.28,44.825494,1.2240
3,1990-04-01,59.157,1990,4,2,112.600,114.100,1.044,18.49,20.20,...,66.0,77.278,127.50,81.2,17.12,56.575714,155.200,18.54,44.470335,1.2050
4,1990-05-01,59.290,1990,5,2,113.100,114.600,1.061,17.20,18.46,...,64.6,77.297,127.50,81.8,16.24,56.049268,149.125,17.40,44.834773,1.1925
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
403,2023-08-01,120.965,2023,8,3,251.093,257.680,3.955,84.78,81.64,...,121.8,309.182,450.63,114.5,86.83,108.741176,528.750,83.63,99.637595,4.0345
404,2023-09-01,121.387,2023,9,3,252.368,258.934,3.988,86.22,83.63,...,121.7,311.279,461.50,113.4,92.20,110.751493,527.900,90.79,101.157146,3.9085
405,2023-10-01,121.421,2023,10,4,249.119,255.192,3.782,90.48,90.88,...,120.7,313.257,468.13,112.0,85.02,109.301730,534.750,81.02,98.459930,3.7615
406,2023-11-01,121.415,2023,11,4,246.497,252.856,3.500,84.78,81.31,...,120.6,314.132,462.63,111.6,80.86,111.018307,535.630,75.96,101.158840,3.8638


## Constants:

In [46]:
scaler1=Scaler(MinMaxScaler())
scaler2=Scaler(MinMaxScaler())
scaler3=Scaler(MinMaxScaler())
INPUT_SIZE=12
HORIZON=3
HIDDEN_SIZE=4

VALID_SIZE=3

In [47]:
def choose_top_n_features(df:pd.DataFrame, n:int):

    return df.iloc[:,:n].copy(deep=True)

def create_time_series_objects(df,future_covariates_cols, target_col= "fred_PCEPI", date_col="observation_date"):

    future_covariates= df[future_covariates_cols]
    past_covariates= df.drop(future_covariates_cols,axis=1).drop(target_col,axis=1).drop(date_col,axis=1)
    target= df[[date_col,target_col]]

    return future_covariates, past_covariates, target

def darts_get_inference_set(past_covariates,future_covariates,target, valid_size, input_size, time_col="observation_date", freq= "MS"):
    X_train_future= TimeSeries.from_dataframe(future_covariates.iloc[:-valid_size,:],time_col=time_col,freq=freq)
    X_train_past= TimeSeries.from_dataframe(past_covariates.iloc[:-valid_size,:],time_col=time_col,freq=freq)
    y_train= TimeSeries.from_dataframe(target.iloc[:-valid_size,:],time_col=time_col,freq=freq)

    X_valid_future= TimeSeries.from_dataframe(future_covariates.iloc[-valid_size-input_size:,:],time_col=time_col,freq=freq)
    X_valid_past= TimeSeries.from_dataframe(past_covariates.iloc[-valid_size-input_size:,:],time_col=time_col,freq=freq)
    y_valid= TimeSeries.from_dataframe(target.iloc[-valid_size-input_size:,:],time_col=time_col,freq=freq)

    return X_train_past, X_valid_past, X_train_future, X_valid_future, y_train, y_valid

## Create TimeSeries objects:

In [48]:
future_covariates= data_ranked[["observation_date","year","month", "quarter"]]
past_covariates= data_ranked.drop(["fred_PCEPI","year","month", "quarter"],axis=1)
target= data_ranked[["observation_date","fred_PCEPI"]]

In [49]:

X_train_past, X_valid_past, X_train_future, X_valid_future, y_train, y_valid=darts_get_inference_set(past_covariates,future_covariates,target, valid_size=VALID_SIZE,input_size=INPUT_SIZE)

## Feature Scaling:

In [50]:

X_train= scaler1.fit_transform(X_train_past)
X_train_future= scaler2.fit_transform(X_train_future)
y_train= scaler3.fit_transform(y_train)


X_valid=scaler1.transform(X_valid_past)
X_valid_future= scaler2.transform(X_valid_future)
y_valid=scaler3.transform(y_valid)


# Training:
To optimize:
* number of features
* hidden_size [1,2,4,8,16,32]
* lstm layers [1,2,4]
* num_attention_heads [1,2,4,8]
* lr [1e-2,1e-3,1e-4,1e-5]


In [51]:
valid_metric= MeanSquaredError()

In [None]:
early_stopper = EarlyStopping(
    monitor="val_MeanSquaredError",
    patience=15,
    min_delta=1e-3,
    mode='min',
)
model_checkpoint= ModelCheckpoint('.', 'TFT_best_weights', monitor="val_MeanSquaredError",save_top_k=1, mode="min", save_weights_only=True)

In [53]:
TFT_params= {
    'input_chunk_length':INPUT_SIZE,
    'output_chunk_length':HORIZON,
    'hidden_size':HIDDEN_SIZE,
    'batch_size':128,
    'use_static_covariates':False,
    'add_relative_index':True,
    'likelihood':None,
    'loss_fn' : MSELoss(),
    'torch_metrics':valid_metric,
    'save_checkpoints':True,
    'random_state':42, 
    'pl_trainer_kwargs':{
        "accelerator":"gpu",
        "devices": -1,
        "callbacks": [early_stopper, model_checkpoint]
        },
    'dropout':0.1
}

In [54]:
model = TFTModel(**TFT_params)

model.fit(y_train,past_covariates=X_train_past,future_covariates=X_train_future,val_series=y_valid, val_past_covariates=X_valid_past,val_future_covariates=X_valid_future, epochs=10000)#,past_covariates=X_train,future_covariates=X_train_future,

2025-04-14 13:48:01,889 - INFO - Train dataset contains 391 samples.
2025-04-14 13:48:01,941 - INFO - Time series values are 64-bits; casting model to float64.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
c:\Users\kevin\AppData\Local\Programs\Python\Python312\Lib\site-packages\pytorch_lightning\callbacks\model_checkpoint.py:654: Checkpoint directory H:\My Drive\UniversityOfLeeds\CS\Year4\GroupProject\repo\COMP5530M-Group-Project-Inflation-Forecasting\Training\TFT exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                              | Type                             | Params | Mode 
------------------------------------------------------------------------------------------------
0  | criterion                         | MSELoss                          | 0      | train
1  | train_criterion                   | MSELoss                          | 0      | train
2  | val_criterion     

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\kevin\AppData\Local\Programs\Python\Python312\Lib\site-packages\pytorch_lightning\core\module.py:512: You called `self.log('val_MeanSquaredError', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

c:\Users\kevin\AppData\Local\Programs\Python\Python312\Lib\site-packages\pytorch_lightning\core\module.py:512: You called `self.log('train_MeanSquaredError', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

TFTModel(output_chunk_shift=0, hidden_size=4, lstm_layers=1, num_attention_heads=4, full_attention=False, feed_forward=GatedResidualNetwork, dropout=0.1, hidden_continuous_size=8, categorical_embedding_sizes=None, add_relative_index=True, loss_fn=MSELoss(), likelihood=None, norm_type=LayerNorm, use_static_covariates=False, input_chunk_length=12, output_chunk_length=3, batch_size=128, torch_metrics=MeanSquaredError(), save_checkpoints=True, random_state=42, pl_trainer_kwargs={'accelerator': 'gpu', 'devices': -1, 'callbacks': [<pytorch_lightning.callbacks.early_stopping.EarlyStopping object at 0x000001557F1EF710>, <pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint object at 0x000001557F1ED220>]})

In [55]:
import torch
torch.save(torch.load(model_checkpoint.best_model_path, weights_only=False)["state_dict"],'sate_dict.pt')

In [56]:
model.load('sate_dict.pt')



OrderedDict([('encoder_vsn.flattened_grn.resample_norm.mask',
              tensor([ 0.0027, -0.0040,  0.0038,  0.0005, -0.0005, -0.0034,  0.0039,  0.0039,
                       0.0036,  0.0033,  0.0038, -0.0040,  0.0039, -0.0038, -0.0037,  0.0040,
                      -0.0039, -0.0040, -0.0037,  0.0035,  0.0039, -0.0036, -0.0039,  0.0040,
                      -0.0033,  0.0036, -0.0034,  0.0021,  0.0025,  0.0030], device='cuda:0',
                     dtype=torch.float64)),
             ('encoder_vsn.flattened_grn.resample_norm.norm.weight',
              tensor([1.0038, 1.0030, 1.0039, 1.0034, 1.0039, 0.9968, 1.0040, 0.9961, 1.0039,
                      1.0035, 1.0039, 0.9960, 1.0039, 0.9963, 0.9963, 1.0040, 0.9961, 0.9960,
                      0.9963, 1.0037, 1.0040, 1.0036, 0.9961, 1.0040, 0.9968, 0.9964, 1.0039,
                      0.9964, 1.0038, 0.9962], device='cuda:0', dtype=torch.float64)),
             ('encoder_vsn.flattened_grn.resample_norm.norm.bias',
             

In [57]:
scaler3.inverse_transform(model.predict(3,series=y_train,past_covariates=X_valid_past, future_covariates=X_valid_future)).to_dataframe()

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |          | 0/? [00:00<?, ?it/s]

component,fred_PCEPI
observation_date,Unnamed: 1_level_1
2023-10-01,88.185044
2023-11-01,87.39401
2023-12-01,86.999285


In [None]:
def objective(trial):

    
    
       

In [None]:
from darts.explainability import TFTExplainer

explainer = TFTExplainer(
    tft_model,
    background_series=data_transformed[1],
    background_future_covariates=dynamic_covariates_transformed[1],
)
explainability_result = explainer.explain()