# Quick Start: Running Chronos and Chronos-Bolt models on BOOM benchmark

This notebook is adapted from the [GiftEval repository](https://github.com/SalesforceAIResearch/gift-eval/tree/main/notebooks) and shows how to run Chronos and Chronos-Bolt models on the BOOM benchmark.

Make sure you download the BOOM benchmark and set the `BOOM` environment variable correctly before running this notebook.

We will use the `Dataset` class from GiftEval to load the data and run the model.

Download BOOM datasets. Calling `download_boom_benchmark` also sets the `BOOM` environment variable with the correct path, which is needed for running the evals below.

In [None]:
import os
import json
from dotenv import load_dotenv
from dataset_utils import download_boom_benchmark

boom_path = "ChangeMe"
download_boom_benchmark(boom_path)
load_dotenv()

dataset_properties_map = json.load(open("eval/boom/dataset_properties.json"))
all_datasets = list(dataset_properties_map.keys())
print(len(all_datasets))      

Now we define the metrics that we want to compute for each dataset.

In [None]:
from gluonts.ev.metrics import (
    MAE,
    MAPE,
    MASE,
    MSE,
    MSIS,
    ND,
    NRMSE,
    RMSE,
    SMAPE,
    MeanWeightedSumQuantileLoss,
)

# Instantiate the metrics
metrics = [
    MSE(forecast_type="mean"),
    MSE(forecast_type=0.5),
    MAE(),
    MASE(),
    MAPE(),
    SMAPE(),
    MSIS(),
    RMSE(),
    NRMSE(),
    ND(),
    MeanWeightedSumQuantileLoss(quantile_levels=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]),
]

## Chronos Predictor

For foundation models, we need to implement a wrapper containing the model and use the wrapper to generate predicitons.

This is just meant to be a simple wrapper to get you started, feel free to use your own custom implementation to wrap any model.

In [None]:
from dataclasses import dataclass, field
from typing import List, Optional

import numpy as np
import torch
from chronos import BaseChronosPipeline, ForecastType
from gluonts.itertools import batcher
from gluonts.model import Forecast
from gluonts.model.forecast import QuantileForecast, SampleForecast
from tqdm.auto import tqdm


@dataclass
class ModelConfig:
    quantile_levels: Optional[List[float]] = None
    forecast_keys: List[str] = field(init=False)
    statsforecast_keys: List[str] = field(init=False)
    intervals: Optional[List[int]] = field(init=False)

    def __post_init__(self):
        self.forecast_keys = ["mean"]
        self.statsforecast_keys = ["mean"]
        if self.quantile_levels is None:
            self.intervals = None
            return

        intervals = set()

        for quantile_level in self.quantile_levels:
            interval = round(200 * (max(quantile_level, 1 - quantile_level) - 0.5))
            intervals.add(interval)
            side = "hi" if quantile_level > 0.5 else "lo"
            self.forecast_keys.append(str(quantile_level))
            self.statsforecast_keys.append(f"{side}-{interval}")

        self.intervals = sorted(intervals)


class ChronosPredictor:
    def __init__(
        self,
        model_path,
        num_samples: int,
        prediction_length: int,
        *args,
        **kwargs,
    ):
        print("prediction_length:", prediction_length)
        self.pipeline = BaseChronosPipeline.from_pretrained(
            model_path,
            *args,
            **kwargs,
        )
        self.prediction_length = prediction_length
        self.num_samples = num_samples
        print("forecast type:", self.pipeline.forecast_type)

    def predict(self, test_data_input, batch_size: int = 1024) -> List[Forecast]:
        pipeline = self.pipeline
        predict_kwargs = {"num_samples": self.num_samples} if pipeline.forecast_type == ForecastType.SAMPLES else {}
        while True:
            try:
                # Generate forecast samples
                forecast_outputs = []
                for batch in tqdm(batcher(test_data_input, batch_size=batch_size)):
                    context = [torch.tensor(entry["target"]) for entry in batch]
                    forecast_outputs.append(
                        pipeline.predict(
                            context,
                            prediction_length=self.prediction_length,
                            **predict_kwargs,
                        ).numpy()
                    )
                forecast_outputs = np.concatenate(forecast_outputs)
                break
            except torch.cuda.OutOfMemoryError:
                print(f"OutOfMemoryError at batch_size {batch_size}, reducing to {batch_size // 2}")
                batch_size //= 2

        # Convert forecast samples into gluonts Forecast objects
        forecasts = []
        for item, ts in zip(forecast_outputs, test_data_input):
            forecast_start_date = ts["start"] + len(ts["target"])

            if pipeline.forecast_type == ForecastType.SAMPLES:
                forecasts.append(SampleForecast(samples=item, start_date=forecast_start_date))
            elif pipeline.forecast_type == ForecastType.QUANTILES:
                forecasts.append(
                    QuantileForecast(
                        forecast_arrays=item,
                        forecast_keys=list(map(str, pipeline.quantiles)),
                        start_date=forecast_start_date,
                    )
                )

        return forecasts

