# Quick Start: Running Foundation Model Toto on BOOM benchmark

This notebook shows how to run the Toto on the BOOM benchmark.

Make sure you download the BOOM benchmark and set the `BOOM` environment variable correctly before running this notebook.

We will use the `Dataset` class from GiftEval to load the data and run the model.

In [None]:
# Set the working directory to the Toto module
%cd ../../toto

Download BOOM datasets. Calling `download_boom_benchmark` also sets the `BOOM` environment variable with the correct path, which is needed for running the evals below.

In [None]:
import os
import json
import csv
import time
from dotenv import load_dotenv
from gluonts.model import evaluate_model
from gluonts.time_feature import get_seasonality
from gift_eval.data import Dataset
from dataset_utils import download_boom_benchmark
from model.toto import Toto
from inference.gluonts_predictor import TOTOPredictor, Multivariate
import torch

boom_path = "ChangeMe"
download_boom_benchmark(boom_path)
load_dotenv()

dataset_properties_map = json.load(open("../boom/dataset_properties.json"))
all_datasets = list(dataset_properties_map.keys())
print(len(all_datasets))  

In [None]:
from gluonts.ev.metrics import (
    MSE,
    MAE,
    MASE,
    MAPE,
    SMAPE,
    MSIS,
    RMSE,
    NRMSE,
    ND,
    MeanWeightedSumQuantileLoss,
)

# Instantiate the metrics
metrics = [
    MSE(forecast_type="mean"),
    MSE(forecast_type=0.5),
    MAE(),
    MASE(),
    MAPE(),
    SMAPE(),
    MSIS(),
    RMSE(),
    NRMSE(),
    ND(),
    MeanWeightedSumQuantileLoss(quantile_levels=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]),
]

## Toto Predictor
Load Toto model

In [None]:
toto = Toto.from_pretrained('Datadog/Toto-Open-Base-1.0')
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
toto.to(DEVICE)


class TOTOModelPredictorWrapper:
    def __init__(self, model, prediction_length, context_length, mode, num_samples=128, use_kv_cache=True):
        """
        Initialize the predictor wrapper with specified parameters.

        Args:
            model: The PyTorch model to be used for predictions.
            prediction_length: The length of the prediction horizon.
            context_length: The length of the context window.
            mode: Mode of prediction (e.g., "forecast").
            initial_samples_per_batch: Starting batch size for predictions (even number).
            num_samples: Total number of samples to generate.
            use_kv_cache: Whether to use key-value caching.
        """
        self.model = torch.compile(model)
        self.prediction_length = prediction_length
        self.context_length = context_length
        self.mode = mode
        self.num_samples = num_samples
        self.use_kv_cache = use_kv_cache
        self.samples_per_batch = num_samples
        self.model = model
        self.predictor = None
        self._adjusted = False  # Tracks whether adjustment has been done

        self._initialize_predictor()

    def _initialize_predictor(self):
        """
        Initialize the TOTOPredictor with the current samples_per_batch.
        """
        self.predictor = TOTOPredictor.create_for_eval(
            model=self.model,
            prediction_length=self.prediction_length,
            context_length=self.context_length,
            mode=self.mode,
            samples_per_batch=self.samples_per_batch,
        )

    def predict(self, gluonts_test_data: tuple):
        """
        Perform prediction while adjusting num_samples and samples_per_batch if OOM errors occur.
        """
        # Adjust samples_per_batch only on the first prediction call
        if not self._adjusted:
            print("Initializing predictor with samples_per_batch =", self.samples_per_batch)

            while self.samples_per_batch >= 1:
                try:
                    print(f"Attempting prediction with samples_per_batch = {self.samples_per_batch}")
                    # Consume the generator here to catch any exceptions
                    predictions = list(
                        self.predictor.predict(
                            gluonts_test_data, use_kv_cache=self.use_kv_cache, num_samples=self.num_samples
                        )
                    )

                    self._adjusted = True

                    return predictions  # Success

                except Exception as e:
                    print(f"exception {self._adjusted=}")
                    print("Caught exception during prediction.")
                    if "CUDA out of memory" in str(e):
                        print(
                            f"Out of memory with samples_per_batch = {self.samples_per_batch}. Reducing samples_per_batch."
                        )
                        torch.cuda.empty_cache()  # Clear cache before retrying
                        self.samples_per_batch = self.samples_per_batch // 2
                        if self.samples_per_batch < 1:
                            raise RuntimeError(
                                "Unable to perform prediction even with minimal samples_per_batch due to OOM."
                            )
                        self._initialize_predictor()
                    else:
                        raise e  # Re-raise unexpected exceptions
        # For subsequent calls, we can just return the generator
        return self.predictor.predict(gluonts_test_data, use_kv_cache=self.use_kv_cache, num_samples=self.num_samples)

## Evaluation

Now that we have our predictor class, we can use it to predict on the boom benchmark datasets. We will use the `evaluate_model` function to evaluate the model. This function is a helper function to evaluate the model on the test data and return the results in a dictionary. We are going to follow the naming conventions explained in the [README](../README.md) file to store the results in a csv file called `all_results.csv` under the `results/toto` folder.

