In [1]:
!pip install pytorch-forecasting

Collecting pytorch-forecasting
  Downloading pytorch_forecasting-1.3.0-py3-none-any.whl.metadata (13 kB)
Collecting lightning<3.0.0,>=2.0.0 (from pytorch-forecasting)
  Downloading lightning-2.5.1-py3-none-any.whl.metadata (39 kB)
Collecting lightning-utilities<2.0,>=0.10.0 (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)
  Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)
Collecting torchmetrics<3.0,>=0.7.0 (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)
  Downloading torchmetrics-1.7.0-py3-none-any.whl.metadata (21 kB)
Collecting pytorch-lightning (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)
  Downloading pytorch_lightning-2.5.1-py3-none-any.whl.metadata (20 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-fo

In [8]:
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset

from pytorch_forecasting.data.timeseries import _coerce_to_dict

In [9]:
def _coerce_to_list(obj):
    """Coerce object to list.

    None is coerced to empty list, otherwise list constructor is used.
    """
    if obj is None:
        return []
    if isinstance(obj, str):
        return [obj]
    return list(obj)


class TimeSeries(Dataset):
    """PyTorch Dataset for time series data stored in pandas DataFrame.

    Parameters
    ----------
    data : pd.DataFrame
        data frame with sequence data.
        Column names must all be str, and contain str as referred to below.
    data_future : pd.DataFrame, optional, default=None
        data frame with future data.
        Column names must all be str, and contain str as referred to below.
        May contain only columns that are in time, group, weight, known, or static.
    time : str, optional, default = first col not in group_ids, weight, target, static.
        integer typed column denoting the time index within ``data``.
        This column is used to determine the sequence of samples.
        If there are no missing observations,
        the time index should increase by ``+1`` for each subsequent sample.
        The first time_idx for each series does not necessarily
        have to be ``0`` but any value is allowed.
    target : str or List[str], optional, default = last column (at iloc -1)
        column(s) in ``data`` denoting the forecasting target.
        Can be categorical or numerical dtype.
    group : List[str], optional, default = None
        list of column names identifying a time series instance within ``data``.
        This means that the ``group`` together uniquely identify an instance,
        and ``group`` together with ``time`` uniquely identify a single observation
        within a time series instance.
        If ``None``, the dataset is assumed to be a single time series.
    weight : str, optional, default=None
        column name for weights.
        If ``None``, it is assumed that there is no weight column.
    num : list of str, optional, default = all columns with dtype in "fi"
        list of numerical variables in ``data``,
        list may also contain list of str, which are then grouped together.
    cat : list of str, optional, default = all columns with dtype in "Obc"
        list of categorical variables in ``data``,
        list may also contain list of str, which are then grouped together
        (e.g. useful for product categories).
    known : list of str, optional, default = all variables
        list of variables that change over time and are known in the future,
        list may also contain list of str, which are then grouped together
        (e.g. useful for special days or promotion categories).
    unknown : list of str, optional, default = no variables
        list of variables that are not known in the future,
        list may also contain list of str, which are then grouped together
        (e.g. useful for weather categories).
    static : list of str, optional, default = all variables not in known, unknown
        list of variables that do not change over time,
        list may also contain list of str, which are then grouped together.
    """

    def __init__(
        self,
        data: pd.DataFrame,
        data_future: Optional[pd.DataFrame] = None,
        time: Optional[str] = None,
        target: Optional[Union[str, List[str]]] = None,
        group: Optional[List[str]] = None,
        weight: Optional[str] = None,
        num: Optional[List[Union[str, List[str]]]] = None,
        cat: Optional[List[Union[str, List[str]]]] = None,
        known: Optional[List[Union[str, List[str]]]] = None,
        unknown: Optional[List[Union[str, List[str]]]] = None,
        static: Optional[List[Union[str, List[str]]]] = None,
    ):

        self.data = data
        self.data_future = data_future
        self.time = time
        self.target = _coerce_to_list(target)
        self.group = _coerce_to_list(group)
        self.weight = weight
        self.num = _coerce_to_list(num)
        self.cat = _coerce_to_list(cat)
        self.known = _coerce_to_list(known)
        self.unknown = _coerce_to_list(unknown)
        self.static = _coerce_to_list(static)

        self.feature_cols = [
            col
            for col in data.columns
            if col not in [self.time] + self.group + [self.weight] + self.target
        ]
        if self.group:
            self._groups = self.data.groupby(self.group).groups
            self._group_ids = list(self._groups.keys())
        else:
            self._groups = {"_single_group": self.data.index}
            self._group_ids = ["_single_group"]

        self._prepare_metadata()

    def _prepare_metadata(self):
        """Prepare metadata for the dataset.

        The funcion returns metadata that contains:

        * ``cols``: dict { 'y': list[str], 'x': list[str], 'st': list[str] }
          Names of columns for y, x, and static features.
          List elements are in same order as column dimensions.
          Columns not appearing are assumed to be named (x0, x1, etc.),
          (y0, y1, etc.), (st0, st1, etc.).
        * ``col_type``: dict[str, str]
          maps column names to data types "F" (numerical) and "C" (categorical).
          Column names not occurring are assumed "F".
        * ``col_known``: dict[str, str]
          maps column names to "K" (future known) or "U" (future unknown).
          Column names not occurring are assumed "K".
        """
        self.metadata = {
            "cols": {
                "y": self.target,
                "x": self.feature_cols,
                "st": self.static,
            },
            "col_type": {},
            "col_known": {},
        }

        all_cols = self.target + self.feature_cols + self.static
        for col in all_cols:
            self.metadata["col_type"][col] = "C" if col in self.cat else "F"

            self.metadata["col_known"][col] = "K" if col in self.known else "U"

    def __len__(self) -> int:
        """Return number of time series in the dataset."""
        return len(self._group_ids)

    def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
        """Get time series data for given index.

        It returns:

        * ``t``: ``numpy.ndarray`` of shape (n_timepoints,)
          Time index for each time point in the past or present. Aligned with ``y``,
          and ``x`` not ending in ``f``.
        * ``y``: tensor of shape (n_timepoints, n_targets)
          Target values for each time point. Rows are time points, aligned with ``t``.
        * ``x``: tensor of shape (n_timepoints, n_features)
          Features for each time point. Rows are time points, aligned with ``t``.
        * ``group``: tensor of shape (n_groups)
          Group identifiers for time series instances.
        * ``st``: tensor of shape (n_static_features)
          Static features.
        * ``cutoff_time``: float or ``numpy.float64``
          Cutoff time for the time series instance.

        Optionally, the following str-keyed entry can be included:

        * ``weights``: tensor of shape (n_timepoints), only if weight is not None
        """
        group_id = self._group_ids[index]

        if self.group:
            mask = self._groups[group_id]
            data = self.data.loc[mask]
        else:
            data = self.data

        cutoff_time = data[self.time].max()

        result = {
            "t": data[self.time].values,
            "y": torch.tensor(data[self.target].values),
            "x": torch.tensor(data[self.feature_cols].values),
            "group": torch.tensor([hash(str(group_id))]),
            "st": torch.tensor(data[self.static].iloc[0].values if self.static else []),
            "cutoff_time": cutoff_time,
        }

        if self.data_future is not None:
            if self.group:
                future_mask = self.data_future.groupby(self.group).groups[group_id]
                future_data = self.data_future.loc[future_mask]
            else:
                future_data = self.data_future

            combined_times = np.concatenate(
                [data[self.time].values, future_data[self.time].values]
            )
            combined_times = np.unique(combined_times)
            combined_times.sort()

            num_timepoints = len(combined_times)
            x_merged = np.full((num_timepoints, len(self.feature_cols)), np.nan)
            y_merged = np.full((num_timepoints, len(self.target)), np.nan)

            current_time_indices = {t: i for i, t in enumerate(combined_times)}
            for i, t in enumerate(data[self.time].values):
                idx = current_time_indices[t]
                x_merged[idx] = data[self.feature_cols].values[i]
                y_merged[idx] = data[self.target].values[i]

            for i, t in enumerate(future_data[self.time].values):
                if t in current_time_indices:
                    idx = current_time_indices[t]
                    for j, col in enumerate(self.known):
                        if col in self.feature_cols:
                            feature_idx = self.feature_cols.index(col)
                            x_merged[idx, feature_idx] = future_data[col].values[i]

            result.update(
                {
                    "t": combined_times,
                    "x": torch.tensor(x_merged, dtype=torch.float32),
                    "y": torch.tensor(y_merged, dtype=torch.float32),
                }
            )

        if self.weight:
            if self.data_future is not None and self.weight in self.data_future.columns:
                weights_merged = np.full(num_timepoints, np.nan)
                for i, t in enumerate(data[self.time].values):
                    idx = current_time_indices[t]
                    weights_merged[idx] = data[self.weight].values[i]

                for i, t in enumerate(future_data[self.time].values):
                    if t in current_time_indices and self.weight in future_data.columns:
                        idx = current_time_indices[t]
                        weights_merged[idx] = future_data[self.weight].values[i]

                result["weights"] = torch.tensor(weights_merged, dtype=torch.float32)
            else:
                result["weights"] = torch.tensor(
                    data[self.weight].values, dtype=torch.float32
                )

        return result

    def get_metadata(self) -> Dict:
        """Return metadata about the dataset.

        Returns
        -------
        Dict
            Dictionary containing:
            - cols: column names for y, x, and static features
            - col_type: mapping of columns to their types (F/C)
            - col_known: mapping of columns to their future known status (K/U)
        """
        return self.metadata

In [17]:
from typing import Dict, List, Optional, Union

from lightning.pytorch import LightningDataModule
from sklearn.preprocessing import RobustScaler, StandardScaler
import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_forecasting.data.encoders import (
    EncoderNormalizer,
    NaNLabelEncoder,
    TorchNormalizer,
)

NORMALIZER = Union[TorchNormalizer, NaNLabelEncoder, EncoderNormalizer]


class EncoderDecoderTimeSeriesDataModule(LightningDataModule):
    """
    Lightning DataModule for processing time series data in an encoder-decoder format.

    This module handles preprocessing, splitting, and batching of time series data
    for use in deep learning models. It supports categorical and continuous features,
    various scalers, and automatic target normalization.

    Parameters
    ----------
    time_series_dataset : TimeSeries
        The dataset containing time series data.
    max_encoder_length : int, default=30
        Maximum length of the encoder input sequence.
    min_encoder_length : Optional[int], default=None
        Minimum length of the encoder input sequence.
        Defaults to `max_encoder_length` if not specified.
    max_prediction_length : int, default=1
        Maximum length of the decoder output sequence.
    min_prediction_length : Optional[int], default=None
        Minimum length of the decoder output sequence.
        Defaults to `max_prediction_length` if not specified.
    min_prediction_idx : Optional[int], default=None
        Minimum index from which predictions start.
    allow_missing_timesteps : bool, default=False
        Whether to allow missing timesteps in the dataset.
    add_relative_time_idx : bool, default=False
        Whether to add a relative time index feature.
    add_target_scales : bool, default=False
        Whether to add target scaling information.
    add_encoder_length : Union[bool, str], default="auto"
        Whether to include encoder length information.
    target_normalizer :
        Union[NORMALIZER, str, List[NORMALIZER], Tuple[NORMALIZER], None],
         default="auto"
        Normalizer for the target variable. If "auto", uses `RobustScaler`.

    categorical_encoders : Optional[Dict[str, NaNLabelEncoder]], default=None
        Dictionary of categorical encoders.

    scalers :
    Optional[Dict[str, Union[StandardScaler, RobustScaler,
                        TorchNormalizer, EncoderNormalizer]]], default=None
        Dictionary of feature scalers.

    randomize_length : Union[None, Tuple[float, float], bool], default=False
        Whether to randomize input sequence length.
    batch_size : int, default=32
        Batch size for DataLoader.
    num_workers : int, default=0
        Number of workers for DataLoader.
    train_val_test_split : tuple, default=(0.7, 0.15, 0.15)
        Proportions for train, validation, and test dataset splits.
    """

    def __init__(
        self,
        time_series_dataset: TimeSeries,
        max_encoder_length: int = 30,
        min_encoder_length: Optional[int] = None,
        max_prediction_length: int = 1,
        min_prediction_length: Optional[int] = None,
        min_prediction_idx: Optional[int] = None,
        allow_missing_timesteps: bool = False,
        add_relative_time_idx: bool = False,
        add_target_scales: bool = False,
        add_encoder_length: Union[bool, str] = "auto",
        target_normalizer: Union[
            NORMALIZER, str, List[NORMALIZER], Tuple[NORMALIZER], None
        ] = "auto",
        categorical_encoders: Optional[Dict[str, NaNLabelEncoder]] = None,
        scalers: Optional[
            Dict[
                str,
                Union[StandardScaler, RobustScaler, TorchNormalizer, EncoderNormalizer],
            ]
        ] = None,
        randomize_length: Union[None, Tuple[float, float], bool] = False,
        batch_size: int = 32,
        num_workers: int = 0,
        train_val_test_split: tuple = (0.7, 0.15, 0.15),
    ):
        super().__init__()
        self.time_series_dataset = time_series_dataset
        self.time_series_metadata = time_series_dataset.get_metadata()

        self.max_encoder_length = max_encoder_length
        self.min_encoder_length = min_encoder_length or max_encoder_length
        self.max_prediction_length = max_prediction_length
        self.min_prediction_length = min_prediction_length or max_prediction_length
        self.min_prediction_idx = min_prediction_idx

        self.allow_missing_timesteps = allow_missing_timesteps
        self.add_relative_time_idx = add_relative_time_idx
        self.add_target_scales = add_target_scales
        self.add_encoder_length = add_encoder_length
        self.randomize_length = randomize_length

        self.batch_size = batch_size
        self.num_workers = num_workers
        self.train_val_test_split = train_val_test_split

        if isinstance(target_normalizer, str) and target_normalizer.lower() == "auto":
            self.target_normalizer = RobustScaler()
        else:
            self.target_normalizer = target_normalizer

        self.categorical_encoders = _coerce_to_dict(categorical_encoders)
        self.scalers = _coerce_to_dict(scalers)

        self.categorical_indices = []
        self.continuous_indices = []
        self._metadata = None

        for idx, col in enumerate(self.time_series_metadata["cols"]["x"]):
            if self.time_series_metadata["col_type"].get(col) == "C":
                self.categorical_indices.append(idx)
            else:
                self.continuous_indices.append(idx)

    def _prepare_metadata(self):
        """Prepare metadata for model initialisation.

        Returns
        -------
        dict
            dictionary containing the following keys:

                * ``encoder_cat``: Number of categorical variables in the encoder.
                    Computed as ``len(self.categorical_indices)``, which counts the
                    categorical feature indices.
                * ``encoder_cont``: Number of continuous variables in the encoder.
                    Computed as ``len(self.continuous_indices)``, which counts the
                    continuous feature indices.
                * ``decoder_cat``: Number of categorical variables in the decoder that
                    are known in advance.
                    Computed by filtering ``self.time_series_metadata["cols"]["x"]``
                    where col_type == "C"(categorical) and col_known == "K" (known)
                * ``decoder_cont``:  Number of continuous variables in the decoder that
                    are known in advance.
                    Computed by filtering ``self.time_series_metadata["cols"]["x"]``
                    where col_type == "F"(continuous) and col_known == "K"(known)
                * ``target``: Number of target variables.
                    Computed as ``len(self.time_series_metadata["cols"]["y"])``, which
                    gives the number of output target columns..
                * ``static_categorical_features``: Number of static categorical features
                    Computed by filtering ``self.time_series_metadata["cols"]["st"]``
                    (static features) where col_type == "C" (categorical).
                * ``static_continuous_features``: Number of static continuous features
                    Computed as difference of
                    ``len(self.time_series_metadata["cols"]["st"])`` (static features)
                    and static_categorical_features that gives static continuous feature
                * ``max_encoder_length``: maximum encoder length
                    Taken directly from `self.max_encoder_length`.
                * ``max_prediction_length``: maximum prediction length
                    Taken directly from `self.max_prediction_length`.
                * ``min_encoder_length``: minimum encoder length
                    Taken directly from `self.min_encoder_length`.
                * ``min_prediction_length``: minimum prediction length
                    Taken directly from `self.min_prediction_length`.

        """
        encoder_cat_count = len(self.categorical_indices)
        encoder_cont_count = len(self.continuous_indices)

        decoder_cat_count = len(
            [
                col
                for col in self.time_series_metadata["cols"]["x"]
                if self.time_series_metadata["col_type"].get(col) == "C"
                and self.time_series_metadata["col_known"].get(col) == "K"
            ]
        )
        decoder_cont_count = len(
            [
                col
                for col in self.time_series_metadata["cols"]["x"]
                if self.time_series_metadata["col_type"].get(col) == "F"
                and self.time_series_metadata["col_known"].get(col) == "K"
            ]
        )

        target_count = len(self.time_series_metadata["cols"]["y"])
        metadata = {
            "encoder_cat": encoder_cat_count,
            "encoder_cont": encoder_cont_count,
            "decoder_cat": decoder_cat_count,
            "decoder_cont": decoder_cont_count,
            "target": target_count,
        }
        if self.time_series_metadata["cols"]["st"]:
            static_cat_count = len(
                [
                    col
                    for col in self.time_series_metadata["cols"]["st"]
                    if self.time_series_metadata["col_type"].get(col) == "C"
                ]
            )
            static_cont_count = (
                len(self.time_series_metadata["cols"]["st"]) - static_cat_count
            )

            metadata["static_categorical_features"] = static_cat_count
            metadata["static_continuous_features"] = static_cont_count
        else:
            metadata["static_categorical_features"] = 0
            metadata["static_continuous_features"] = 0

        metadata.update(
            {
                "max_encoder_length": self.max_encoder_length,
                "max_prediction_length": self.max_prediction_length,
                "min_encoder_length": self.min_encoder_length,
                "min_prediction_length": self.min_prediction_length,
            }
        )

        return metadata

    @property
    def metadata(self):
        """Compute metadata for model initialization.

        This property returns a dictionary containing the shapes and key information
        related to the time series model. The metadata includes:

        * ``encoder_cat``: Number of categorical variables in the encoder.
        * ``encoder_cont``: Number of continuous variables in the encoder.
        * ``decoder_cat``: Number of categorical variables in the decoder that are
                            known in advance.
        * ``decoder_cont``:  Number of continuous variables in the decoder that are
                            known in advance.
        * ``target``: Number of target variables.

        If static features are present, the following keys are added:

        * ``static_categorical_features``: Number of static categorical features
        * ``static_continuous_features``: Number of static continuous features

        It also contains the following information:

        * ``max_encoder_length``: maximum encoder length
        * ``max_prediction_length``: maximum prediction length
        * ``min_encoder_length``: minimum encoder length
        * ``min_prediction_length``: minimum prediction length
        """
        if self._metadata is None:
            self._metadata = self._prepare_metadata()
        return self._metadata

    def _preprocess_data(self, indices: torch.Tensor) -> List[Dict[str, Any]]:
        """Preprocess the data before feeding it into _ProcessedEncoderDecoderDataset.

        Preprocessing steps
        --------------------

        * Converts target (`y`) and features (`x`) to `torch.float32`.
        * Masks time points that are at or before the cutoff time.
        * Splits features into categorical and continuous subsets based on
            predefined indices.


        TODO: add scalers, target normalizers etc.
        """
        processed_data = []

        for idx in indices:
            sample = self.time_series_dataset[idx.item()]

            target = sample["y"]
            features = sample["x"]
            times = sample["t"]
            cutoff_time = sample["cutoff_time"]

            time_mask = torch.tensor(times <= cutoff_time, dtype=torch.bool)

            if isinstance(target, torch.Tensor):
                target = target.float()
            else:
                target = torch.tensor(target, dtype=torch.float32)

            if isinstance(features, torch.Tensor):
                features = features.float()
            else:
                features = torch.tensor(features, dtype=torch.float32)

            # TODO: add scalers, target normalizers etc.

            categorical = (
                features[:, self.categorical_indices]
                if self.categorical_indices
                else torch.zeros((features.shape[0], 0))
            )
            continuous = (
                features[:, self.continuous_indices]
                if self.continuous_indices
                else torch.zeros((features.shape[0], 0))
            )

            processed_data.append(
                {
                    "features": {"categorical": categorical, "continuous": continuous},
                    "target": target,
                    "static": sample.get("st", None),
                    "group": sample.get("group", torch.tensor([0])),
                    "length": len(target),
                    "time_mask": time_mask,
                    "times": times,
                    "cutoff_time": cutoff_time,
                }
            )

        return processed_data

    class _ProcessedEncoderDecoderDataset(Dataset):
        """PyTorch Dataset for processed encoder-decoder time series data.

        Parameters
        ----------
        processed_data : List[Dict[str, Any]]
            List of preprocessed time series samples.
        windows : List[Tuple[int, int, int, int]]
            List of window tuples containing
            (series_idx, start_idx, enc_length, pred_length).
        add_relative_time_idx : bool, default=False
            Whether to include relative time indices.
        """

        def __init__(
            self,
            processed_data: List[Dict[str, Any]],
            windows: List[Tuple[int, int, int, int]],
            add_relative_time_idx: bool = False,
        ):
            self.processed_data = processed_data
            self.windows = windows
            self.add_relative_time_idx = add_relative_time_idx

        def __len__(self):
            return len(self.windows)

        def __getitem__(self, idx):
            """Retrieve a processed time series window for dataloader input.

            x : dict
                Dictionary containing model inputs:

                * ``encoder_cat`` : tensor of shape (enc_length, n_cat_features)
                  Categorical features for the encoder.
                * ``encoder_cont`` : tensor of shape (enc_length, n_cont_features)
                  Continuous features for the encoder.
                * ``decoder_cat`` : tensor of shape (pred_length, n_cat_features)
                  Categorical features for the decoder.
                * ``decoder_cont`` : tensor of shape (pred_length, n_cont_features)
                  Continuous features for the decoder.
                * ``encoder_lengths`` : tensor of shape (1,)
                  Length of the encoder sequence.
                * ``decoder_lengths`` : tensor of shape (1,)
                  Length of the decoder sequence.
                * ``decoder_target_lengths`` : tensor of shape (1,)
                  Length of the decoder target sequence.
                * ``groups`` : tensor of shape (1,)
                  Group identifier for the time series instance.
                * ``encoder_time_idx`` : tensor of shape (enc_length,)
                  Time indices for the encoder sequence.
                * ``decoder_time_idx`` : tensor of shape (pred_length,)
                  Time indices for the decoder sequence.
                * ``target_scale`` : tensor of shape (1,)
                  Scaling factor for the target values.
                * ``encoder_mask`` : tensor of shape (enc_length,)
                  Boolean mask indicating valid encoder time points.
                * ``decoder_mask`` : tensor of shape (pred_length,)
                  Boolean mask indicating valid decoder time points.

                  If static features are present, the following keys are added:

                * ``static_categorical_features`` : tensor of shape
                                                    (1, n_static_cat_features), optional
                  Static categorical features, if available.
                * ``static_continuous_features`` : tensor of shape (1, 0), optional
                  Placeholder for static continuous features (currently empty).

            y : tensor of shape ``(pred_length, n_targets)``
                Target values for the decoder sequence.
            """
            series_idx, start_idx, enc_length, pred_length = self.windows[idx]
            data = self.processed_data[series_idx]

            end_idx = start_idx + enc_length + pred_length
            encoder_indices = slice(start_idx, start_idx + enc_length)
            decoder_indices = slice(start_idx + enc_length, end_idx)

            target_scale = data["target"][encoder_indices]
            target_scale = target_scale[~torch.isnan(target_scale)].abs().mean()
            if torch.isnan(target_scale) or target_scale == 0:
                target_scale = torch.tensor(1.0)

            encoder_mask = (
                data["time_mask"][encoder_indices]
                if "time_mask" in data
                else torch.ones(enc_length, dtype=torch.bool)
            )
            decoder_mask = (
                data["time_mask"][decoder_indices]
                if "time_mask" in data
                else torch.zeros(pred_length, dtype=torch.bool)
            )

            x = {
                "encoder_cat": data["features"]["categorical"][encoder_indices],
                "encoder_cont": data["features"]["continuous"][encoder_indices],
                "decoder_cat": data["features"]["categorical"][decoder_indices],
                "decoder_cont": data["features"]["continuous"][decoder_indices],
                "encoder_lengths": torch.tensor(enc_length),
                "decoder_lengths": torch.tensor(pred_length),
                "decoder_target_lengths": torch.tensor(pred_length),
                "groups": data["group"],
                "encoder_time_idx": torch.arange(enc_length),
                "decoder_time_idx": torch.arange(enc_length, enc_length + pred_length),
                "target_scale": target_scale,
                "encoder_mask": encoder_mask,
                "decoder_mask": decoder_mask,
            }
            if data["static"] is not None:
                x["static_categorical_features"] = data["static"].unsqueeze(0)
                x["static_continuous_features"] = torch.zeros((1, 0))

            y = data["target"][decoder_indices]
            if y.ndim == 1:
                y = y.unsqueeze(-1)

            return x, y

    def _create_windows(
        self, processed_data: List[Dict[str, Any]]
    ) -> List[Tuple[int, int, int, int]]:
        """Generate sliding windows for training, validation, and testing.

        Returns
        -------
        List[Tuple[int, int, int, int]]
            A list of tuples, where each tuple consists of:
            - ``series_idx`` : int
              Index of the time series in `processed_data`.
            - ``start_idx`` : int
              Start index of the encoder window.
            - ``enc_length`` : int
              Length of the encoder input sequence.
            - ``pred_length`` : int
              Length of the decoder output sequence.
        """
        windows = []

        for idx, data in enumerate(processed_data):
            sequence_length = data["length"]

            if sequence_length < self.max_encoder_length + self.max_prediction_length:
                continue

            effective_min_prediction_idx = (
                self.min_prediction_idx
                if self.min_prediction_idx is not None
                else self.max_encoder_length
            )

            max_prediction_idx = sequence_length - self.max_prediction_length + 1

            if max_prediction_idx <= effective_min_prediction_idx:
                continue

            for start_idx in range(
                0, max_prediction_idx - effective_min_prediction_idx
            ):
                if (
                    start_idx + self.max_encoder_length + self.max_prediction_length
                    <= sequence_length
                ):
                    windows.append(
                        (
                            idx,
                            start_idx,
                            self.max_encoder_length,
                            self.max_prediction_length,
                        )
                    )

        return windows

    def setup(self, stage: Optional[str] = None):
        """Prepare the datasets for training, validation, testing, or prediction.

        Parameters
        ----------
        stage : Optional[str], default=None
            Specifies the stage of setup. Can be one of:
            - ``"fit"`` : Prepares training and validation datasets.
            - ``"test"`` : Prepares the test dataset.
            - ``"predict"`` : Prepares the dataset for inference.
            - ``None`` : Prepares all datasets.
        """
        total_series = len(self.time_series_dataset)
        self._split_indices = torch.randperm(total_series)

        self._train_size = int(self.train_val_test_split[0] * total_series)
        self._val_size = int(self.train_val_test_split[1] * total_series)

        self._train_indices = self._split_indices[: self._train_size]
        self._val_indices = self._split_indices[
            self._train_size : self._train_size + self._val_size
        ]
        self._test_indices = self._split_indices[self._train_size + self._val_size :]

        if stage is None or stage == "fit":
            if not hasattr(self, "train_dataset") or not hasattr(self, "val_dataset"):
                self.train_processed = self._preprocess_data(self._train_indices)
                self.val_processed = self._preprocess_data(self._val_indices)

                self.train_windows = self._create_windows(self.train_processed)
                self.val_windows = self._create_windows(self.val_processed)

                self.train_dataset = self._ProcessedEncoderDecoderDataset(
                    self.train_processed, self.train_windows, self.add_relative_time_idx
                )
                self.val_dataset = self._ProcessedEncoderDecoderDataset(
                    self.val_processed, self.val_windows, self.add_relative_time_idx
                )
                # print(self.val_dataset[0])

        elif stage is None or stage == "test":
            if not hasattr(self, "test_dataset"):
                self.test_processed = self._preprocess_data(self._test_indices)
                self.test_windows = self._create_windows(self.test_processed)

                self.test_dataset = self._ProcessedEncoderDecoderDataset(
                    self.test_processed, self.test_windows, self.add_relative_time_idx
                )
        elif stage == "predict":
            predict_indices = torch.arange(len(self.time_series_dataset))
            self.predict_processed = self._preprocess_data(predict_indices)
            self.predict_windows = self._create_windows(self.predict_processed)
            self.predict_dataset = self._ProcessedEncoderDecoderDataset(
                self.predict_processed, self.predict_windows, self.add_relative_time_idx
            )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            collate_fn=self.collate_fn,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
        )

    def predict_dataloader(self):
        return DataLoader(
            self.predict_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
        )

    @staticmethod
    def collate_fn(batch):
        x_batch = {
            "encoder_cat": torch.stack([x["encoder_cat"] for x, _ in batch]),
            "encoder_cont": torch.stack([x["encoder_cont"] for x, _ in batch]),
            "decoder_cat": torch.stack([x["decoder_cat"] for x, _ in batch]),
            "decoder_cont": torch.stack([x["decoder_cont"] for x, _ in batch]),
            "encoder_lengths": torch.stack([x["encoder_lengths"] for x, _ in batch]),
            "decoder_lengths": torch.stack([x["decoder_lengths"] for x, _ in batch]),
            "decoder_target_lengths": torch.stack(
                [x["decoder_target_lengths"] for x, _ in batch]
            ),
            "groups": torch.stack([x["groups"] for x, _ in batch]),
            "encoder_time_idx": torch.stack([x["encoder_time_idx"] for x, _ in batch]),
            "decoder_time_idx": torch.stack([x["decoder_time_idx"] for x, _ in batch]),
            "target_scale": torch.stack([x["target_scale"] for x, _ in batch]),
            "encoder_mask": torch.stack([x["encoder_mask"] for x, _ in batch]),
            "decoder_mask": torch.stack([x["decoder_mask"] for x, _ in batch]),
        }

        if "static_categorical_features" in batch[0][0]:
            x_batch["static_categorical_features"] = torch.stack(
                [x["static_categorical_features"] for x, _ in batch]
            )
            x_batch["static_continuous_features"] = torch.stack(
                [x["static_continuous_features"] for x, _ in batch]
            )

        y_batch = torch.stack([y for _, y in batch])
        return x_batch, y_batch

