# Quickstart - Run TiRex on GiftEval

This notebook shows how to run [TiRex](https://github.com/NX-AI/tirex) on the gift-eval benchmark.

Make sure you download the gift-eval benchmark and set the `GIFT-EVAL` environment variable correctly before running this notebook.


## Setup Instructions

Before proceeding, ensure you have the following:
(Note: You need a Nvidia GPU with [CUDA compute capabality >= 8.0](https://developer.nvidia.com/cuda-gpus))

1. **Optional but suggested: Install conda environment specifed in TiRex Repo**

```bash
git clone github.com/NX-AI/tirex
conda env create --file ./tirex/requirements_py26.yaml
conda activate tirex
```

2. **Install TiRex**

```bash
git clone github.com/NX-AI/tirex  # if not cloned before
cd tirex
pip install .  # install tirex
```

3. **Install additional dependecies needed for GiftEval benchmark**

```bash
pip install gluonts dotenv datasets
```

In [None]:
#
# This is the data.py file as in GiftEval but fixed so that one can run it with numpy >2.0.0
# (Extended/Fixed the frequency alias)
#

import os
import math
from functools import cached_property
from enum import Enum
from pathlib import Path
from typing import Iterable, Iterator

import datasets
from dotenv import load_dotenv
from gluonts.dataset import DataEntry
from gluonts.dataset.common import ProcessDataEntry
from gluonts.dataset.split import TestData, TrainingDataset, split
from gluonts.itertools import Map
from gluonts.time_feature import norm_freq_str
from gluonts.transform import Transformation
from pandas.tseries.frequencies import to_offset
import pyarrow.compute as pc
from toolz import compose

TEST_SPLIT = 0.1
MAX_WINDOW = 20

M4_PRED_LENGTH_MAP = {
    "A": 6,
    "Q": 8,
    "M": 18,
    "W": 13,
    "D": 14,
    "H": 48,
    # new version fix:
    "h": 48,
    "Y": 6,

}

PRED_LENGTH_MAP = {
    "M": 12,
    "W": 8,
    "D": 30,
    "H": 48,
    "T": 48,
    "S": 60,
    # new version fix:
    "h": 48,
    "s": 60,
    "min": 48,
}

TFB_PRED_LENGTH_MAP = {
    "A": 6,
    "H": 48,
    "Q": 8,
    "D": 14,
    "M": 18,
    "W": 13,
    "U": 8,
    "T": 8,
    # new version fix:
    "min": 8,
    "us": 8,
    "Y": 6,
    "h": 48,
}


class Term(Enum):
    SHORT = "short"
    MEDIUM = "medium"
    LONG = "long"

    @property
    def multiplier(self) -> int:
        if self == Term.SHORT:
            return 1
        elif self == Term.MEDIUM:
            return 10
        elif self == Term.LONG:
            return 15


def itemize_start(data_entry: DataEntry) -> DataEntry:
    data_entry["start"] = data_entry["start"].item()
    return data_entry


class MultivariateToUnivariate(Transformation):
    def __init__(self, field):
        self.field = field

    def __call__(
        self, data_it: Iterable[DataEntry], is_train: bool = False
    ) -> Iterator:
        for data_entry in data_it:
            item_id = data_entry["item_id"]
            val_ls = list(data_entry[self.field])
            for id, val in enumerate(val_ls):
                univariate_entry = data_entry.copy()
                univariate_entry[self.field] = val
                univariate_entry["item_id"] = item_id + "_dim" + str(id)
                yield univariate_entry


class Dataset:
    def __init__(
        self,
        name: str,
        term: Term | str = Term.SHORT,
        to_univariate: bool = False,
        storage_env_var: str = "GIFT_EVAL",
    ):
        load_dotenv()
        storage_path = Path(os.getenv(storage_env_var))
        self.hf_dataset = datasets.load_from_disk(str(storage_path / name)).with_format(
            "numpy"
        )
        process = ProcessDataEntry(
            self.freq,
            one_dim_target=self.target_dim == 1,
        )

        self.gluonts_dataset = Map(compose(process, itemize_start), self.hf_dataset)
        if to_univariate:
            self.gluonts_dataset = MultivariateToUnivariate("target").apply(
                self.gluonts_dataset
            )

        self.term = Term(term)
        self.name = name

    @cached_property
    def prediction_length(self) -> int:
        freq = norm_freq_str(to_offset(self.freq).name)
        if freq.endswith("E"):
            freq = freq[:-1]
        pred_len = (
            M4_PRED_LENGTH_MAP[freq] if "m4" in self.name else PRED_LENGTH_MAP[freq]
        )
        return self.term.multiplier * pred_len

    @cached_property
    def freq(self) -> str:
        return self.hf_dataset[0]["freq"]

    @cached_property
    def target_dim(self) -> int:
        return (
            target.shape[0]
            if len((target := self.hf_dataset[0]["target"]).shape) > 1
            else 1
        )

    @cached_property
    def past_feat_dynamic_real_dim(self) -> int:
        if "past_feat_dynamic_real" not in self.hf_dataset[0]:
            return 0
        elif (
            len(
                (
                    past_feat_dynamic_real := self.hf_dataset[0][
                        "past_feat_dynamic_real"
                    ]
                ).shape
            )
            > 1
        ):
            return past_feat_dynamic_real.shape[0]
        else:
            return 1

    @cached_property
    def windows(self) -> int:
        if "m4" in self.name:
            return 1
        w = math.ceil(TEST_SPLIT * self._min_series_length / self.prediction_length)
        return min(max(1, w), MAX_WINDOW)

    @cached_property
    def _min_series_length(self) -> int:
        if self.hf_dataset[0]["target"].ndim > 1:
            lengths = pc.list_value_length(
                pc.list_flatten(
                    pc.list_slice(self.hf_dataset.data.column("target"), 0, 1)
                )
            )
        else:
            lengths = pc.list_value_length(self.hf_dataset.data.column("target"))
        return min(lengths.to_numpy())

    @cached_property
    def sum_series_length(self) -> int:
        if self.hf_dataset[0]["target"].ndim > 1:
            lengths = pc.list_value_length(
                pc.list_flatten(self.hf_dataset.data.column("target"))
            )
        else:
            lengths = pc.list_value_length(self.hf_dataset.data.column("target"))
        return sum(lengths.to_numpy())

    @property
    def training_dataset(self) -> TrainingDataset:
        training_dataset, _ = split(
            self.gluonts_dataset, offset=-self.prediction_length * (self.windows + 1)
        )
        return training_dataset

    @property
    def validation_dataset(self) -> TrainingDataset:
        validation_dataset, _ = split(
            self.gluonts_dataset, offset=-self.prediction_length * self.windows
        )
        return validation_dataset

    @property
    def test_data(self) -> TestData:
        _, test_template = split(
            self.gluonts_dataset, offset=-self.prediction_length * self.windows
        )
        test_data = test_template.generate_instances(
            prediction_length=self.prediction_length,
            windows=self.windows,
            distance=self.prediction_length,
        )
        return test_data


In [None]:
from typing import Any
from dataclasses import dataclass
import json
import logging
from pathlib import Path
from gluonts.ev.metrics import (
    MSE,
    MAE,
    MASE,
    MAPE,
    SMAPE,
    MSIS,
    RMSE,
    NRMSE,
    ND,
    MeanWeightedSumQuantileLoss,
)
from gluonts.model import evaluate_model
from gluonts.time_feature import get_seasonality
import json
import pandas as pd

# avoid exessive logging
class WarningFilter(logging.Filter):
    def __init__(self, text_to_filter):
        super().__init__()
        self.text_to_filter = text_to_filter

    def filter(self, record):
        return self.text_to_filter not in record.getMessage()
gts_logger = logging.getLogger("gluonts.model.forecast")
gts_logger.addFilter(
    WarningFilter("The mean prediction is not stored in the forecast data")
)

SHORT_DATA = "m4_yearly m4_quarterly m4_monthly m4_weekly m4_daily m4_hourly electricity/15T electricity/H electricity/D electricity/W solar/10T solar/H solar/D solar/W hospital covid_deaths us_births/D us_births/M us_births/W saugeenday/D saugeenday/M saugeenday/W temperature_rain_with_missing kdd_cup_2018_with_missing/H kdd_cup_2018_with_missing/D car_parts_with_missing restaurant hierarchical_sales/D hierarchical_sales/W LOOP_SEATTLE/5T LOOP_SEATTLE/H LOOP_SEATTLE/D SZ_TAXI/15T SZ_TAXI/H M_DENSE/H M_DENSE/D ett1/15T ett1/H ett1/D ett1/W ett2/15T ett2/H ett2/D ett2/W jena_weather/10T jena_weather/H jena_weather/D bitbrains_fast_storage/5T bitbrains_fast_storage/H bitbrains_rnd/5T bitbrains_rnd/H bizitobs_application bizitobs_service bizitobs_l2c/5T bizitobs_l2c/H"
MED_LONG_DATA = "electricity/15T electricity/H solar/10T solar/H kdd_cup_2018_with_missing/H LOOP_SEATTLE/5T LOOP_SEATTLE/H SZ_TAXI/15T M_DENSE/H ett1/15T ett1/H ett2/15T ett2/H jena_weather/10T jena_weather/H bitbrains_fast_storage/5T bitbrains_rnd/5T bizitobs_application bizitobs_service bizitobs_l2c/5T bizitobs_l2c/H"
PRETTY_NAMES = {
    "saugeenday": "saugeen",
    "temperature_rain_with_missing": "temperature_rain",
    "kdd_cup_2018_with_missing": "kdd_cup_2018",
    "car_parts_with_missing": "car_parts",
}
ALL_DATASETS = list(set(SHORT_DATA.split() + MED_LONG_DATA.split()))

METRICS = [
        MSE(forecast_type="mean"),
        MSE(forecast_type=0.5),
        MAE(),
        MASE(),
        MAPE(),
        SMAPE(),
        MSIS(),
        RMSE(),
        NRMSE(),
        ND(),
        MeanWeightedSumQuantileLoss(
            quantile_levels=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
        ),
    ]

try:
    
    with open(Path.cwd() / "dataset_properties.json", 'r') as f:
        dataset_properties_map = json.load(f)
except FileNotFoundError:
    raise ValueError("Can not find needed dataset_properties.json file!")


def gift_eval_dataset_iter():
    for ds_num, ds_name in enumerate(ALL_DATASETS):
        ds_key = ds_name.split("/")[0]
        terms = ["short", "medium", "long"]
        for term in terms:
            if (
                term == "medium" or term == "long"
            ) and ds_name not in MED_LONG_DATA.split():
                continue

            if "/" in ds_name:
                ds_key = ds_name.split("/")[0]
                ds_freq = ds_name.split("/")[1]
                ds_key = ds_key.lower()
                ds_key = PRETTY_NAMES.get(ds_key, ds_key)
            else:
                ds_key = ds_name.lower()
                ds_key = PRETTY_NAMES.get(ds_key, ds_key)
                ds_freq = dataset_properties_map[ds_key]["frequency"]
            yield {"ds_name": ds_name,"ds_key": ds_key,
                   "ds_freq": ds_freq,"term": term}


# Setup GiftEval Evalution
def evaluate_dataset(predictor, ds_name, ds_key, ds_freq, term):
    print(f"Processing dataset: {ds_name}")
    ds_config = f"{ds_key}/{ds_freq}/{term}"
    # Initialize the dataset
    to_univariate = (
        False
        if Dataset(name=ds_name, term=term, to_univariate=False).target_dim == 1
        else True
    )
    dataset = Dataset(name=ds_name, term=term, to_univariate=to_univariate)
    predictor.set_prediction_len(dataset.prediction_length)
    predictor.set_ds_freq(ds_freq)
    season_length = get_seasonality(dataset.freq)

    # Measure the time taken for evaluation
    res = evaluate_model(
        predictor,
        test_data=dataset.test_data,
        metrics=METRICS,
        batch_size=1024,
        axis=None,
        mask_invalid_label=True,
        allow_nan_forecast=False,
        seasonality=season_length,
    )
    result = {
        "dataset": ds_config,
        "model": predictor.model_id,
        "eval_metrics/MSE[mean]": res["MSE[mean]"][0],
        "eval_metrics/MSE[0.5]": res["MSE[0.5]"][0],
        "eval_metrics/MAE[0.5]": res["MAE[0.5]"][0],
        "eval_metrics/MASE[0.5]": res["MASE[0.5]"][0],
        "eval_metrics/MAPE[0.5]": res["MAPE[0.5]"][0],
        "eval_metrics/sMAPE[0.5]": res["sMAPE[0.5]"][0],
        "eval_metrics/MSIS": res["MSIS"][0],
        "eval_metrics/RMSE[mean]": res["RMSE[mean]"][0],
        "eval_metrics/NRMSE[mean]": res["NRMSE[mean]"][0],
        "eval_metrics/ND[0.5]": res["ND[0.5]"][0],
        "eval_metrics/mean_weighted_sum_quantile_loss": res["mean_weighted_sum_quantile_loss"][0],
        "domain": dataset_properties_map[ds_key]["domain"],
        "num_variates": dataset_properties_map[ds_key]["num_variates"],
    }
    return result


@dataclass
class TiRexGiftEvalWrapper():
    model: Any
    freq: str = None
    pred_len: int = 32

    def set_ds_freq(self, freq):
        self.freq = freq

    def set_prediction_len(self, pred_len):
        self.pred_len = pred_len

    def predict(self, test_data_input):
        forecasts = self.model.forecast_gluon(test_data_input, prediction_length=self.pred_len, output_type="gluonts")
        return forecasts

    @property
    def model_id(self):
        return "TiRex"

### This Runs the Evalutions 
(make sure to execute the blocks before to init all depdent classes)

In [None]:
import os
import pandas as pd
from tirex import load_model, ForecastModel

#os.environ["GIFT-EVAL"] = "/path/to/gifteval" # Set GIFT-EVAL data path if not already specified

model : ForecastModel = load_model("NX-AI/TiRex")

wrapped_model = TiRexGiftEvalWrapper(model)
results = []
for task in gift_eval_dataset_iter():
    task_result = evaluate_dataset(wrapped_model, **task)
    results.append(task_result)
    print(task_result)
results = pd.DataFrame(results)
print(results)
results.to_csv("./all_results.csv", index=False)