# Quickstart - Run Kairos on GiftEval

This notebook shows how to run [Kairos](https://github.com/foundation-model-research/Kairos) on the gift-eval benchmark.

Make sure you download the gift-eval benchmark and set the `GIFT-EVAL` environment variable correctly before running this notebook.


## Setup Instructions

1. **Optional but suggested: Install conda environment specifed in Kairos Repo**

```bash
conda create --name kairos python=3.10
conda activate kairos
```

2. **Install Kairos**

```bash
git clone https://github.com/foundation-model-research/Kairos.git
cd Kairos
pip install -r requirements.txt
```

3. **Install additional dependecies needed for GiftEval benchmark**

```bash
pip install gluonts==0.14.4 dotenv datasets==3.5.1
```

## Setting up the data and metrics

In [9]:
import csv
import os

from gluonts.model import evaluate_model, Forecast
from gluonts.model.forecast import SampleForecast
from gluonts.itertools import batcher
from gluonts.time_feature import get_seasonality
from gluonts.ev.metrics import (
    MAE,
    MAPE,
    MASE,
    MSE,
    MSIS,
    ND,
    NRMSE,
    RMSE,
    SMAPE,
    MeanWeightedSumQuantileLoss,
)

import logging

from gift_eval.data import Dataset

import json

from dotenv import load_dotenv

from Kairos.tsfm.model.kairos import AutoModel
from tqdm.auto import tqdm
from typing import List
import numpy as np
import torch
import pandas as pd
import re

# Load environment variables
load_dotenv()

# short_datasets = "m4_yearly m4_quarterly m4_monthly m4_weekly m4_daily m4_hourly electricity/15T electricity/H electricity/D electricity/W solar/10T solar/H solar/D solar/W hospital covid_deaths us_births/D us_births/M us_births/W saugeenday/D saugeenday/M saugeenday/W temperature_rain_with_missing kdd_cup_2018_with_missing/H kdd_cup_2018_with_missing/D car_parts_with_missing restaurant hierarchical_sales/D hierarchical_sales/W LOOP_SEATTLE/5T LOOP_SEATTLE/H LOOP_SEATTLE/D SZ_TAXI/15T SZ_TAXI/H M_DENSE/H M_DENSE/D ett1/15T ett1/H ett1/D ett1/W ett2/15T ett2/H ett2/D ett2/W jena_weather/10T jena_weather/H jena_weather/D bitbrains_fast_storage/5T bitbrains_fast_storage/H bitbrains_rnd/5T bitbrains_rnd/H bizitobs_application bizitobs_service bizitobs_l2c/5T bizitobs_l2c/H"
short_datasets = "m4_weekly"

# med_long_datasets = "electricity/15T electricity/H solar/10T solar/H kdd_cup_2018_with_missing/H LOOP_SEATTLE/5T LOOP_SEATTLE/H SZ_TAXI/15T M_DENSE/H ett1/15T ett1/H ett2/15T ett2/H jena_weather/10T jena_weather/H bitbrains_fast_storage/5T bitbrains_rnd/5T bizitobs_application bizitobs_service bizitobs_l2c/5T bizitobs_l2c/H"
med_long_datasets = "bizitobs_l2c/H"

# Get union of short and med_long datasets
all_datasets = list(set(short_datasets.split() + med_long_datasets.split()))

dataset_properties_map = json.load(open("dataset_properties.json"))

# Instantiate the metrics
metrics = [
    MSE(forecast_type="mean"),
    MSE(forecast_type=0.5),
    MAE(),
    MASE(),
    MAPE(),
    SMAPE(),
    MSIS(),
    RMSE(),
    NRMSE(),
    ND(),
    MeanWeightedSumQuantileLoss(
        quantile_levels=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
    ),
]

## KairosPredictor

In [10]:
def pad_or_truncate(sequence, max_length=2048, pad_value=np.nan):
    """
    Pads or truncates a sequence on the left to a specified max_length.

    Args:
        sequence (list or np.ndarray): The input sequence.
        max_length (int): The target length.
        pad_value (int or float): The value to use for padding, defaults to np.nan.

    Returns:
        np.ndarray: A NumPy array of length max_length.
    """
    seq_np = np.array(sequence)
    current_length = len(seq_np)

    if current_length < max_length:
        # If the current length is less than the target, calculate the required padding
        padding_size = max_length - current_length
        # Use np.pad to add padding to the left
        # (padding_size, 0) means pad `padding_size` elements at the beginning of the first (and only) axis
        return np.pad(seq_np, (padding_size, 0), 'constant', constant_values=pad_value)
    else:
        # If the current length is greater than or equal to the target, truncate to the last max_length elements
        return seq_np[-max_length:]

class KairosPredictor:
    def __init__(
        self,
        model_path,
        prediction_length: int,
        *args,
        **kwargs,
    ):
        print("prediction_length:", prediction_length)
        self.prediction_length = prediction_length
        # 1. Check for CUDA availability and set the primary device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")

        # 2. Load the model
        self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True)

        # 3. Move the model to the primary device
        self.model.to(self.device)

    def predict(self, test_data_input, batch_size: int = 256) -> List[Forecast]:
        self.model.eval()
        model = self.model
        while True:
            try:
                # Generate forecast samples
                forecast_outputs = []
                with torch.no_grad():
                    for batch in tqdm(batcher(test_data_input, batch_size=batch_size)):
                        context = [torch.tensor(pad_or_truncate(entry["target"], max_length=2048)) for entry in batch]
                        forecast_outputs.append(
                            model(
                                past_target=torch.stack(context).to(self.device),
                                prediction_length=self.prediction_length,
                                generation=True,
                                infer_is_positive=True,
                                force_flip_invariance=True,
                            )["prediction_outputs"].detach().cpu().numpy()
                        )
                forecast_outputs = np.concatenate(forecast_outputs)
                break
            except torch.cuda.OutOfMemoryError:
                print(
                    f"OutOfMemoryError at batch_size {batch_size}, reducing to {batch_size // 2}"
                )
                batch_size //= 2

        # Convert forecast samples into gluonts Forecast objects
        forecasts = []
        for item, ts in zip(forecast_outputs, test_data_input):
            forecast_start_date = ts["start"] + len(ts["target"])
            forecasts.append(
                SampleForecast(samples=item, start_date=forecast_start_date)
            )

        return forecasts