In [11]:
from lightning.pytorch import Trainer
import pandas as pd
import torch
import torch.nn as nn

from pytorch_forecasting.metrics import MAE, SMAPE

num_series = 100
seq_length = 50
data_list = []
for i in range(num_series):
    x = np.arange(seq_length)
    y = np.sin(x / 5.0) + np.random.normal(scale=0.1, size=seq_length)
    category = i % 5
    static_value = np.random.rand()
    for t in range(seq_length - 1):
        data_list.append(
            {
                "series_id": i,
                "time_idx": t,
                "x": y[t],
                "y": y[t + 1],
                "category": category,
                "future_known_feature": np.cos(t / 10),
                "static_feature": static_value,
                "static_feature_cat": i % 3,
            }
        )
data_df = pd.DataFrame(data_list)
data_df.head()

Unnamed: 0,series_id,time_idx,x,y,category,future_known_feature,static_feature,static_feature_cat
0,0,0,0.023138,0.249834,0,1.0,0.454668,0
1,0,1,0.249834,0.213821,0,0.995004,0.454668,0
2,0,2,0.213821,0.671829,0,0.980067,0.454668,0
3,0,3,0.671829,0.781042,0,0.955336,0.454668,0
4,0,4,0.781042,0.706092,0,0.921061,0.454668,0


