# Distributional Metrics Analysis

This notebook analyzes and visualizes distributional metrics across multiple context windows and metrics computation runs.

## Overview

The notebook processes distributional metrics (Hellinger distance and KL divergence) from multiple evaluation result directories, accumulates them across different prediction windows, and generates comparative visualizations and statistics.

## Key Features

### 1. **Multi-Directory Data Loading**
- Loads metrics from multiple evaluation directories specified in `eval_results_dir_lst`
- Automatically discovers and sorts metric files by prediction horizon, then accumulates the average across all context windows.

### 2. **Models Analyzed**
- **Panda**: 21M parameter Panda model
- **Chronos 20M SFT**: Fine-tuned Chronos 20M model
- **Chronos 20M**: Chronos 20M model
- **Chronos 200M**: Chronos 200M model
- **DynaMix**: DynaMix model

### 3. **Metrics Computed**
- **Average Hellinger Distance**: Average per-dimension spectral Hellinger distance, using the implementation from the DynaMix authors.
- **KL Divergence**: Geometric misalignment, using the implementation from DynaMix authors.
- **Prediction Time**: Computational efficiency metrics (Currenlty not written to, so will always be 0.0).

### 4. **Analysis Types**
- Distribution histograms for individual metrics
- Model-to-model comparisons (differences in KLD and Hellinger distance)
- Statistics across multiple prediction horizons (512, 1024, 2048, 3072, 3584)
- Support for both prediction horizon and full trajectory evaluations

## Configuration

Key parameters:
- `eval_results_dir_lst`: List of directories containing evaluation results
- `data_split`: Data split to analyze (default: "test_zeroshot")
- `run_name`: Optional run suffix for organizing results (default: "fdiv")
- `use_chronos_deterministic`: Flag for deterministic vs non-deterministic Chronos results

## Output

- Figures saved to `../../figures/eval_metrics/`
- Statistical summaries printed to stdout
- Mean ± standard deviation for all metrics across prediction horizons


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
import os
import re
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

from panda.utils.plot_utils import apply_custom_style

apply_custom_style("../../config/plotting.yaml")

In [None]:
fig_save_dir = os.path.join("../../figures", "eval_metrics")
os.makedirs(fig_save_dir, exist_ok=True)

In [None]:
DEFAULT_COLORS = plt.rcParams["axes.prop_cycle"].by_key()["color"]

In [None]:
WORK_DIR = os.getenv("WORK", "")
DATA_DIR = os.path.join(WORK_DIR, "data")
# eval_results_dir = os.path.join(WORK_DIR, "eval_results_distributional_long")
eval_results_dir_lst = [
    os.path.join(WORK_DIR, "eval_results_distributional_long"),
    os.path.join(WORK_DIR, "eval_results_distributional_3072"),
    # os.path.join(WORK_DIR, "old_eval_results_backup/eval_results_distributional_longest"),
    # os.path.join(WORK_DIR, "old_eval_results_backup/eval_results_distributional"),
]
data_split = "test_zeroshot"
run_name = "fdiv"
# run_name = "fdiv_kld-gmm"

In [None]:
use_chronos_deterministic = True
chronos_dirname = "chronos" if use_chronos_deterministic else "chronos_nondeterministic"

print(f"Using {chronos_dirname} for chronos metrics")


def get_sorted_metric_fnames(save_dir):
    def extract_window(fname):
        m = re.search(r"window-(\d+)", fname)
        return int(m.group(1)) if m else float("inf")

    if not os.path.exists(save_dir):
        return []

    return sorted(
        [f for f in os.listdir(save_dir) if f.endswith(".json") and "distributional_metrics" in f], key=extract_window
    )


run_suffix = run_name if run_name else ""

# Define model paths relative to eval_results_dir
model_path_templates = {
    "Panda": ("panda", "panda-21M", data_split, run_suffix),
    "Chronos 20M SFT": (chronos_dirname, "chronos_t5_mini_ft-0", data_split, run_suffix),
    "Chronos 20M": (chronos_dirname, "chronos_mini_zeroshot", data_split, run_suffix),
    "Chronos 200M": (chronos_dirname, "chronos_base_zeroshot", data_split, run_suffix),
    "DynaMix": ("dynamix", "dynamix", data_split, run_suffix),
}
model_run_names = list(model_path_templates.keys())