## Evaluation

In [11]:
class WarningFilter(logging.Filter):
    def __init__(self, text_to_filter):
        super().__init__()
        self.text_to_filter = text_to_filter

    def filter(self, record):
        return self.text_to_filter not in record.getMessage()


gts_logger = logging.getLogger("gluonts.model.forecast")
gts_logger.addFilter(
    WarningFilter("The mean prediction is not stored in the forecast data")
)

pretty_names = {
    "saugeenday": "saugeen",
    "temperature_rain_with_missing": "temperature_rain",
    "kdd_cup_2018_with_missing": "kdd_cup_2018",
    "car_parts_with_missing": "car_parts",
}

# Model Configuration
# model_name = "Kairos_10m"
# model_path = "mldi-lab/Kairos_10m"
# model_name = "Kairos_23m"
# model_path = "mldi-lab/Kairos_23m"
model_name = "Kairos_50m"
model_path = "mldi-lab/Kairos_50m"

# set the output directory and CSV file path
output_dir = f"../results/{model_name}"
os.makedirs(output_dir, exist_ok=True)
csv_file_path = os.path.join(output_dir, "all_results.csv")

completed_datasets = set()
# 1. Check if the results file exists and read the completed datasets to allow resuming
if os.path.exists(csv_file_path):
    print(f"'{csv_file_path}' exists. Reading completed datasets...")
    with open(csv_file_path, "r", newline="") as csvfile:
        reader = csv.reader(csvfile)
        next(reader)
        for row in reader:
            if row:
                completed_datasets.add(row[0])
    print(f"Found {len(completed_datasets)} completed datasets.")