In [18]:
dataset = TimeSeries(
    data=data_df,
    time="time_idx",
    target="y",
    group=["series_id"],
    num=["x", "future_known_feature", "static_feature"],
    cat=["category", "static_feature_cat"],
    known=["future_known_feature"],
    unknown=["x", "category"],
    static=["static_feature", "static_feature_cat"],
)

In [19]:
data_module = EncoderDecoderTimeSeriesDataModule(
    time_series_dataset=dataset,
    max_encoder_length=30,
    max_prediction_length=1,
    batch_size=32,
    categorical_encoders={
        "category": NaNLabelEncoder(add_nan=True),
        "static_feature_cat": NaNLabelEncoder(add_nan=True),
    },
    scalers={
        "x": StandardScaler(),
        "future_known_feature": StandardScaler(),
        "static_feature": StandardScaler(),
    },
    target_normalizer=TorchNormalizer(),
)

In [None]:
from typing import Dict, List, Optional, Union

from lightning.pytorch import LightningModule
from lightning.pytorch.utilities.types import STEP_OUTPUT
import torch
from torch.optim import Optimizer


class BaseModel(LightningModule):
    def __init__(
        self,
        loss: nn.Module,
        logging_metrics: Optional[List[nn.Module]] = None,
        optimizer: Optional[Union[Optimizer, str]] = "adam",
        optimizer_params: Optional[Dict] = None,
        lr_scheduler: Optional[str] = None,
        lr_scheduler_params: Optional[Dict] = None,
    ):
        """
        Base model for time series forecasting.

        Parameters
        ----------
        loss : nn.Module
            Loss function to use for training.
        logging_metrics : Optional[List[nn.Module]], optional
            List of metrics to log during training, validation, and testing.
        optimizer : Optional[Union[Optimizer, str]], optional
            Optimizer to use for training.
            Can be a string ("adam", "sgd") or an instance of `torch.optim.Optimizer`.
        optimizer_params : Optional[Dict], optional
            Parameters for the optimizer.
        lr_scheduler : Optional[str], optional
            Learning rate scheduler to use.
            Supported values: "reduce_lr_on_plateau", "step_lr".
        lr_scheduler_params : Optional[Dict], optional
            Parameters for the learning rate scheduler.
        """
        super().__init__()
        self.loss = loss
        self.logging_metrics = logging_metrics if logging_metrics is not None else []
        self.optimizer = optimizer
        self.optimizer_params = optimizer_params if optimizer_params is not None else {}
        self.lr_scheduler = lr_scheduler
        self.lr_scheduler_params = (
            lr_scheduler_params if lr_scheduler_params is not None else {}
        )

    def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Forward pass of the model.

        Parameters
        ----------
        x : Dict[str, torch.Tensor]
            Dictionary containing input tensors

        Returns
        -------
        Dict[str, torch.Tensor]
            Dictionary containing output tensors
        """
        raise NotImplementedError("Forward method must be implemented by subclass.")

    def training_step(
        self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int
    ) -> STEP_OUTPUT:
        """
        Training step for the model.

        Parameters
        ----------
        batch : Tuple[Dict[str, torch.Tensor]]
            Batch of data containing input and target tensors.
        batch_idx : int
            Index of the batch.

        Returns
        -------
        STEP_OUTPUT
            Dictionary containing the loss and other metrics.
        """
        x, y = batch
        y_hat_dict = self(x)
        y_hat = y_hat_dict["prediction"]
        loss = self.loss(y_hat, y)
        self.log(
            "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
        )
        self.log_metrics(y_hat, y, prefix="train")
        return {"loss": loss}

    def validation_step(
        self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int
    ) -> STEP_OUTPUT:
        """
        Validation step for the model.

        Parameters
        ----------
        batch : Tuple[Dict[str, torch.Tensor]]
            Batch of data containing input and target tensors.
        batch_idx : int
            Index of the batch.

        Returns
        -------
        STEP_OUTPUT
            Dictionary containing the loss and other metrics.
        """
        x, y = batch
        y_hat_dict = self(x)
        y_hat = y_hat_dict["prediction"]
        loss = self.loss(y_hat, y)
        self.log(
            "val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True
        )
        self.log_metrics(y_hat, y, prefix="val")
        return {"val_loss": loss}

    def test_step(
        self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int
    ) -> STEP_OUTPUT:
        """
        Test step for the model.

        Parameters
        ----------
        batch : Tuple[Dict[str, torch.Tensor]]
            Batch of data containing input and target tensors.
        batch_idx : int
            Index of the batch.

        Returns
        -------
        STEP_OUTPUT
            Dictionary containing the loss and other metrics.
        """
        x, y = batch
        y_hat_dict = self(x)
        y_hat = y_hat_dict["prediction"]
        loss = self.loss(y_hat, y)
        self.log(
            "test_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True
        )
        self.log_metrics(y_hat, y, prefix="test")
        return {"test_loss": loss}

    def predict_step(
        self,
        batch: Tuple[Dict[str, torch.Tensor]],
        batch_idx: int,
        dataloader_idx: int = 0,
    ) -> torch.Tensor:
        """
        Prediction step for the model.

        Parameters
        ----------
        batch : Tuple[Dict[str, torch.Tensor]]
            Batch of data containing input tensors.
        batch_idx : int
            Index of the batch.
        dataloader_idx : int
            Index of the dataloader.

        Returns
        -------
        torch.Tensor
            Predicted output tensor.
        """
        x, _ = batch
        y_hat = self(x)
        return y_hat

    def configure_optimizers(self) -> Dict:
        """
        Configure the optimizer and learning rate scheduler.

        Returns
        -------
        Dict
            Dictionary containing the optimizer and scheduler configuration.
        """
        optimizer = self._get_optimizer()
        if self.lr_scheduler is not None:
            scheduler = self._get_scheduler(optimizer)
            if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                return {
                    "optimizer": optimizer,
                    "lr_scheduler": {
                        "scheduler": scheduler,
                        "monitor": "val_loss",
                    },
                }
            else:
                return {"optimizer": optimizer, "lr_scheduler": scheduler}
        return {"optimizer": optimizer}

    def _get_optimizer(self) -> Optimizer:
        """
        Get the optimizer based on the specified optimizer name and parameters.

        Returns
        -------
        Optimizer
            The optimizer instance.
        """
        if isinstance(self.optimizer, str):
            if self.optimizer.lower() == "adam":
                return torch.optim.Adam(self.parameters(), **self.optimizer_params)
            elif self.optimizer.lower() == "sgd":
                return torch.optim.SGD(self.parameters(), **self.optimizer_params)
            else:
                raise ValueError(f"Optimizer {self.optimizer} not supported.")
        elif isinstance(self.optimizer, Optimizer):
            return self.optimizer
        else:
            raise ValueError(
                "Optimizer must be either a string or "
                "an instance of torch.optim.Optimizer."
            )

    def _get_scheduler(
        self, optimizer: Optimizer
    ) -> torch.optim.lr_scheduler._LRScheduler:
        """
        Get the lr scheduler based on the specified scheduler name and params.

        Parameters
        ----------
        optimizer : Optimizer
            The optimizer instance.

        Returns
        -------
        torch.optim.lr_scheduler._LRScheduler
            The learning rate scheduler instance.
        """
        if self.lr_scheduler.lower() == "reduce_lr_on_plateau":
            return torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, **self.lr_scheduler_params
            )
        elif self.lr_scheduler.lower() == "step_lr":
            return torch.optim.lr_scheduler.StepLR(
                optimizer, **self.lr_scheduler_params
            )
        else:
            raise ValueError(f"Scheduler {self.lr_scheduler} not supported.")

    def log_metrics(
        self, y_hat: torch.Tensor, y: torch.Tensor, prefix: str = "val"
    ) -> None:
        """
        Log additional metrics during training, validation, or testing.

        Parameters
        ----------
        y_hat : torch.Tensor
            Predicted output tensor.
        y : torch.Tensor
            Target output tensor.
        prefix : str
            Prefix for the logged metrics (e.g., "train", "val", "test").
        """
        for metric in self.logging_metrics:
            metric_value = metric(y_hat, y)
            self.log(
                f"{prefix}_{metric.__class__.__name__}",
                metric_value,
                on_step=False,
                on_epoch=True,
                prog_bar=True,
                logger=True,
            )


In [None]:
from typing import Dict, List, Optional, Union

import torch
import torch.nn as nn
from torch.optim import Optimizer


class TFT(BaseModel):
    def __init__(
        self,
        loss: nn.Module,
        logging_metrics: Optional[List[nn.Module]] = None,
        optimizer: Optional[Union[Optimizer, str]] = "adam",
        optimizer_params: Optional[Dict] = None,
        lr_scheduler: Optional[str] = None,
        lr_scheduler_params: Optional[Dict] = None,
        hidden_size: int = 64,
        num_layers: int = 2,
        attention_head_size: int = 4,
        dropout: float = 0.1,
        metadata: Optional[Dict] = None,
        output_size: int = 1,
    ):
        super().__init__(
            loss=loss,
            logging_metrics=logging_metrics,
            optimizer=optimizer,
            optimizer_params=optimizer_params,
            lr_scheduler=lr_scheduler,
            lr_scheduler_params=lr_scheduler_params,
        )
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.attention_head_size = attention_head_size
        self.dropout = dropout
        self.metadata = metadata
        self.output_size = output_size

        self.max_encoder_length = self.metadata["max_encoder_length"]
        self.max_prediction_length = self.metadata["max_prediction_length"]
        self.encoder_cont = self.metadata["encoder_cont"]
        self.encoder_cat = self.metadata["encoder_cat"]
        self.static_categorical_features = self.metadata["static_categorical_features"]
        self.static_continuous_features = self.metadata["static_continuous_features"]

        total_feature_size = self.encoder_cont + self.encoder_cat
        total_static_size = (
            self.static_categorical_features + self.static_continuous_features
        )

        self.encoder_var_selection = nn.Sequential(
            nn.Linear(total_feature_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, total_feature_size),
            nn.Sigmoid(),
        )

        self.decoder_var_selection = nn.Sequential(
            nn.Linear(total_feature_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, total_feature_size),
            nn.Sigmoid(),
        )

        self.static_context_linear = (
            nn.Linear(total_static_size, hidden_size) if total_static_size > 0 else None
        )

        self.lstm_encoder = nn.LSTM(
            input_size=total_feature_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=dropout,
            batch_first=True,
        )

        self.lstm_decoder = nn.LSTM(
            input_size=total_feature_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=dropout,
            batch_first=True,
        )

        self.self_attention = nn.MultiheadAttention(
            embed_dim=hidden_size,
            num_heads=attention_head_size,
            dropout=dropout,
            batch_first=True,
        )

        self.pre_output = nn.Linear(hidden_size, hidden_size)
        self.output_layer = nn.Linear(hidden_size, output_size)

    def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Forward pass of the TFT model.

        Parameters
        ----------
        x : Dict[str, torch.Tensor]
            Dictionary containing input tensors:
            - encoder_cat: Categorical encoder features
            - encoder_cont: Continuous encoder features
            - decoder_cat: Categorical decoder features
            - decoder_cont: Continuous decoder features
            - static_categorical_features: Static categorical features
            - static_continuous_features: Static continuous features

        Returns
        -------
        Dict[str, torch.Tensor]
            Dictionary containing output tensors:
            - prediction: Prediction output (batch_size, prediction_length, output_size)
        """
        batch_size = x["encoder_cont"].shape[0]

        encoder_cat = x.get(
            "encoder_cat",
            torch.zeros(batch_size, self.max_encoder_length, 0, device=self.device),
        )
        encoder_cont = x.get(
            "encoder_cont",
            torch.zeros(batch_size, self.max_encoder_length, 0, device=self.device),
        )
        decoder_cat = x.get(
            "decoder_cat",
            torch.zeros(batch_size, self.max_prediction_length, 0, device=self.device),
        )
        decoder_cont = x.get(
            "decoder_cont",
            torch.zeros(batch_size, self.max_prediction_length, 0, device=self.device),
        )

        encoder_input = torch.cat([encoder_cont, encoder_cat], dim=2)
        decoder_input = torch.cat([decoder_cont, decoder_cat], dim=2)

        static_context = None
        if self.static_context_linear is not None:
            static_cat = x.get(
                "static_categorical_features",
                torch.zeros(batch_size, 0, device=self.device),
            )
            static_cont = x.get(
                "static_continuous_features",
                torch.zeros(batch_size, 0, device=self.device),
            )

            if static_cat.size(2) == 0 and static_cont.size(2) == 0:
                static_context = None
            elif static_cat.size(2) == 0:
                static_input = static_cont.to(
                    dtype=self.static_context_linear.weight.dtype
                )
                static_context = self.static_context_linear(static_input)
                static_context = static_context.view(batch_size, self.hidden_size)
            elif static_cont.size(2) == 0:
                static_input = static_cat.to(
                    dtype=self.static_context_linear.weight.dtype
                )
                static_context = self.static_context_linear(static_input)
                static_context = static_context.view(batch_size, self.hidden_size)
            else:

                static_input = torch.cat([static_cont, static_cat], dim=1).to(
                    dtype=self.static_context_linear.weight.dtype
                )
                static_context = self.static_context_linear(static_input)
                static_context = static_context.view(batch_size, self.hidden_size)

        encoder_weights = self.encoder_var_selection(encoder_input)
        encoder_input = encoder_input * encoder_weights

        decoder_weights = self.decoder_var_selection(decoder_input)
        decoder_input = decoder_input * decoder_weights

        if static_context is not None:
            encoder_static_context = static_context.unsqueeze(1).expand(
                -1, self.max_encoder_length, -1
            )
            decoder_static_context = static_context.unsqueeze(1).expand(
                -1, self.max_prediction_length, -1
            )

            encoder_output, (h_n, c_n) = self.lstm_encoder(encoder_input)
            encoder_output = encoder_output + encoder_static_context
            decoder_output, _ = self.lstm_decoder(decoder_input, (h_n, c_n))
            decoder_output = decoder_output + decoder_static_context
        else:
            encoder_output, (h_n, c_n) = self.lstm_encoder(encoder_input)
            decoder_output, _ = self.lstm_decoder(decoder_input, (h_n, c_n))

        sequence = torch.cat([encoder_output, decoder_output], dim=1)

        if static_context is not None:
            expanded_static_context = static_context.unsqueeze(1).expand(
                -1, sequence.size(1), -1
            )

            attended_output, _ = self.self_attention(
                sequence + expanded_static_context, sequence, sequence
            )
        else:
            attended_output, _ = self.self_attention(sequence, sequence, sequence)

        decoder_attended = attended_output[:, -self.max_prediction_length :, :]

        output = nn.functional.relu(self.pre_output(decoder_attended))
        prediction = self.output_layer(output)

        return {"prediction": prediction}