# Collect all directories and filenames for each model across all eval_results_dirs
metrics_dirs_and_fnames = {model_name: [] for model_name in model_run_names}

for eval_results_dir in eval_results_dir_lst:
    print(f"\n{'=' * 80}")
    print(f"Processing eval_results_dir: {eval_results_dir}")
    print(f"{'=' * 80}")

    for model_name, path_parts in model_path_templates.items():
        save_dir = os.path.join(eval_results_dir, *path_parts)
        print(f"Loading {model_name} metrics from: {save_dir}")
        found_fnames = get_sorted_metric_fnames(save_dir)

        if found_fnames:
            metrics_dirs_and_fnames[model_name].append({"save_dir": save_dir, "fnames": found_fnames})
            print(f"  Found {len(found_fnames)} files: {found_fnames}")
        else:
            print("  No metrics files found")

# Print summary
print(f"\n{'=' * 80}")
print("Summary of collected metrics:")
print(f"{'=' * 80}")
for model_name in model_run_names:
    total_files = sum(len(entry["fnames"]) for entry in metrics_dirs_and_fnames[model_name])
    total_dirs = len(metrics_dirs_and_fnames[model_name])
    print(f"{model_name}: {total_files} files from {total_dirs} director(ies)")

In [None]:
# Example metrics file
# Get the first directory and filename for Panda
first_dir_entry = metrics_dirs_and_fnames["Panda"][0]
metrics_fpath = os.path.join(first_dir_entry["save_dir"], first_dir_entry["fnames"][0])
with open(metrics_fpath, "rb") as f:
    metrics = json.load(f)

# Convert string keys to integers
metrics = {int(k): v for k, v in metrics.items()}

print(metrics.keys())

In [None]:
def accumulate_metrics(dirs_and_fnames_list):
    """Accumulate distributional metrics across multiple files and directories.

    Args:
        dirs_and_fnames_list: List of dicts with 'save_dir' and 'fnames' keys
    """
    HORIZONS = ["prediction_horizon", "full_trajectory"]
    METRICS = ["avg_hellinger_distance", "kl_divergence"]

    # Initialize accumulators
    accum = {metric: {horizon: defaultdict(lambda: defaultdict(list)) for horizon in HORIZONS} for metric in METRICS}

    # Accumulate values from all directories and files
    for dir_entry in dirs_and_fnames_list:
        metrics_save_dir = dir_entry["save_dir"]
        metrics_fnames = dir_entry["fnames"]

        print(f"\n  Processing directory: {metrics_save_dir}")

        for fname in metrics_fnames:
            with open(os.path.join(metrics_save_dir, fname), "rb") as f:
                metrics = json.load(f)
            metrics = {int(k) if isinstance(k, str) else k: v for k, v in metrics.items()}

            print(f"    Processing {fname}: {len(metrics)} prediction interval(s)")

            for pred_interval, data in metrics.items():
                for system_name, system_entry in tqdm(data, desc=f"    Interval {pred_interval}"):
                    # Process each horizon
                    for horizon in HORIZONS:
                        if horizon in system_entry:
                            for metric in METRICS:
                                accum[metric][horizon][pred_interval][system_name].append(system_entry[horizon][metric])

    # Compute means, filtering None values
    def compute_means(data_accum):
        result = {horizon: defaultdict(dict) for horizon in HORIZONS}
        for horizon in HORIZONS:
            for pred_interval, systems in data_accum[horizon].items():
                for system_name, values in systems.items():
                    filtered = [v for v in values if v is not None]
                    result[horizon][pred_interval][system_name] = float(np.mean(filtered)) if filtered else None
        return result

    return {
        "avg_hellinger": compute_means(accum["avg_hellinger_distance"]),
        "kld": compute_means(accum["kl_divergence"]),
    }


metrics_by_modelname = {}
print(f"\n{'=' * 80}")
print("Accumulating metrics for all models:")
print(f"{'=' * 80}")
for model_name in model_run_names:
    print(f"\nAccumulating {model_name} metrics...")
    metrics = accumulate_metrics(metrics_dirs_and_fnames[model_name])
    metrics_by_modelname[model_name] = metrics
    print(f"Completed {model_name} metrics accumulation")

In [None]:
metrics = {k: {m: metrics_by_modelname[m][k] for m in model_run_names} for k in ["avg_hellinger", "kld"]}