# 2. If the file doesn't exist, create it and write the header
else:
    with open(csv_file_path, "w", newline="") as csvfile:
        writer = csv.writer(csvfile)

        # Write the header
        writer.writerow(
            [
                "dataset",
                "model",
                "eval_metrics/MSE[mean]",
                "eval_metrics/MSE[0.5]",
                "eval_metrics/MAE[0.5]",
                "eval_metrics/MASE[0.5]",
                "eval_metrics/MAPE[0.5]",
                "eval_metrics/sMAPE[0.5]",
                "eval_metrics/MSIS",
                "eval_metrics/RMSE[mean]",
                "eval_metrics/NRMSE[mean]",
                "eval_metrics/ND[0.5]",
                "eval_metrics/mean_weighted_sum_quantile_loss",
                "domain",
                "num_variates",
            ]
        )

for ds_num, ds_name in enumerate(all_datasets):
    ds_key = ds_name.split("/")[0]
    print(f"Processing dataset: {ds_name} ({ds_num + 1} of {len(all_datasets)})")
    terms = ["short", "medium", "long"]
    for term in terms:
        if (
            term == "medium" or term == "long"
        ) and ds_name not in med_long_datasets.split():
            continue

        if "/" in ds_name:
            ds_key = ds_name.split("/")[0]
            ds_freq = ds_name.split("/")[1]
            ds_key = ds_key.lower()
            ds_key = pretty_names.get(ds_key, ds_key)
        else:
            ds_key = ds_name.lower()
            ds_key = pretty_names.get(ds_key, ds_key)
            ds_freq = dataset_properties_map[ds_key]["frequency"]
        ds_config = f"{ds_key}/{ds_freq}/{term}"

        if ds_config in completed_datasets:
            print(f"Skipping already completed dataset: {ds_config}")
            continue

        # Initialize the dataset
        to_univariate = (
            False
            if Dataset(name=ds_name, term=term, to_univariate=False).target_dim == 1
            else True
        )
        dataset = Dataset(name=ds_name, term=term, to_univariate=to_univariate)
        season_length = get_seasonality(dataset.freq)
        print(f"Dataset size: {len(dataset.test_data)}")
        predictor = KairosPredictor(
            model_path=model_path,
            prediction_length=dataset.prediction_length,
        )
        # Measure the time taken for evaluation
        res = evaluate_model(
            predictor,
            test_data=dataset.test_data,
            metrics=metrics,
            axis=None,
            mask_invalid_label=True,
            allow_nan_forecast=False,
            seasonality=season_length,
        )

        # Append the results to the CSV file
        with open(csv_file_path, "a", newline="") as csvfile:
            writer = csv.writer(csvfile)
            writer.writerow(
                [
                    ds_config,
                    model_name,
                    res["MSE[mean]"][0],
                    res["MSE[0.5]"][0],
                    res["MAE[0.5]"][0],
                    res["MASE[0.5]"][0],
                    res["MAPE[0.5]"][0],
                    res["sMAPE[0.5]"][0],
                    res["MSIS"][0],
                    res["RMSE[mean]"][0],
                    res["NRMSE[mean]"][0],
                    res["ND[0.5]"][0],
                    res["mean_weighted_sum_quantile_loss"][0],
                    dataset_properties_map[ds_key]["domain"],
                    dataset_properties_map[ds_key]["num_variates"],
                ]
            )

        print(f"Results for {ds_name} have been written to {csv_file_path}")

Processing dataset: m4_weekly (1 of 2)
Dataset size: 359
prediction_length: 13
Using device: cuda


2it [00:01,  1.86it/s]
359it [00:06, 53.83it/s]
  res["MSE[mean]"][0],
  res["MSE[0.5]"][0],
  res["MAE[0.5]"][0],
  res["MASE[0.5]"][0],
  res["MAPE[0.5]"][0],
  res["sMAPE[0.5]"][0],
  res["MSIS"][0],
  res["RMSE[mean]"][0],
  res["NRMSE[mean]"][0],
  res["ND[0.5]"][0],
  res["mean_weighted_sum_quantile_loss"][0],


Results for m4_weekly have been written to ../results/Kairos_50m/all_results.csv
Processing dataset: bizitobs_l2c/H (2 of 2)
Dataset size: 42
prediction_length: 48
Using device: cuda


1it [00:00,  4.22it/s]
42it [00:00, 57.98it/s]


Results for bizitobs_l2c/H have been written to ../results/Kairos_50m/all_results.csv
Dataset size: 7
prediction_length: 480
Using device: cuda


