In [1]:
# install required packages (only need to run once)
!pip install pytorch-lightning pytorch-forecasting

Collecting pytorch-lightning
  Downloading pytorch_lightning-2.5.2-py3-none-any.whl.metadata (21 kB)
Collecting pytorch-forecasting
  Downloading pytorch_forecasting-1.4.0-py3-none-any.whl.metadata (14 kB)
Collecting torchmetrics>=0.7.0 (from pytorch-lightning)
  Downloading torchmetrics-1.7.4-py3-none-any.whl.metadata (21 kB)
Collecting lightning-utilities>=0.10.0 (from pytorch-lightning)
  Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)
Collecting lightning<3.0.0,>=2.0.0 (from pytorch-forecasting)
  Downloading lightning-2.5.2-py3-none-any.whl.metadata (38 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.1.0->pytorch-lightning)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.1.0->pytorch-lightning)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4

In [6]:
import os
import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt

import pytorch_lightning as pl
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch import Trainer

from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import RMSE
import torchmetrics
os.environ["CUDA_VISIBLE_DEVICES"] = ""
# Monkey-patch CUDA availability checks to always be False
torch.cuda.is_available = lambda: False
torch.cuda.device_count = lambda: 0


def get_attention_weights(
    model: TemporalFusionTransformer,
    input_data: np.ndarray or torch.Tensor,
    prediction_length: int = 1,
    max_encoder_length: int = 32,
):
    """
    Extracts encoder and decoder attention weights from a saved TFT model.

    Parameters:
        model: a loaded TemporalFusionTransformer
        input_data: numpy array or torch tensor of shape (1, max_encoder_length, num_features)
        prediction_length: number of future steps the model predicts
        max_encoder_length: number of past steps used for encoding
    Returns:
        encoder_attention: Tensor of shape (batch=1, num_heads, time, time)
        decoder_attention: Tensor of shape (batch=1, num_heads, time, time)
    """
    # model.eval()
    # convert numpy to tensor if needed and ensure float32
    if isinstance(input_data, np.ndarray):
        x_cont_encoder = torch.from_numpy(input_data).float()
    else:
        x_cont_encoder = input_data.float()
    # verify shape
    assert x_cont_encoder.ndim == 3 and x_cont_encoder.size(1) == max_encoder_length, \
        f"Expected input_data shape (1, {max_encoder_length}, num_features), got {tuple(x_cont_encoder.shape)}"

    # prepare decoder continuous inputs as zeros
    num_features = x_cont_encoder.size(2)
    x_cont_decoder = torch.zeros((1, prediction_length, num_features), dtype=torch.float32)

    n_static_cat = len(getattr(model.hparams, "static_categoricals", []))
    n_static_real = len(getattr(model.hparams, "static_reals", []))
    n_time_known_cat = len(getattr(model.hparams, "time_varying_known_categoricals", []))
    n_time_unknown_cat = len(getattr(model.hparams, "time_varying_unknown_categoricals", []))
    n_time_cat = n_time_known_cat + n_time_unknown_cat

    # 4) zero‐fill your cats
    x_static_cat   = torch.zeros((1, n_static_cat), dtype=torch.long)
    x_static_real  = torch.zeros((1, len(model.hparams.static_reals)))
    x_encoder_cat  = torch.zeros((1, max_encoder_length, n_time_cat), dtype=torch.long)
    x_decoder_cat  = torch.zeros((1, prediction_length, n_time_cat), dtype=torch.long)

    # build input dict for TFT forward
    x = {
        "static_cat":   x_static_cat,
        "static_real":  x_static_real,
        "encoder_cat":  x_encoder_cat,
        "decoder_cat":  x_decoder_cat,
        "encoder_cont": x_cont_encoder,
        "decoder_cont": x_cont_decoder,
        "encoder_lengths": torch.tensor([max_encoder_length]),
        "decoder_lengths": torch.tensor([prediction_length]),
        "target_scale": torch.ones((1, 1, 1)),
    }
    with torch.no_grad():
        out = model(x)
    return out.get("encoder_attention"), out.get("decoder_attention")


import torch
from pytorch_forecasting import TemporalFusionTransformer

