# Q4 – Model Ranking by Average KL Divergence

This notebook summarizes model performance across multiple data-generating processes (DGPs) using **average KL divergence** as the primary evaluation metric. We rank models by their ability to align with true forecast distributions, across different forecast horizons and context lengths.

### Objective
- Rank models by average KL divergence over selected DGPs and forecast days.
- Compare performance across both price and return forecasts.
- Identify which architectures generalize best across heterogeneous settings.

### Evaluation Setup
- Models are evaluated on five representative DGPs: `gbm_low_vol`, `gbm_high_vol`, `t_garch`, `mixture_normal`, and `seasonal`.
- KL divergence is computed at Days 2, 12, and 22.
- Scores are averaged per model, DGP, and context length.

### Key Outputs
- 📄 Tables:
  - Average KL per model × DGP × context length (sorted)
- 📊 Plots:
  - Bar charts showing average KL across models (separately for prices and returns)

### Notes
- Both price and return targets are analyzed.
- Only models in `selected_model_names` are included.
- KL is always computed on **returns**, even for price forecasts (returns are derived before evaluation).

This analysis provides a clean, global view of model quality—highlighting which architectures offer consistently accurate forecasts across different temporal and statistical regimes.


In [1]:
# Packages
import pickle
import numpy as np
import pandas as pd
from pathlib import Path

from utils.evaluation import (
    compute_kl_divergence,
    dataframe_to_latex
)

from utils.plotting import (
    plot_model_comparison_bar_avg
)

# Needed to avoid issues with numpy for TimesFM 2.5
import sys, numpy.core.numeric as numeric
sys.modules['numpy._core.numeric'] = numeric

### Models List

Models can be added or removed from the followng list.

In [2]:
# Selected Models for Analysis
selected_model_names = [
    "chronos_model_tiny",
    "chronos_model_mini",
    "chronos_model_base",
    "lag_llama_model",
    "moirai_model_small",
    "moirai_model_base",
    "moirai_model_small_2_0",   # NEW
    "moirai_model_small_1_1",   # NEW
    "moirai_model_base_1_1",    # NEW
    "toto_model",
    "tirex_model",
    "timesfm_model_small",
    "timesfm_model_large",
    "timesfm_model_2_5"        # NEW
]

In [3]:
# Paths and Setup
results_dir = Path("results_q4_model")
tables_dir = results_dir / "tables"
plots_dir = results_dir / "plots_model_comparison"
tables_dir.mkdir(parents=True, exist_ok=True)
plots_dir.mkdir(parents=True, exist_ok=True)

forecast_dir = Path("forecasts")
run_dir = Path("runfiles")
datasets_dir = Path("datasets")

selected_days = [0, 10, 20]
ordered_days = [f"Day {d+2}" for d in selected_days]
context_lengths = [22, 66, 252]

dgp_types_kl = ["gbm_low_vol", "gbm_high_vol", "garch", "t_garch", "mixture_normal", "seasonal"]

### Loading the Forecasts

We load the forecasts and retrieve the specifics.

In [4]:
# Load Forecasts
forecast_files = sorted(forecast_dir.glob("forecast_*.pkl"))
results = []

for forecast_file in forecast_files:
    run_name = forecast_file.stem
    run_file = run_dir / f"{run_name}.txt"
    if not run_file.exists():
        continue

    # Skip temperature-tuning runs
    with open(run_file, "r") as f:
        run_text = f.read()
    if "temperature" in run_text:
        continue

    run_config = {}
    with open(run_file, "r") as f:
        for line in f:
            if "=" in line:
                key, value = [x.strip() for x in line.strip().split("=", 1)]
                try:
                    run_config[key] = eval(value)
                except:
                    run_config[key] = value.strip("\"'").strip("'")

    try:
        with open(forecast_file, "rb") as f:
            forecast_result = pickle.load(f)
            low, median, high, samples, base_price = forecast_result
    except Exception:
        continue

    results.append({
        "run_name": run_name,
        "model_name": run_config["model_name"],
        "dgp_type": run_config["dataset_name"],
        "target_type": run_config["target_type"],
        "context_length": run_config["context_length"],
        "samples": samples,
        "low": low,
        "median": median,
        "high": high,
        "base_price": base_price
    })

# Filter Results by Selected Models
results = [r for r in results if r["model_name"] in selected_model_names]

price_results = [r for r in results if r["target_type"] == "prices"]
return_results = [r for r in results if r["target_type"] == "returns"]

In [5]:
print("Unique model names found in runfiles:")
for name in sorted(set(r["model_name"] for r in results)):
    print(f"'{name}'")

Unique model names found in runfiles:
'chronos_model_base'
'chronos_model_mini'
'chronos_model_tiny'
'lag_llama_model'
'moirai_model_base'
'moirai_model_base_1_1'
'moirai_model_small'
'moirai_model_small_1_1'
'moirai_model_small_2_0'
'timesfm_model_2_5'
'timesfm_model_large'
'timesfm_model_small'
'tirex_model'
'toto_model'


### Defining Functions

We define 2 new special functions to save tables and compute the KL divergence compatible with this notebook setup.