1it [00:00,  1.06it/s]
7it [00:00, 54.59it/s]


Results for bizitobs_l2c/H have been written to ../results/Kairos_50m/all_results.csv
Dataset size: 7
prediction_length: 720
Using device: cuda


1it [00:01,  1.35s/it]
7it [00:00, 53.93it/s]

Results for bizitobs_l2c/H have been written to ../results/Kairos_50m/all_results.csv





## Optional: Sort by lexicographic order, frequency, and predicted length of the dataset

In [12]:
def process_and_sort_csv(PATH):
    """
    Sort by lexicographic order, frequency, and predicted length of the dataset

    Args:
        PATH (str): The directory path containing the CSV file.
    """
    input_filepath=f'{PATH}/all_results.csv'
    output_filepath=f'{PATH}/all_results.csv'
    # Check if the input file exists
    if not os.path.exists(input_filepath):
        print(f"Error: File not found '{input_filepath}'. Please ensure the file exists and the path is correct.")
        return

    # 1. Load the CSV file
    df = pd.read_csv(input_filepath)
    print("Successfully loaded file.")

    # 2. Define the function for sorting
    def get_sort_key(dataset_name):
        """
        Generates a tuple for sorting based on the dataset string.
        Sorting rules:
        1. Before the first '/': Lexicographical order
        2. Before the second '/': By number, then by S, T, H, D, W, M order
        3. After the second '/': By short, medium, long order
        """
        try:
            part1, part2, part3 = dataset_name.split('/')
        except ValueError:
            # Handle strings that do not match the 'part1/part2/part3' format, placing them at the end
            return (dataset_name, float('inf'), float('inf'), float('inf'))

        # Define the sort mapping for the second part
        time_unit_map = {'S': 0, 'T': 1, 'H': 2, 'D': 3, 'W': 4, 'M': 5}
        # Use regex to separate the number and the unit
        match = re.match(r'(\d*)(\w+)', part2)
        if match:
            num_str, unit = match.groups()
            num = int(num_str) if num_str else 0 # Default to 0 if no number prefix
            unit_order = time_unit_map.get(unit, float('inf'))
        else:
            num, unit_order = float('inf'), float('inf') # Place at the end if format doesn't match

        # Define the sort mapping for the third part
        length_map = {'short': 0, 'medium': 1, 'long': 2}
        length_order = length_map.get(part3, float('inf')) # Place at the end if not found

        # Return a tuple; Pandas will sort by the elements of the tuple in order
        return (part1, unit_order, num, length_order)
    
    # 3. Apply the sorting logic
    # Create a temporary column to store the sort keys
    df['sort_key'] = df['dataset'].apply(get_sort_key)

    # Sort by this temporary column, then drop it
    df_sorted = df.sort_values(by='sort_key').drop(columns='sort_key')
    print("Data sorting complete.")

    # 4. Save the processed data to a new file
    df_sorted.to_csv(output_filepath, index=False)
    print(f"Processing complete! Results saved to '{output_filepath}'.")
    print("\nPreview of the first 5 rows of the processed data:")
    print(df_sorted.head())

process_and_sort_csv(output_dir)

Successfully loaded file.
Data sorting complete.
Processing complete! Results saved to '../results/Kairos_50m/all_results.csv'.

Preview of the first 5 rows of the processed data:
                 dataset       model  eval_metrics/MSE[mean]  \
1   bizitobs_l2c/H/short  Kairos_50m               57.918098   
2  bizitobs_l2c/H/medium  Kairos_50m               63.733904   
3    bizitobs_l2c/H/long  Kairos_50m               76.498069   
0      m4_weekly/W/short  Kairos_50m           350282.823003   

   eval_metrics/MSE[0.5]  eval_metrics/MAE[0.5]  eval_metrics/MASE[0.5]  \
1              58.555671               4.736005                0.480060   
2              63.673946               4.684196                0.474470   
3              76.727667               5.158500                0.556074   
0          342247.636980             303.264569                2.431396   

   eval_metrics/MAPE[0.5]  eval_metrics/sMAPE[0.5]  eval_metrics/MSIS  \
1                0.451625                 0.553508