# Quick Start: Running timesfm models on gift-eval benchmark

This notebook is adapted from the [GiftEval repository](https://github.com/SalesforceAIResearch/gift-eval/tree/main/notebooks) and shows how to run Timesfm-2.0 on BOOM.

Make sure you download the BOOM benchmark and set the `BOOM` environment variable correctly before running this notebook.

We will use the `Dataset` class to load the data and run the model. If you have not already please check out the [dataset.ipynb](./dataset.ipynb) notebook to learn more about the `Dataset` class. We are going to just run the model on two datasets for brevity. But feel free to run on any dataset by changing the `short_datasets` and `med_long_datasets` variables below.

Download BOOM datasets. Calling `download_boom_benchmark` also sets the `BOOM` environment variable with the correct path, which is needed for running the evals below.

In [None]:
import os
import json
from dotenv import load_dotenv
from dataset_utils import download_boom_benchmark

boom_path = "ChangeMe"
download_boom_benchmark(boom_path)
load_dotenv()

dataset_properties_map = json.load(open("boom/dataset_properties.json"))
all_datasets = list(dataset_properties_map.keys())
print(len(all_datasets))

In [None]:
from gluonts.ev.metrics import (
    MAE,
    MAPE,
    MASE,
    MSE,
    MSIS,
    ND,
    NRMSE,
    RMSE,
    SMAPE,
    MeanWeightedSumQuantileLoss,
)

# Instantiate the metrics
metrics = [
    MSE(forecast_type="mean"),
    MSE(forecast_type=0.5),
    MAE(),
    MASE(),
    MAPE(),
    SMAPE(),
    MSIS(),
    RMSE(),
    NRMSE(),
    ND(),
    MeanWeightedSumQuantileLoss(quantile_levels=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]),
]

## TimesFM Predictor

For foundation models, we need to implement a wrapper containing the model and use the wrapper to generate predicitons.

This is just meant to be a simple wrapper to get you started, feel free to use your own custom implementation to wrap any model.

### Lets first load the timesfm model

In [None]:
import timesfm

# tfm = timesfm.TimesFm(
#     hparams=timesfm.TimesFmHparams(
#         backend="gpu",
#         per_core_batch_size=32,
#         num_layers=50,
#         horizon_len=128,
#         context_len=2048,
#         use_positional_embedding=False,
#         output_patch_len=128,
#     ),
#     checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id="google/timesfm-2.0-500m-jax"),
# )

# If you are using the pytorch version:
tfm = timesfm.TimesFm(
    hparams=timesfm.TimesFmHparams(
        backend="gpu",
        per_core_batch_size=32,
        num_layers=50,
        horizon_len=128,
        context_len=2048,
        use_positional_embedding=False,
        output_patch_len=128,
    ),
    checkpoint=timesfm.TimesFmCheckpoint(
        huggingface_repo_id="google/timesfm-2.0-500m-pytorch"),
)

In [None]:
from typing import List
import numpy as np
from tqdm.auto import tqdm
from gluonts.itertools import batcher
from gluonts.model import Forecast
from gluonts.model.forecast import QuantileForecast


class TimesFmPredictor:

    def __init__(
        self,
        tfm,
        prediction_length: int,
        ds_freq: str,
        *args,
        **kwargs,
    ):
        self.tfm = tfm
        self.prediction_length = prediction_length
        if self.prediction_length > self.tfm.horizon_len:
            self.tfm.horizon_len = (
                (self.prediction_length + self.tfm.output_patch_len - 1) // self.tfm.output_patch_len
            ) * self.tfm.output_patch_len
            print("Jitting for new prediction length.")
        self.freq = timesfm.freq_map(ds_freq)
        print("frequency key:", ds_freq)
        print("frequency:", self.freq)

    def predict(self, test_data_input, batch_size: int = 128) -> List[Forecast]:
        forecast_outputs = []
        for batch in tqdm(batcher(test_data_input, batch_size=batch_size)):
            context = []
            for entry in batch:
                arr = np.array(entry["target"])
                context.append(arr)
            freqs = [self.freq] * len(context)
            _, full_preds = self.tfm.forecast(context, freqs, normalize=True)
            full_preds = full_preds[:, 0 : self.prediction_length, 1:]
            forecast_outputs.append(full_preds.transpose((0, 2, 1)))
        forecast_outputs = np.concatenate(forecast_outputs)

        # Convert forecast samples into gluonts Forecast objects
        forecasts = []
        for item, ts in zip(forecast_outputs, test_data_input):
            forecast_start_date = ts["start"] + len(ts["target"])
            forecasts.append(
                QuantileForecast(
                    forecast_arrays=item,
                    forecast_keys=list(map(str, self.tfm.quantiles)),
                    start_date=forecast_start_date,
                )
            )

        return forecasts

## Evaluation

Now that we have our predictor class, we can use it to predict on the gift-eval benchmark datasets. We will use the `evaluate_model` function to evaluate the model. This function is a helper function to evaluate the model on the test data and return the results in a dictionary. We are going to follow the naming conventions explained in the [README](../README.md) file to store the results in a csv file called `all_results.csv` under the `results/timesfm_2_0_500m` folder.

