# Chronos Time Series Prediction Playbook

[Huggingface](https://huggingface.co/amazon/chronos-bolt-tiny)

In [None]:
def extend_path():
    """Extend notebooks system path config to import relative packages."""
    import sys
    from pathlib import Path

    parent_folder = str(Path.cwd().parent)
    print(f"Adding {parent_folder} to sys.path")
    if parent_folder not in sys.path:
        sys.path.insert(0, parent_folder)


extend_path()

In [None]:
import torch
from chronos import BaseChronosPipeline

In [None]:
from pathlib import Path
from services import PredictionVizualizationProvider

predviz_provider = PredictionVizualizationProvider()

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

MODEL_PATH = "amazon/chronos-bolt-tiny"
N_TIMESERIES = 1

In [None]:
pipeline: BaseChronosPipeline = BaseChronosPipeline.from_pretrained(
    pretrained_model_name_or_path=MODEL_PATH,
    device_map=DEVICE,  # use "cpu" for CPU inference and "mps" for Apple Silicon
    dtype=torch.bfloat16,
)

In [None]:
class Evaluator:
    """Evaluator class for computing forecasting metrics."""

    @staticmethod
    def mae(pred: torch.Tensor, tgt: torch.Tensor) -> float:
        """Compute Mean Absolute Error (MAE).

        Args:
            pred (torch.Tensor): Predicted values tensor.
            tgt (torch.Tensor): Target values tensor.

        Returns:
            float: MAE value.
        """
        return torch.mean(torch.abs(pred - tgt)).item()

    @staticmethod
    def rmse(pred: torch.Tensor, tgt: torch.Tensor) -> float:
        """Compute Root Mean Squared Error (RMSE).

        Args:
            pred (torch.Tensor): Predicted values tensor.
            tgt (torch.Tensor): Target values tensor.

        Returns:
            float: RMSE value.
        """
        return torch.sqrt(torch.mean((pred - tgt) ** 2)).item()

    @staticmethod
    def nrmse(pred: torch.Tensor, tgt: torch.Tensor) -> float:
        """Compute Normalized Root Mean Squared Error (NRMSE).

        Normalized by the range of the target (max - min).

        Args:
            pred (torch.Tensor): Predicted values tensor.
            tgt (torch.Tensor): Target values tensor.

        Returns:
            float: NRMSE value.
        """
        rmse_val = torch.sqrt(torch.mean((pred - tgt) ** 2))
        tgt_range = torch.max(tgt) - torch.min(tgt)
        if tgt_range == 0:
            return float("inf")
        return (rmse_val / tgt_range).item()

    @staticmethod
    def nd(pred: torch.Tensor, tgt: torch.Tensor) -> float:
        """Compute Normalized Deviation (ND).

        Sum of absolute errors divided by sum of absolute target values.
        Also known as WAPE (Weighted Absolute Percentage Error).

        Args:
            pred (torch.Tensor): Predicted values tensor.
            tgt (torch.Tensor): Target values tensor.

        Returns:
            float: ND value.
        """
        sum_abs_error = torch.sum(torch.abs(pred - tgt))
        sum_abs_tgt = torch.sum(torch.abs(tgt))
        if sum_abs_tgt == 0:
            return float("inf")
        return (sum_abs_error / sum_abs_tgt).item()

    @staticmethod
    def mape(pred: torch.Tensor, tgt: torch.Tensor) -> float:
        """Compute Mean Absolute Percentage Error (MAPE).

        Args:
            pred (torch.Tensor): Predicted values tensor.
            tgt (torch.Tensor): Target values tensor.

        Returns:
            float: MAPE value as percentage.
        """
        # Avoid division by zero by adding small epsilon where tgt is zero
        epsilon = 1e-8
        safe_tgt = torch.where(tgt == 0, epsilon, tgt)
        return (100.0 * torch.abs(pred - tgt) / torch.abs(safe_tgt)).mean().item()

    @staticmethod
    def smape(pred: torch.Tensor, tgt: torch.Tensor) -> float:
        """Compute Symmetric Mean Absolute Percentage Error (sMAPE).

        Args:
            pred (torch.Tensor): Predicted values tensor.
            tgt (torch.Tensor): Target values tensor.

        Returns:
            float: sMAPE value.
        """
        return (
            (100.0 * torch.abs(pred - tgt) / ((torch.abs(tgt) + torch.abs(pred)) / 2))
            .mean()
            .item()
        )

    @staticmethod
    def mase(
        pred: torch.Tensor, tgt: torch.Tensor, context: torch.Tensor | None = None
    ) -> float:
        """Compute Mean Absolute Scaled Error (MASE).

        Uses naive forecast on context (history) as benchmark if available,
        otherwise uses naive forecast on target (less robust).

        Args:
            pred (torch.Tensor): Predicted values tensor.
            tgt (torch.Tensor): Target values tensor.
            context (torch.Tensor | None): Context/History tensor.

        Returns:
            float: MASE value.
        """
        mae_pred = torch.mean(torch.abs(pred - tgt))

        if context is not None and context.numel() > 1:
            # Use context (history) for naive error scale
            # Calculate mean absolute difference of the context (in-sample naive error)
            ctx = context.squeeze()
            if ctx.ndim > 1:
                ctx = ctx.view(-1)
            scale = torch.mean(torch.abs(ctx[1:] - ctx[:-1]))
        else:
            # Fallback: use target (less robust as it's out-of-sample)
            naive_pred = tgt[:-1]
            naive_tgt = tgt[1:]
            scale = torch.mean(torch.abs(naive_pred - naive_tgt))

        if scale == 0:
            return float("inf")

        return (mae_pred / scale).item()

    @staticmethod
    def directional_accuracy(pred: torch.Tensor, tgt: torch.Tensor) -> float:
        """Compute Directional Accuracy (percentage of correct direction predictions).

        Measures if the predicted direction of change matches the actual direction.

        Args:
            pred (torch.Tensor): Predicted values tensor.
            tgt (torch.Tensor): Target values tensor.

        Returns:
            float: Directional accuracy as percentage (0-100).
        """
        if len(pred) < 2 or len(tgt) < 2:
            return 0.0
        pred_changes = torch.sign(pred[1:] - pred[:-1])
        tgt_changes = torch.sign(tgt[1:] - tgt[:-1])
        correct = torch.sum(pred_changes == tgt_changes).item()
        total = len(pred_changes)
        return (correct / total) * 100.0


class Utils:
    @staticmethod
    def median_forecast(forecast: torch.Tensor) -> torch.Tensor:
        """Compute the median forecast from the forecast tensor.

        Args:
            forecast (torch.Tensor): Forecast tensor of shape [N, prediction_length, n_quantiles].

        Returns:
            torch.Tensor: Median forecast tensor of shape [N, prediction_length].
        """
        n_quantiles = forecast.shape[-1]
        median_index = n_quantiles // 2
        return forecast[:, :, median_index]

In [None]:
from typing import Any, Generator


class FluxTrace:
    """A class to iterate over time series data in sliding windows for forecasting.

    This class allows generating context and target pairs from a time series trace
    for training or evaluating forecasting models.
    """

    def __init__(
        self,
        trace: torch.Tensor,
        prediction_length: int,
        context_length: int | None = None,
        window: int | None = None,
    ) -> None:
        """Initialize the FluxTrace iterator.

        Args:
            trace (torch.Tensor): The full time series data tensor.
            prediction_length (int): Length of the prediction horizon.
            context_length (int | None): Length of the context window. If None,
                defaults to trace length minus prediction_length.
            window (int | None): Step size for sliding the window. If None,
                defaults to context_length (non-overlapping windows).
        """
        self.trace = trace
        self.prediction_length = prediction_length
        self.context_length = context_length
        self.window = window
        self.metrics: list[dict[str, Any]] = []

    def __iter__(self) -> Generator[tuple[torch.Tensor, torch.Tensor, int], None, None]:
        """Iterate over the trace yielding context and target pairs.

        Yields:
            tuple[torch.Tensor, torch.Tensor]: A tuple of (context, target) tensors.
                - context: Tensor of shape [1, context_length] for model input.
                - target: Tensor of shape [prediction_length] for ground truth.
        """
        trace_length = self.trace.shape[-1]
        context_length = self.context_length or trace_length - self.prediction_length
        start = 0
        end = context_length
        while end + self.prediction_length <= trace_length:
            yield (
                self.trace[..., start:end].unsqueeze(0),
                self.trace[..., end : end + self.prediction_length],
                start,
            )
            start += self.window or context_length
            end += self.window or context_length

    def record(
        self,
        forecast: torch.Tensor,
        target: torch.Tensor,
        context: torch.Tensor | None = None,
        metadata: dict[str, Any] = {},
    ) -> None:
        """Record evaluation metrics for a given forecast and ground truth.

        Args:
            forecast (torch.Tensor): Forecasted values tensor.
            target (torch.Tensor): Ground truth values tensor.
            context (torch.Tensor | None): Context values tensor (history).
            metadata (dict[str, Any]): Additional metadata for the evaluation. Defaults to empty dict.
        """
        mae = Evaluator.mae(forecast, target)
        rmse = Evaluator.rmse(forecast, target)
        nrmse = Evaluator.nrmse(forecast, target)
        nd = Evaluator.nd(forecast, target)
        mape = Evaluator.mape(forecast, target)
        smape = Evaluator.smape(forecast, target)
        mase = Evaluator.mase(forecast, target, context)
        directional_acc = Evaluator.directional_accuracy(forecast, target)

        self.metrics.append(
            {
                "MAE": mae,
                "RMSE": rmse,
                "NRMSE": nrmse,
                "ND": nd,
                "MAPE": mape,
                "sMAPE": smape,
                "MASE": mase,
                "Directional Accuracy": directional_acc,
                "metadata": metadata,
            }
        )

    def forecast_summary(self) -> dict[str, float]:
        """Compute average metrics over all recorded forecasts.

        Returns:
            dict[str, float]: A dictionary with average metrics.
        """
        summary = {
            "MAE": 0.0,
            "RMSE": 0.0,
            "NRMSE": 0.0,
            "ND": 0.0,
            "MAPE": 0.0,
            "sMAPE": 0.0,
            "MASE": 0.0,
            "Directional Accuracy": 0.0,
        }
        n = len(self.metrics)
        if n == 0:
            return summary

        for metric in self.metrics:
            summary["MAE"] += metric["MAE"]
            summary["RMSE"] += metric["RMSE"]
            summary["NRMSE"] += metric["NRMSE"]
            summary["ND"] += metric["ND"]
            summary["MAPE"] += metric["MAPE"]
            summary["sMAPE"] += metric["sMAPE"]
            summary["MASE"] += metric["MASE"]
            summary["Directional Accuracy"] += metric["Directional Accuracy"]

        for key in summary:
            summary[key] /= n

        return summary


class Scaler:
    """Utility class for normalizing and denormalizing time series data."""

    @staticmethod
    def setnorm(ctx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Normalize the context tensor using z-score normalization.

        Args:
            ctx (torch.Tensor): Input context tensor.

        Returns:
            tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
                - normed: Normalized tensor.
                - mean: Mean tensor used for normalization.
                - std: Standard deviation tensor used for normalization.
        """
        mean = ctx.mean(dim=-1, keepdim=True)
        std = ctx.std(dim=-1, keepdim=True) + 1e-8
        normed = (ctx - mean) / std
        return normed, mean, std

    @staticmethod
    def denorm(
        pred: torch.Tensor, mean: torch.Tensor, std: torch.Tensor
    ) -> torch.Tensor:
        """Denormalize the prediction tensor back to original scale.

        Args:
            pred (torch.Tensor): Normalized prediction tensor.
            mean (torch.Tensor): Mean tensor from normalization.
            std (torch.Tensor): Standard deviation tensor from normalization.


        Returns:
            torch.Tensor: Denormalized prediction tensor.
        """
        return pred * std + mean

In [None]:
from functools import cache
import os

import numpy as np


class FluxType:
    ELECTRON_FLUX: int = 0
    ENERGY_FLUX: int = 1
    ION_FLUX: int = 2


class FluxTraceProvider:
    FLUX_TYPE: FluxType = FluxType()
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def __init__(
        self,
        dir: Path,
        filename_convention: str = "fluxes_{iteration}.dat",
    ) -> None:
        """Data Access Provider for flux traces

        Args:
            dir (Path): The path to the directory where the flux traces are stored
            filename_convention (str, optional): The filename convention. Defaults to "fluxes_{iteration}.dat".
        """
        self.dir = dir
        self.filename_convention = filename_convention

    @cache
    def load_flux_energy_data(self, iteration: int) -> torch.Tensor:
        file_path: Path = self.dir / self.filename_convention.format(
            iteration=iteration
        )
        data: np.ndarray = np.loadtxt(file_path)
        return torch.from_numpy(data[:, self.FLUX_TYPE.ENERGY_FLUX]).to(
            self.DEVICE
        )  # return only energy fluxes

    @cache
    def __len__(self) -> int:
        """Get the number of flux trace files available.

        Returns:
            int: Number of flux trace files.
        """
        return len(os.listdir(self.dir))

In [None]:
from torch.utils.data import Dataset, DataLoader


class FluxDataset(Dataset):
    """PyTorch Dataset for flux traces using FluxTraceProvider."""

    def __init__(
        self,
        provider: FluxTraceProvider,
        prediction_length: int,
        context_length: int | None = None,
        window: int | None = None,
    ) -> None:
        """Initialize the dataset.

        Args:
            provider (FluxTraceProvider): The data provider.
            prediction_length (int): Length of prediction horizon.
            context_length (int | None): Length of context window.
            window (int | None): Step size for sliding window.
        """
        self.provider: FluxTraceProvider = provider
        self.prediction_length = prediction_length
        self.context_length = context_length
        self.window = window

    def __len__(self) -> int:
        return len(self.provider)

    def __getitem__(self, idx: int) -> FluxTrace:
        trace = self.provider.load_flux_energy_data(idx)
        flux_trace = FluxTrace(
            trace=trace,
            prediction_length=self.prediction_length,
            context_length=self.context_length,
            window=self.window,
        )
        return flux_trace


# Example usage:
# dataset = FluxDataset(
#     provider=fluxtrace_provider,
#     iterations=[0, 1, 2],
#     prediction_length=PREDICTION_LEN,
#     context_length=CONTEXT_LEN,
#     window=10,
# )
# dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

In [None]:
from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field
from tqdm import tqdm
import json


class Benchmark(BaseModel):
    model: str
    prediction_length: int
    context_length: int | None = None
    window: int | None = None
    benchmark_start_timestamp: float | None = Field(
        default=None, description="Posix Timestamp."
    )
    benchmark_end_timestamp: float | None = Field(
        default=None, description="Posix Timestamp."
    )
    metrics: list[dict[str, Any]] = []


class FluxForecastingBenchmarker:
    def __init__(
        self,
        dataset: FluxDataset,
        model: str,
        save_dir: Path = Path(".").resolve().parent / "data" / "benchmarks",
        benchmark_file_convention: str = "{timestamp}_benchmark_{model}.json",
    ) -> None:
        self.dataset = dataset
        self.model = model
        self.benchmark_file_convention = benchmark_file_convention
        self.save_dir = save_dir
        self.save_dir.mkdir(parents=True, exist_ok=True)
        print(f"Benchmarker will save results to: {self.save_dir}")

        self.metrics: list[dict[str, Any]] = []
        self.benchmark_start_time: datetime | None = None
        self.benchmark_end_time: datetime | None = None

    def benchmark(self, batch_size: int = 1, stop_after: int | None = None) -> None:
        # reset
        self.metrics = []
        self.benchmark_start_time = None
        self.benchmark_end_time = None

        dataloader = DataLoader(
            self.dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: x[0]
        )

        self.benchmark_start_time = datetime.now()
        for idx, fluxtrace in tqdm(enumerate(dataloader)):
            fluxtrace: FluxTrace
            if stop_after is not None and idx >= stop_after:
                break

            for ctx, tgt, _ in fluxtrace:
                ctx: torch.Tensor
                tgt: torch.Tensor
                _: int

                # Normalize context
                normed_ctx, mean, std = Scaler.setnorm(ctx)

                # Generate forecast
                with torch.no_grad():
                    forecast: torch.Tensor = pipeline.predict(
                        normed_ctx,
                        prediction_length=self.dataset.prediction_length,
                    )
                    forecast = forecast.permute(0, 2, 1)

                # Denormalize forecast
                denormed_forecast = Scaler.denorm(forecast, mean, std)

                fluxtrace.record(
                    forecast=Utils.median_forecast(denormed_forecast).squeeze(0),
                    target=tgt.squeeze(0),
                    context=ctx.squeeze(0),
                )

            self.metrics.append(fluxtrace.forecast_summary())
        self.benchmark_end_time = datetime.now()

    def save_benchmark(self) -> None:
        """Save the benchmark to a file."""
        benchmark: Benchmark = Benchmark(
            model=self.model,
            prediction_length=self.dataset.prediction_length,
            context_length=self.dataset.context_length,
            window=self.dataset.window,
            benchmark_start_timestamp=self.benchmark_start_time.timestamp()
            if self.benchmark_start_time
            else None,
            benchmark_end_timestamp=self.benchmark_end_time.timestamp()
            if self.benchmark_end_time
            else None,
            metrics=self.metrics,
        )

        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        safe_model = self.model.replace("/", "_")
        filename = self.benchmark_file_convention.format(
            model=safe_model, timestamp=timestamp
        )

        with open(self.save_dir / filename, "w") as f:
            json.dump(
                benchmark.model_dump(exclude_none=True, exclude_defaults=True),
                f,
                indent=4,
            )

In [None]:
torch.manual_seed(42)

fluxtrace_provider = FluxTraceProvider(
    dir=Path(".").resolve().parent.parent / "data" / "flux" / "raw"
)
dataset = FluxDataset(
    provider=fluxtrace_provider,
    prediction_length=64,
    context_length=128,
    window=32,
)
benchmarker = FluxForecastingBenchmarker(dataset, model=MODEL_PATH)

In [None]:
benchmarker.benchmark(stop_after=10)
benchmarker.save_benchmark()