In [None]:
metrics["avg_hellinger"]["Panda"]["prediction_horizon"].keys()

In [None]:
values = list(metrics["avg_hellinger"]["Panda"]["prediction_horizon"][3584].values())
num_nones = sum(v is None for v in values)
num_nans = sum(np.isnan(v) for v in values if v is not None)
print(f"Number of None values: {num_nones}")
print(f"Number of NaN values: {num_nans}")
print(f"Number of values: {len(values)}")

In [None]:
values = list(metrics["kld"]["Panda"]["prediction_horizon"][3584].values())
num_nones = sum(v is None for v in values)
num_nans = sum(np.isnan(v) for v in values if v is not None)
print(f"Number of None values: {num_nones}")
print(f"Number of NaN values: {num_nans}")
print(f"Number of values: {len(values)}")

In [None]:
pred_length = 3584
horizon_name = "prediction_horizon"

show_chronos_zs = False
show_chronos_sft = True


def filter_nans(values):
    arr = [float(v) for v in values if v is not None and not (isinstance(v, float) and np.isnan(v))]
    return np.array(arr, dtype=float)


avg_hellinger = {}
for model_key in model_run_names:
    avg_hellinger[model_key] = filter_nans(metrics["avg_hellinger"][model_key][horizon_name][pred_length].values())

# colors = DEFAULT_COLORS
colors = DEFAULT_COLORS[:4] + ["#FFB5B8"]
print(colors)

num_bins = 50
plt.figure(figsize=(4, 4))
all_hellinger = np.concatenate(list(avg_hellinger.values()))
print(f"min hellinger: {all_hellinger.min()}, max hellinger: {all_hellinger.max()}")
bins = np.histogram_bin_edges(all_hellinger, bins=num_bins)

# Increase hatch linewidth so it appears in PDF
plt.rcParams["hatch.linewidth"] = 2.0

alpha_val = 0.6

for i, (label, vals) in enumerate(avg_hellinger.items()):
    if not show_chronos_zs and label in ["Chronos 200M", "Chronos 20M"]:
        continue
    if not show_chronos_sft and label == "Chronos 20M SFT":
        continue
    if label == "DynaMix":
        # For hatches to appear in PDF, we need to manually set the patches
        n, bins_edges, patches = plt.hist(
            vals,
            bins=bins,
            color=colors[i],
            alpha=alpha_val,
            label=label,
            histtype="stepfilled",
            linewidth=2,
            # zorder=9,
            zorder=10 - i,
            edgecolor=colors[i],
        )
        # Apply hatch to each patch with contrasting edge color
        for patch in patches:
            patch.set_hatch("////")
            patch.set_edgecolor("hotpink")  # Use hotpink for hatch visibility
    else:
        plt.hist(
            vals,
            bins=bins,
            color=colors[i],
            edgecolor=colors[i],
            alpha=alpha_val,
            zorder=10 - i,
            histtype="stepfilled",
            label=label,
        )

# plt.yscale("log")
plt.ylabel(
    "Count",
    fontweight="bold",
    # fontsize=14,
)
plt.legend(
    loc="upper right",
    frameon=True,
    fontsize=10,
    # fontsize=12,
)
plt.title(
    f"Avg Hellinger ($L_{{\\mathrm{{pred}}}} = {pred_length}$) Last $2048$",
    fontweight="bold",
    # fontsize=15,
)
plt.tight_layout()
plt.savefig(
    os.path.join(fig_save_dir, f"avg_hellinger_distribution_{horizon_name}_{pred_length}.pdf"),
    bbox_inches="tight",
)
plt.show()

In [None]:
pred_length = 3584
horizon_name = "prediction_horizon"

show_chronos_zs = False


# Extract and filter positive KL divergence values
def pos_vals(vals):
    return [x for x in vals if x is not None and x > 0]


kld_dict = {}
for model_key in model_run_names:
    kld_dict[model_key] = pos_vals(metrics["kld"][model_key][horizon_name][pred_length].values())

all_kld_pos = np.concatenate(list(kld_dict.values()))
num_bins = 50
if len(all_kld_pos) > 0:
    bins = np.linspace(all_kld_pos.min(), all_kld_pos.max(), num_bins)
else:
    bins = num_bins
    print("No positive values found")