The first column in the csv file is the dataset config name which is a combination of the dataset name, frequency and the term:

```python
f"{dataset_name}/{freq}/{term}"
```
Note that we try to replace the results with the baseline results whenever the model yield nan forecasts.

In [None]:
import logging


class WarningFilter(logging.Filter):
    def __init__(self, text_to_filter):
        super().__init__()
        self.text_to_filter = text_to_filter

    def filter(self, record):
        return self.text_to_filter not in record.getMessage()


gts_logger = logging.getLogger("gluonts.model.forecast")
gts_logger.addFilter(WarningFilter("The mean prediction is not stored in the forecast data"))

### Ordering all dataset settings from lowest to highest prediction length to minimize the number of jittings. This is not necessary for the pytorch version.

In [None]:
import csv
import os

from gluonts.model import evaluate_model
from gluonts.time_feature import get_seasonality

from gift_eval.data import Dataset


### Evaluating on all settings

In [None]:
model_name = "timesfm_2_0_500m"
output_dir = f"ChangeMe/{model_name}"
# Ensure the output directory exists
os.makedirs(output_dir, exist_ok=True)

# Define the path for the CSV file
csv_file_path = os.path.join(output_dir, "all_results.csv")

with open(csv_file_path, "w", newline="") as csvfile:
    writer = csv.writer(csvfile)

    # Write the header
    writer.writerow(
        [
            "dataset",
            "model",
            "eval_metrics/MSE[mean]",
            "eval_metrics/MSE[0.5]",
            "eval_metrics/MAE[0.5]",
            "eval_metrics/MASE[0.5]",
            "eval_metrics/MAPE[0.5]",
            "eval_metrics/sMAPE[0.5]",
            "eval_metrics/MSIS",
            "eval_metrics/RMSE[mean]",
            "eval_metrics/NRMSE[mean]",
            "eval_metrics/ND[0.5]",
            "eval_metrics/mean_weighted_sum_quantile_loss",
            "domain",
            "num_variates",
            "dataset_size",
        ]
    )
for ds_num, ds_name in enumerate(all_datasets):
    dataset_term = dataset_properties_map[ds_name]["term"]
    terms = ["short", "medium", "long"]
    for term in terms:
        if (term == "medium" or term == "long") and dataset_term == "short":
            continue
        ds_freq = dataset_properties_map[ds_name]["frequency"]
        ds_config = f"{ds_name}/{ds_freq}/{term}"
        to_univariate = False if Dataset(name=ds_name, term=term, to_univariate=False,storage_env_var="BOOM").target_dim == 1 else True
        dataset = Dataset(name=ds_name, term=term, to_univariate=to_univariate,storage_env_var="BOOM")
        season_length = get_seasonality(dataset.freq)
        dataset_size = len(dataset.test_data)
        print(f"Dataset size: {dataset_size}")
        predictor = TimesFmPredictor(
            tfm=tfm,
            prediction_length=dataset.prediction_length,
            ds_freq=dataset.freq,
        )
        # Measure the time taken for evaluation
        try:
            res = evaluate_model(
                predictor,
                test_data=dataset.test_data,
                metrics=metrics,
                batch_size=1024,
                axis=None,
                mask_invalid_label=True,
                allow_nan_forecast=False,
                seasonality=season_length,
            )
        except Exception as e:
            if "NaN" in str(e):
                print(f"replacing results of {ds_name} with seasonal naive scores due to NaN values")
                res = pd.read_csv(f"ChangeMe/seasonalnaive/all_results.csv")
                prefix = "eval_metrics/"
                res.columns = [col[len(prefix):] if col.startswith(prefix) else col for col in res.columns]
                res = res[res["dataset"]==ds_config]
                res = res.reset_index(drop=True)
            else:
                raise e 
    
        # Append the results to the CSV file
        with open(csv_file_path, "a", newline="") as csvfile:
            writer = csv.writer(csvfile)
            writer.writerow(
                [
                    ds_config,
                    model_name,
                    res["MSE[mean]"][0],
                    res["MSE[0.5]"][0],
                    res["MAE[0.5]"][0],
                    res["MASE[0.5]"][0],
                    res["MAPE[0.5]"][0],
                    res["sMAPE[0.5]"][0],
                    res["MSIS"][0],
                    res["RMSE[mean]"][0],
                    res["NRMSE[mean]"][0],
                    res["ND[0.5]"][0],
                    res["mean_weighted_sum_quantile_loss"][0],
                    dataset_properties_map[ds_name]["domain"],
                    dataset_properties_map[ds_name]["num_variates"],
                    dataset_size,
                ]
            )
    
        print(f"Results for {ds_name} have been written to {csv_file_path}")

## Results

Running the above cell will generate a csv file called `all_results.csv` under the `results/timesfm` folder containing the results for the Chronos model on the gift-eval benchmark. We can display the csv file using the follow code:

In [None]:
import pandas as pd
df = pd.read_csv(output_dir + "/all_results.csv")
df