## Evaluation

Now that we have our predictor class, we can use it to predict on the BOOM datasets. We will use the `evaluate_model` function from `gluonts` to evaluate the model. Then we store the results in a csv file called `all_results.csv` under the `results/{model name}` folder.

The first column in the csv file is the dataset config name which is a combination of the dataset name, frequency and the term:

```python
f"{dataset_name}/{freq}/{term}"
```


In [None]:
import logging


class WarningFilter(logging.Filter):
    def __init__(self, text_to_filter):
        super().__init__()
        self.text_to_filter = text_to_filter

    def filter(self, record):
        return self.text_to_filter not in record.getMessage()


gts_logger = logging.getLogger("gluonts.model.forecast")
gts_logger.addFilter(WarningFilter("The mean prediction is not stored in the forecast data"))

In [None]:
import csv
import os

from gluonts.model import evaluate_model
from gluonts.time_feature import get_seasonality

from gift_eval.data import Dataset

# Iterate over all available datasets

model_name = "chronos_bolt_base"
output_dir = f"ChangeMe/{model_name}"
# Ensure the output directory exists
os.makedirs(output_dir, exist_ok=True)

# Define the path for the CSV file
csv_file_path = os.path.join(output_dir, "all_results.csv")

with open(csv_file_path, "w", newline="") as csvfile:
    writer = csv.writer(csvfile)

    # Write the header
    writer.writerow(
        [
            "dataset",
            "model",
            "eval_metrics/MSE[mean]",
            "eval_metrics/MSE[0.5]",
            "eval_metrics/MAE[0.5]",
            "eval_metrics/MASE[0.5]",
            "eval_metrics/MAPE[0.5]",
            "eval_metrics/sMAPE[0.5]",
            "eval_metrics/MSIS",
            "eval_metrics/RMSE[mean]",
            "eval_metrics/NRMSE[mean]",
            "eval_metrics/ND[0.5]",
            "eval_metrics/mean_weighted_sum_quantile_loss",
            "domain",
            "num_variates",
            "dataset_size",
        ]
    )

for ds_num, ds_name in enumerate(all_datasets):
    print(f"Processing dataset: {ds_name} ({ds_num + 1} of {len(all_datasets)})")
    dataset_term = dataset_properties_map[ds_name]["term"]
    terms = ["short", "medium", "long"]
    for term in terms:
        if (term == "medium" or term == "long") and dataset_term == "short":
            continue

        ds_freq = dataset_properties_map[ds_name]["frequency"]
        ds_config = f"{ds_name}/{ds_freq}/{term}"

        # Initialize the dataset
        to_univariate = False if Dataset(name=ds_name, term=term,to_univariate=False,storage_env_var="BOOM").target_dim == 1 else True

        dataset = Dataset(name=ds_name, term=term, to_univariate=to_univariate,storage_env_var="BOOM")

        season_length = get_seasonality(dataset.freq)
        dataset_size = len(dataset.test_data)
        print(f"Dataset size: {dataset_size}")
        predictor = ChronosPredictor(
            # use "amazon/chronos-t5-base" for the corresponding original Chronos model
            model_path="amazon/chronos-bolt-base",
            num_samples=20,
            prediction_length=dataset.prediction_length,
            device_map="cuda:0",
        )
        
        res = evaluate_model(
            predictor,
            test_data=dataset.test_data,
            metrics=metrics,
            batch_size=1024,
            axis=None,
            mask_invalid_label=True,
            allow_nan_forecast=False,
            seasonality=season_length,
        )

        # Append the results to the CSV file
        with open(csv_file_path, "a", newline="") as csvfile:
            writer = csv.writer(csvfile)
            writer.writerow(
                [
                    ds_config,
                    model_name,
                    res["MSE[mean]"][0],
                    res["MSE[0.5]"][0],
                    res["MAE[0.5]"][0],
                    res["MASE[0.5]"][0],
                    res["MAPE[0.5]"][0],
                    res["sMAPE[0.5]"][0],
                    res["MSIS"][0],
                    res["RMSE[mean]"][0],
                    res["NRMSE[mean]"][0],
                    res["ND[0.5]"][0],
                    res["mean_weighted_sum_quantile_loss"][0],
                    dataset_properties_map[ds_name]["domain"],
                    dataset_properties_map[ds_name]["num_variates"],
                    dataset_size,
                ]
            )

        print(f"Results for {ds_name} have been written to {csv_file_path}")

## Results

Running the above cell will generate a csv file called `all_results.csv` under the `results/chronos` folder containing the results for the Chronos model on the boom benchmark. We can display the csv file using the follow code:

In [None]:
import pandas as pd
df = pd.read_csv(output_dir + "/all_results.csv")
df