plt.figure(figsize=(4, 4))

# Increase hatch linewidth so it appears in PDF
plt.rcParams["hatch.linewidth"] = 2.0

alpha_val = 0.6
for i, (label, vals) in enumerate(kld_dict.items()):
    # if not show_chronos_zs and label in ["Chronos 200M", "Chronos 20M"]:
    if not show_chronos_zs and label in ["Chronos 200M", "Chronos 20M"]:
        continue

    # Also skip Chronos 200M to make the plot less crowded
    if label == "Chronos 200M":
        continue

    if label == "DynaMix":
        # For hatches to appear in PDF, we need to manually set the patches
        n, bins_edges, patches = plt.hist(
            vals,
            bins=bins,
            color=colors[i],
            alpha=alpha_val,
            label=label,
            histtype="stepfilled",
            linewidth=2,
            zorder=10 - i,
            edgecolor=colors[i],
        )
        # Apply hatch to each patch with contrasting edge color
        for patch in patches:
            patch.set_hatch("////")
            patch.set_edgecolor("hotpink")  # Use hotpink for hatch visibility
    else:
        plt.hist(
            vals,
            bins=bins,
            color=colors[i],
            edgecolor=colors[i],
            alpha=alpha_val,
            histtype="stepfilled",
            label=label,
            zorder=10 - i,
        )
plt.yscale("log")
plt.ylabel(
    "Count",
    fontweight="bold",
    # fontsize=14,
)
plt.legend(
    loc="upper right",
    frameon=True,
    fontsize=10,
    # fontsize=12,
)
plt.title(
    f"KL Divergence ($L_{{\\mathrm{{pred}}}} = {pred_length}$) Last $2048$",
    fontweight="bold",
    # fontsize=15,
)
plt.tight_layout()
plt.savefig(os.path.join(fig_save_dir, f"kld_distribution_{horizon_name}_{pred_length}_log.pdf"), bbox_inches="tight")
plt.show()

In [None]:
pred_length = 3584
horizon_name = "prediction_horizon"

# Extract KL divergences for each model
full_kld_dict = {
    model_key: pos_vals(metrics["kld"][model_key][horizon_name][pred_length].values()) for model_key in model_run_names
}

# Compute difference between Chronos SFT and Panda
kld_diff = np.array(
    [c - p for c, p in zip(full_kld_dict["Chronos 20M SFT"], full_kld_dict["Panda"]) if c is not None and p is not None]
)

plt.figure(figsize=(4, 4))
plt.hist(kld_diff, bins=30, color="gray", edgecolor="black", alpha=0.7, histtype="stepfilled")
plt.axvline(0, color="k", linestyle="dotted", linewidth=1.5)
plt.xlabel("$D_{{KL}}$ (Chronos SFT - Panda)", fontweight="bold")
plt.ylabel("Count", fontweight="bold")
plt.title(f"Difference in $D_{{KL}}$ ($L_{{\\mathrm{{pred}}}} = {pred_length}$)", fontweight="bold")
plt.tight_layout()
plt.yscale("log")
plt.show()

In [None]:
# pred_lengths = [128, 256, 512, 1024]
pred_lengths = [512, 1024, 2048, 3072, 3584]
horizon_name = "prediction_horizon"
print(f"horizon_name: {horizon_name}")
# Determine which Chronos base model is available

pairs = [
    ("Chronos 20M SFT - Panda", "Chronos 20M SFT", "Panda"),
    ("Chronos 20M - Chronos 20M SFT", "Chronos 20M", "Chronos 20M SFT"),
    ("Chronos 200M - Chronos 20M SFT", "Chronos 200M", "Chronos 20M SFT"),
    ("DynaMix - Panda", "DynaMix", "Panda"),
]
for label, key1, key2 in pairs:
    print(f"{label} KLD Diff:")
    for pred_length in pred_lengths:
        # Get all keys and ensure they exist
        all_keys = set(model_run_names)
        klds = {}
        for key in all_keys:
            klds[key] = np.array(list(metrics["kld"][key][horizon_name][pred_length].values()))

        # Filter out None values when computing difference
        valid_pairs = [(v1, v2) for v1, v2 in zip(klds[key1], klds[key2]) if v1 is not None and v2 is not None]
        if valid_pairs:
            diff = np.array([v1 - v2 for v1, v2 in valid_pairs])
            print(f"  Prediction length {pred_length}: (mean ± std) = {diff.mean():.2f} ± {diff.std():.2f}")
        else:
            print(f"  Prediction length {pred_length}: No valid data")

