install timesfm 1.3.0 using:

````
uv remove "timesfm[torch] @ git+https://github.com/google-research/timesfm.git"
uv add timesfm
````

In [None]:
from fusiontimeseries.lib.config import FTSConfig
import torch

torch.cuda.empty_cache()

fts_config = FTSConfig()
fts_config.prediction_length = 64

In [None]:
from timesfm.timesfm_torch import TimesFmTorch
from timesfm.timesfm_base import TimesFmHparams, TimesFmCheckpoint

repo_id = "google/timesfm-2.0-500m-pytorch"

hparams = TimesFmHparams(
    backend=fts_config.device,  # type: ignore
    per_core_batch_size=fts_config.batch_size,
    horizon_len=fts_config.prediction_length,
    context_len=fts_config.context_length,
    num_layers=50,
    use_positional_embedding=True,
)

tfm = TimesFmTorch(
    hparams=hparams, checkpoint=TimesFmCheckpoint(huggingface_repo_id=repo_id)
)

In [None]:
from fusiontimeseries.lib.dataset import TimeseriesDataset

benchmark_flux_traces = TimeseriesDataset.get_benchmark_flux_traces(config=fts_config)

In [None]:
from typing import Literal
import numpy as np

START_TIMESTAMP: int = 80
forecasts: dict[Literal["id", "ood"], dict[int, np.ndarray]] = {
    "id": {},
    "ood": {},
}

for distribution, samples in benchmark_flux_traces.items():
    for sample_id, flux_trace in samples.items():
        time_series: np.ndarray = np.array(flux_trace.energy_flux)

        ctx = time_series[:START_TIMESTAMP]

        while len(ctx) < len(time_series):
            # forecast: list(batch_size=1, prediction_length)
            point_forecast, _ = tfm.forecast(
                inputs=[ctx],
                freq=[1],
                forecast_context_len=fts_config.context_length,
                normalize=True,
            )
            ctx = np.concatenate([ctx, point_forecast.squeeze(0)], axis=0)

        forecast: np.ndarray = ctx[: len(time_series)]
        forecasts[distribution][sample_id] = forecast

In [None]:
from datetime import datetime
from fusiontimeseries.benchmarking.benchmark_utils import rmse_with_standard_error

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
results: dict[str, dict[Literal["rsme", "rsme_se"], float]] | str = {
    "id": {},
    "ood": {},
    "timestamp": timestamp,
    "prediction_length": fts_config.prediction_length,
}  # type: ignore

for distribution, samples in benchmark_flux_traces.items():
    y_true: list[float] = []
    y_pred: list[float] = []

    for sample_id, flux_trace in samples.items():
        time_series: np.ndarray = np.array(flux_trace.energy_flux)
        target_mean = time_series[-fts_config.pred_tail_timestamps :].mean()
        forecast_mean = forecasts[distribution][sample_id][
            -fts_config.pred_tail_timestamps :
        ].mean()
        y_true.append(target_mean)
        y_pred.append(forecast_mean)

    rsme, rsme_se = rmse_with_standard_error(np.array(y_true), np.array(y_pred))
    results[distribution]["rsme"] = rsme  # type: ignore
    results[distribution]["rsme_se"] = rsme_se  # type: ignore
results

In [None]:
import json
from pathlib import Path

# Prepare results
model_name_clean = repo_id.replace("/", "_")

# Save results to JSON
data_dir = Path(".").resolve() / "results" / model_name_clean
data_dir.mkdir(parents=True, exist_ok=True)
results_file = data_dir / f"{timestamp}_{model_name_clean}_benchmark_results.json"
with open(results_file, "w") as f:
    json.dump(results, f, indent=2)

print(f"Results saved to: {results_file}")

In [None]:
# create and save plots of forecasts vs true values
import matplotlib.pyplot as plt

plot_dir = data_dir / "plots"
plot_dir.mkdir(parents=True, exist_ok=True)
for distribution, samples in benchmark_flux_traces.items():
    for sample_id, flux_trace in samples.items():
        time_series: np.ndarray = np.array(flux_trace.energy_flux)
        forecast: np.ndarray = forecasts[distribution][sample_id]

        plt.figure(figsize=(10, 5))
        plt.plot(time_series, label="True Values")
        plt.plot(forecast, label="Forecast", linestyle="--")
        plt.title(
            f"Forecast vs True Values ({distribution.upper()}) - Sample {sample_id}"
        )
        plt.xlabel("Time")
        plt.ylabel("Energy Flux")
        plt.legend()
        plt.grid()

        plot_file = (
            plot_dir
            / f"{timestamp}_{model_name_clean}_{distribution}_sample_{sample_id}_forecast.png"
        )
        plt.savefig(plot_file)
        plt.close()