def load_tft_strict_cpu(checkpoint_path: str) -> TemporalFusionTransformer:
    """
    Load a TFT Lightning checkpoint strictly on CPU.
    Avoids any .cpu() calls that trigger torchmetrics CUDA init.
    """
    # 1) Load full checkpoint (allow unpickle) onto CPU
    ckpt = torch.load(
        checkpoint_path,
        map_location=torch.device("cpu"),
        weights_only=False
    )
    # 2) Extract Lightning hyperparameters
    hparams = ckpt.get("hyper_parameters", ckpt.get("hparams", {}))
    # 3) Instantiate fresh TFT on CPU (no GPUs involved)
    model = TemporalFusionTransformer(**hparams)
    # 4) Load weights & buffers (already on CPU)
    model.load_state_dict(ckpt["state_dict"])
    # 5) Patch every torchmetrics.Metric to live on CPU
    for module in model.modules():
        if isinstance(module, torchmetrics.Metric):
            module._device = torch.device("cpu")
    # 6) Set to inference mode (no further .cpu needed)
    model.eval()
    return model

In [3]:
df = pd.read_csv('/content/drive/MyDrive/tft/data/df_pca_n.csv')

In [4]:
df

Unnamed: 0,pca1,pca2,pca3,pca4,pca5,pca6,pca7,pca8,pca11,pca13,pca14,pca16,pca17,pca18,pca19,pca20,pca21,close
0,-4.728898,-1.115874,-0.350434,0.465851,0.346966,-0.063788,-0.840479,-0.644939,0.120015,1.910322,1.305534,0.669070,0.307071,1.389222,1.563639,0.564337,0.155247,28212.73
1,-4.635488,-2.741018,1.690103,-2.328273,-0.770616,1.601905,0.038216,-1.187851,1.623920,0.795812,-1.426525,1.030457,0.065907,1.300028,1.824407,-0.132587,0.037102,28127.82
2,-4.694583,-1.193204,0.440777,1.554080,0.044771,0.620009,0.986779,0.868953,1.621704,-0.432284,0.463382,0.740617,0.691632,1.340005,1.675177,0.299795,0.201588,28169.00
3,-4.708888,-2.340780,0.054083,-0.864007,0.050861,-0.809183,-2.061673,0.437444,1.108165,0.823587,0.540933,0.793279,0.287486,1.366688,1.618464,0.438763,0.133583,28128.59
4,-4.715028,-1.494159,0.141522,1.021576,0.024130,0.061481,0.593962,-0.776617,1.957156,0.502589,1.086233,0.736398,0.540106,1.357617,1.638440,0.410520,0.196583,28150.00
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
210480,5.715374,2.329135,-1.077321,0.317249,0.370631,-0.451116,0.247526,-2.197046,-0.225612,-0.190178,0.455132,0.526352,-0.586859,-0.757855,-2.109945,-0.840042,0.248779,94360.00
210481,5.719356,2.395326,-1.081455,-0.022791,0.384120,0.969201,1.308615,-1.453685,-0.851483,-0.207386,0.093235,0.253536,-0.464944,-0.765870,-2.096844,-0.825267,0.472029,94408.05
210482,5.710340,1.108185,-1.326260,-1.562592,0.287421,1.651703,1.574665,-0.300371,1.421647,-1.088612,0.157923,0.380004,-0.325366,-0.754873,-2.078163,-0.854962,0.357805,94276.00
210483,5.704047,0.310273,-1.231921,-1.317224,0.126251,-1.785598,-1.010372,-0.279763,1.316855,0.653663,0.165989,0.261319,0.163953,-0.790322,-2.071569,-0.848133,0.353897,94159.83


In [7]:
look_back = 32
pred_len = 1
# Alternatively, load an existing checkpoint directly:
ckpt_path = "/content/drive/MyDrive/tft/models/best_tft.ckpt"
loaded_model = load_tft_strict_cpu(ckpt_path)
print(f"Loaded model from {ckpt_path}")

# prepare a sample window
feature_cols = [col for col in df.columns if col not in ["time_idx", "series", "close"]]
window = df[feature_cols].iloc[-look_back:].to_numpy()[None, :, :]
input_tensor = torch.from_numpy(window)

# extract attention weights
enc_attn, dec_attn = get_attention_weights(
    loaded_model,
    input_tensor,
    prediction_length=pred_len,
    max_encoder_length=look_back
)

print(f"Encoder attention shape: {enc_attn.shape}")
print(f"Decoder attention shape: {dec_attn.shape}")

/usr/local/lib/python3.11/dist-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss'])`.
/usr/local/lib/python3.11/dist-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'logging_metrics' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['logging_metrics'])`.


Loaded model from /content/drive/MyDrive/tft/models/best_tft.ckpt


IndexError: index 0 is out of bounds for dimension 1 with size 0