The first column in the csv file is the dataset config name which is a combination of the dataset name, frequency and the term:

```python
f"{dataset_name}/{freq}/{term}"
```


In [None]:
default_context_length = 2048
torch.set_float32_matmul_precision("high")

# Iterate over all available datasets

output_dir = "../results/toto"
# Ensure the output directory exists
os.makedirs(output_dir, exist_ok=True)

pretty_names = {
    "saugeenday": "saugeen",
    "temperature_rain_with_missing": "temperature_rain",
    "kdd_cup_2018_with_missing": "kdd_cup_2018",
    "car_parts_with_missing": "car_parts",
}

# Define the path for the CSV file
csv_file_path = os.path.join(output_dir, "all_results.csv")

toto_model = Toto.from_pretrained('Datadog/Toto-Open-Base-1.0').to(DEVICE)
with open(csv_file_path, "w", newline="") as csvfile:
    writer = csv.writer(csvfile)

    # Write the header
    writer.writerow(
        [
            "dataset",
            "model",
            "eval_metrics/MSE[mean]",
            "eval_metrics/MSE[0.5]",
            "eval_metrics/MAE[0.5]",
            "eval_metrics/MASE[0.5]",
            "eval_metrics/MAPE[0.5]",
            "eval_metrics/sMAPE[0.5]",
            "eval_metrics/MSIS",
            "eval_metrics/RMSE[mean]",
            "eval_metrics/NRMSE[mean]",
            "eval_metrics/ND[0.5]",
            "eval_metrics/mean_weighted_sum_quantile_loss",
            "domain",
            "num_variates",
        ]
    )

for ds_name in all_datasets:
    ds_key = ds_name.split("/")[0]
    print(f"Processing dataset: {ds_name}")
    terms = ["short", "medium", "long"]
    for term in terms:
        if (term == "medium" or term == "long") and ds_name not in med_long_datasets.split():
            continue

        if "/" in ds_name:
            ds_key = ds_name.split("/")[0]
            ds_freq = ds_name.split("/")[1]
            ds_key = ds_key.lower()
            ds_key = pretty_names.get(ds_key, ds_key)
        else:
            ds_key = ds_name.lower()
            ds_key = pretty_names.get(ds_key, ds_key)
            ds_freq = dataset_properties_map[ds_key]["frequency"]

        ds_config = f"{ds_key}/{ds_freq}/{term}"

        # Initialize the dataset, since Toto support multivariate time series forecast, it does not require
        # to convert the original data into univariate
        # to_univariate = False if Dataset(name=ds_name, term=term,to_univariate=False).target_dim == 1 else True
        to_univariate = False
        dataset = Dataset(name=ds_name, term=term, to_univariate=to_univariate)
        for i, j in enumerate(dataset.test_data):
            break

        if len(j[0]["target"].shape) == 2:
            batch_size, context_length = j[0]["target"].shape
        else:
            batch_size = 1
            context_length = j[0]["target"].shape[0]

        predictor_wrapper = TOTOModelPredictorWrapper(
            model=toto_model,
            prediction_length=dataset.prediction_length,
            context_length=min(default_context_length, context_length),
            mode=Multivariate(batch_size=batch_size),  # Adjust based on use case
        )

        # Pad the test_data
        print(f"{ds_name =}, {term=}, {ds_key=}")

        # Determine seasonality
        season_length = get_seasonality(dataset.freq)

        # Evaluate results
        season_length = get_seasonality(dataset.freq)
        res = evaluate_model(
            predictor_wrapper,
            test_data=dataset.test_data,
            metrics=metrics,
            batch_size=512,
            axis=None,
            mask_invalid_label=True,
            allow_nan_forecast=False,
            seasonality=season_length,
        )

        # Append the results to the CSV file
        with open(csv_file_path, "a", newline="") as csvfile:
            writer = csv.writer(csvfile)
            writer.writerow(
                [
                    ds_config,
                    "toto",
                    res["MSE[mean]"][0],
                    res["MSE[0.5]"][0],
                    res["MAE[0.5]"][0],
                    res["MASE[0.5]"][0],
                    res["MAPE[0.5]"][0],
                    res["sMAPE[0.5]"][0],
                    res["MSIS"][0],
                    res["RMSE[mean]"][0],
                    res["NRMSE[mean]"][0],
                    res["ND[0.5]"][0],
                    res["mean_weighted_sum_quantile_loss"][0],
                    dataset_properties_map[ds_key]["domain"],
                    dataset_properties_map[ds_key]["num_variates"],
                ]
            )

        print(f"Results for {ds_name} have been written to {csv_file_path}")

## Results

Running the above cell will generate a csv file called `all_results.csv` under the `results/Toto` folder containing the results for the Toto model on the boom benchmark. The csv file will look like this:


In [None]:
import pandas as pd

df = pd.read_csv("../results/toto/all_results.csv")
df