In [6]:
# Compute KL Divergence
def compute_kl_dataframe(results_subset):
    kl_rows = []

    for item in results_subset:
        if item["dgp_type"] not in dgp_types_kl:
            continue

        is_price = item["target_type"] == "prices"
        model_returns = item["samples"]
        if is_price:
            model_returns = model_returns[:, 1:] / model_returns[:, :-1] - 1

        dgp_path = datasets_dir / f"{item['dgp_type']}_returns_paths.npy"
        if not dgp_path.exists():
            continue

        dgp_returns = np.load(dgp_path)

        for day_index in selected_days:
            try:
                p = dgp_returns[:, day_index]
                q = model_returns[:, day_index]
                kl = compute_kl_divergence(p, q)

                kl_rows.append({
                    "context_length": item["context_length"],
                    "dgp_type": item["dgp_type"],
                    "model_name": item["model_name"],
                    "day": f"Day {day_index + 2}",
                    "kl_divergence": kl
                })
            except:
                continue

    return pd.DataFrame(kl_rows).round(4)

df_kl_prices = compute_kl_dataframe(price_results)
df_kl_returns = compute_kl_dataframe(return_results)

In [7]:
# Save Average KL Table
def save_avg_kl_table(df_kl, label):
    df_avg = (
        df_kl[df_kl["day"].isin([f"Day {i + 2}" for i in selected_days])]
        .groupby(["context_length", "dgp_type", "model_name"])["kl_divergence"]
        .mean()
        .reset_index(name="avg_kl")
    )

    for context in context_lengths:
        df_context = df_avg[df_avg["context_length"] == context].copy()

        # Keep context in final pivot for clarity
        df_context["context_length"] = context  # redundant but clear
        df_context = df_context[["context_length", "dgp_type", "model_name", "avg_kl"]]

        # Pivot with multi-index for latex formatting
        pivot = df_context.set_index(["context_length", "dgp_type", "model_name"]).sort_values("avg_kl")

        filename = f"q4_avg_kl_{label}_context{context}.tex"
        dataframe_to_latex(pivot, tables_dir / filename, preserve_index_order=True)

In [8]:
# Save Tables
save_avg_kl_table(df_kl_prices, "prices")
save_avg_kl_table(df_kl_returns, "returns")

### Plotting

Only the specific figure is here plotted.

In [9]:
# Generate Plots (Bar by Avg KL)
plot_model_comparison_bar_avg(df_kl_prices, plots_dir, "prices", selected_days)
plot_model_comparison_bar_avg(df_kl_returns, plots_dir, "returns", selected_days)

### Interpretation: Model Ranking by Average KL at Context Length 22

We rank models by their average KL divergence across five DGPs. Lower values indicate better alignment between model forecasts and true future distributions. Rankings are reported separately for price and return targets.

**Price Forecasts**

- The best-performing models are the three Chronos variants. Chronos tiny (0.04), base (0.05), and mini (0.08) top the ranking, driven by excellent performance on gbm_high_vol and seasonal DGPs.

- Lag-Llama performs competitively, especially on high- and moderate-volatility DGPs (avg KL 0.11–0.35). Toto and Moirai base are also strong contenders on t_garch and mixture_normal.

- Mid-tier models include Tirex, Toto, and Moirai base. They maintain KL values around 0.4–0.8 depending on DGP. TimesFM large performs similarly.

- Lower-tier models show rising KL divergence across DGPs. TimesFM small and Moirai small often exceed 1.5, indicating unstable forecasts under short contexts.

- The worst scores come from Chronos mini and base on gbm_low_vol (KL ~4.9), and Moirai small on seasonal (KL ~5.1), revealing sharp degradation in specific low-volatility or structured regimes.

**Return Forecasts**

- Lag-Llama dominates return forecasting. It ranks first across all DGPs with near-zero KL (0.01–0.04), showing unmatched consistency and precision.

- Toto is also a top performer, particularly strong on gbm_low_vol and mixture_normal (KL ~0.04–0.06), though it struggles on t_garch (KL 4.62).

- Moirai small and base rank next, achieving KL < 0.2 on simpler DGPs like seasonal and gbm_high_vol, but degrade significantly on t_garch and mixture_normal.

- Tirex and TimesFM large maintain moderate performance, generally in the 0.2–0.5 KL range, but spike sharply on t_garch.

- TimesFM small is consistently worse than its large counterpart, particularly on gbm_low_vol and mixture_normal.

- Chronos performs worst by far. All three versions exceed KL 9.0 on most DGPs, with peaks above 15 on gbm_low_vol and mixture_normal. Chronos struggles across all return scenarios.

**Summary**

- For price forecasting, Chronos leads the ranking, especially on high-volatility processes.

- For return forecasting, Lag-Llama is by far the best, with Toto close behind on simpler DGPs.

- Chronos fails completely on return forecasts at short context.

- Moirai, Tirex, and TimesFM perform moderately, but are sensitive to both model scale and DGP volatility.

- t_garch remains the hardest DGP overall, exposing the weaknesses of nearly all models except Lag-Llama.