In [None]:
# pred_lengths = [128, 256, 512, 1024]
pred_lengths = [512, 1024, 2048, 3072, 3584]
horizon_name = "prediction_horizon"
print(f"horizon_name: {horizon_name}")
# Determine which Chronos base model is available

pairs = [
    ("Chronos 20M SFT - Panda", "Chronos 20M SFT", "Panda"),
    ("Chronos 20M - Chronos 20M SFT", "Chronos 20M", "Chronos 20M SFT"),
    ("Chronos 200M - Chronos 20M SFT", "Chronos 200M", "Chronos 20M SFT"),
    ("DynaMix - Panda", "DynaMix", "Panda"),
]
for label, key1, key2 in pairs:
    print(f"{label} Hellinger Diff:")
    for pred_length in pred_lengths:
        # Get all keys and ensure they exist
        all_keys = set(model_run_names)
        hellingers = {}
        for key in all_keys:
            hellingers[key] = np.array(list(metrics["avg_hellinger"][key][horizon_name][pred_length].values()))

        # Filter out None values when computing difference
        valid_pairs = [
            (v1, v2) for v1, v2 in zip(hellingers[key1], hellingers[key2]) if v1 is not None and v2 is not None
        ]
        if valid_pairs:
            diff = np.array([v1 - v2 for v1, v2 in valid_pairs])
            # print(f"Mean Hellinger diff ({label}): {diff.mean():.4f}, Std Hellinger diff: {diff.std():.4f}")
            print(f"  Prediction length {pred_length}: (mean ± std) = {diff.mean():.2f} ± {diff.std():.2f}")
        else:
            print(f"No valid data for {label}")

In [None]:
# pred_lengths = [128, 256, 512, 1024]
pred_lengths = [512, 1024, 2048, 3072, 3584]
# horizon_names = ["prediction_horizon", "full_trajectory"]
horizon_names = ["prediction_horizon"]
for horizon_name in horizon_names:
    print("-" * 100)
    print("KL Divergence")
    print("-" * 100)
    print(f"horizon_name: {horizon_name}")
    for model_key in model_run_names:
        print(f"Model: {model_key}")
        for pred_length in pred_lengths:
            kld = list(metrics["kld"][model_key][horizon_name][pred_length].values())
            # Filter out None and NaN values
            kld_values_filtered = [v for v in kld if v is not None and not np.isnan(v)]
            if kld_values_filtered:
                mean_kld = np.mean(kld_values_filtered)
                std_kld = np.std(kld_values_filtered)
                # print(f"  Prediction length {pred_length}: mean kld = {mean_kld:.4f}, std kld = {std_kld:.4f}")
                print(f"  Prediction length {pred_length}: (mean ± std) = {mean_kld:.2f} ± {std_kld:.2f}")
            else:
                print(f"  Prediction length {pred_length}: No valid data")

In [None]:
# pred_lengths = [128, 256, 512, 1024]
pred_lengths = [512, 1024, 2048, 3072, 3584]
# horizon_names = ["prediction_horizon", "full_trajectory"]
horizon_names = ["prediction_horizon"]
for horizon_name in horizon_names:
    print("-" * 100)
    print("Avg Hellinger Distance")
    print("-" * 100)
    print(f"horizon_name: {horizon_name}")
    for model_key in model_run_names:
        print(f"Model: {model_key}")
        for pred_length in pred_lengths:
            hell_values = list(metrics["avg_hellinger"][model_key][horizon_name][pred_length].values())
            # Filter out None and NaN values
            hell_values_filtered = [v for v in hell_values if v is not None and not np.isnan(v)]
            if hell_values_filtered:
                mean_hell = np.mean(hell_values_filtered)
                std_hell = np.std(hell_values_filtered)
                # print(
                #     f"  Prediction length {pred_length}: mean avg_hellinger = {mean_hell:.4f}, std avg_hellinger = {std_hell:.4f}"
                # )
                print(f"  Prediction length {pred_length}: (mean ± std) = {mean_hell:.2f} ± {std_hell:.2f}")
            else:
                print(f"  Prediction length {pred_length}: No valid data")