In [22]:
model = TFT(
    loss=nn.MSELoss(),
    logging_metrics=[MAE(), SMAPE()],
    optimizer="adam",
    optimizer_params={"lr": 1e-3},
    lr_scheduler="reduce_lr_on_plateau",
    lr_scheduler_params={"mode": "min", "factor": 0.1, "patience": 10},
    hidden_size=64,
    num_layers=2,
    attention_head_size=4,
    dropout=0.1,
    metadata=data_module.metadata,
)

print("\nTraining model...")
trainer = Trainer(max_epochs=5, accelerator="auto", devices=1, enable_progress_bar=True)

trainer.fit(model, data_module)

print("\nEvaluating model...")
test_metrics = trainer.test(model, data_module)

model.eval()
with torch.no_grad():
    test_batch = next(iter(data_module.test_dataloader()))
    x_test, y_test = test_batch
    y_pred = model(x_test)

    print("\nPrediction shape:", y_pred["prediction"].shape)
    print("First prediction values:", y_pred["prediction"][0].cpu().numpy())
    print("First true values:", y_test[0].cpu().numpy())
print("\nTFT model test complete!")

INFO: You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
INFO:lightning.pytorch.utilities.rank_zero:You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
INFO: GPU available: False, used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: 
  | Name                  | Type               | Params | Mode 
---------------------------------------------------------------------
0 | loss                  | MSELoss            | 0      | train
1 | encoder_var_selection | Sequential         | 709    | train
2 | decoder_var_selection | Sequential         


Training model...


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.11/dist-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (42) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=5` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.



Evaluating model...


Testing: |          | 0/? [00:00<?, ?it/s]


Prediction shape: torch.Size([32, 1, 1])
First prediction values: [[-0.06341379]]
First true values: [[0.08132173]]

TFT model test complete!
