In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
import os
import re
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.ticker import ScalarFormatter
from sklearn.metrics import r2_score
from tqdm import tqdm

from panda.utils.plot_utils import apply_custom_style

apply_custom_style("../../config/plotting.yaml")

In [None]:
fig_save_dir = os.path.join("../../figures", "eval_metrics")

In [None]:
DEFAULT_COLORS = plt.rcParams["axes.prop_cycle"].by_key()["color"]

In [None]:
WORK_DIR = os.getenv("WORK", "")
DATA_DIR = os.path.join(WORK_DIR, "data")
ASSETS_DIR = os.path.join("../../assets")
eval_results_dir = os.path.join(WORK_DIR, "eval_results_distributional_3200")
full_traj_lyap_r_fpath = os.path.join(ASSETS_DIR, "max_lyap_r_test_zeroshot.json")
data_split = "test_zeroshot"
run_name = "lyap"

### Process Saved Metrics

In [None]:
use_chronos_deterministic = False
chronos_dirname = "chronos" if use_chronos_deterministic else "chronos_nondeterministic"

print(f"Using {chronos_dirname} for chronos metrics")


def get_sorted_metric_fnames(save_dir):
    def extract_window(fname):
        m = re.search(r"window-(\d+)", fname)
        return int(m.group(1)) if m else float("inf")

    if not os.path.exists(save_dir):
        return []

    return sorted(
        [f for f in os.listdir(save_dir) if f.endswith(".json") and "distributional_metrics" in f], key=extract_window
    )


run_suffix = run_name if run_name else ""

metrics_save_dirs = {
    "Panda": os.path.join(eval_results_dir, "panda", "panda-21M", data_split, run_suffix),
    "Chronos 20M SFT": os.path.join(eval_results_dir, chronos_dirname, "chronos_t5_mini_ft-0", data_split, run_suffix),
    "Chronos 20M": os.path.join(eval_results_dir, chronos_dirname, "chronos_mini_zeroshot", data_split, run_suffix),
    "Chronos 200M": os.path.join(eval_results_dir, chronos_dirname, "chronos_base_zeroshot", data_split, run_suffix),
    "Dynamix": os.path.join(eval_results_dir, "dynamix", "dynamix", data_split, run_suffix),
}
model_run_names = list(metrics_save_dirs.keys())

metrics_fnames = {}
for model_name, save_dir in metrics_save_dirs.items():
    print(f"Loading {model_name} metrics from: {save_dir}")
    found_fnames = get_sorted_metric_fnames(save_dir)
    metrics_fnames[model_name] = found_fnames
    print(f"Found {len(found_fnames)} {model_name} metrics files: {found_fnames}")

In [None]:
def filter_none(values):
    """Remove None values from a list."""
    return [v for v in values if v is not None]


def accumulate_metrics_lyap(metrics_fnames, metrics_save_dir):
    max_lyap_accum = {
        "gt": defaultdict(lambda: defaultdict(list)),
        "pred": defaultdict(lambda: defaultdict(list)),
        "full_trajectory": defaultdict(lambda: defaultdict(list)),
    }
    prediction_time_accum = defaultdict(list)

    for fname in metrics_fnames:
        with open(os.path.join(metrics_save_dir, fname), "rb") as f:
            metrics = json.load(f)

        metrics = {int(k) if isinstance(k, str) else k: v for k, v in metrics.items()}

        n_pred_intervals = len(metrics)
        print(f"number of prediction intervals in {fname}: {n_pred_intervals}")
        for pred_interval in metrics:
            print(pred_interval)
            data = metrics[pred_interval]
            for system_name, system_entry in tqdm(data, desc=f"Processing {pred_interval}"):
                max_lyap_accum["gt"][pred_interval][system_name].append(
                    system_entry["prediction_horizon"]["max_lyap_gt"]
                )
                max_lyap_accum["pred"][pred_interval][system_name].append(
                    system_entry["prediction_horizon"]["max_lyap_pred"]
                )
                if "full_trajectory" in system_entry:
                    max_lyap_accum["full_trajectory"][pred_interval][system_name].append(
                        system_entry["full_trajectory"]["max_lyap_full_traj"]
                    )
                pred_time = system_entry["prediction_time"]
                prediction_time_accum[system_name].append(pred_time)

    # Now, take the mean across all files for each metric, skipping None values
    max_lyap = {k: defaultdict(dict) for k in max_lyap_accum.keys()}
    prediction_time = {}

    for key in ["gt", "pred", "full_trajectory"]:
        for pred_interval in max_lyap_accum[key]:
            for system_name, values in max_lyap_accum[key][pred_interval].items():
                filtered = filter_none(values)
                max_lyap[key][pred_interval][system_name] = float(np.mean(filtered)) if filtered else None
                # max_lyap[key][pred_interval][system_name] = float(filtered[0]) if filtered else None

    for system_name, times in prediction_time_accum.items():
        times_arr = np.array(filter_none(times))
        prediction_time[system_name] = np.mean(times_arr) if len(times_arr) > 0 else None

    return {
        "max_lyap": max_lyap,
        "prediction_time": prediction_time,
    }


metrics_by_modelname = {}
for model_name in metrics_save_dirs.keys():
    print(f"Accumulating {model_name} metrics...")
    metrics = accumulate_metrics_lyap(metrics_fnames[model_name], metrics_save_dirs[model_name])
    metrics_by_modelname[model_name] = metrics
    print(f"Accumulated {model_name} metrics")

In [None]:
metrics = {
    k: {m: metrics_by_modelname[m][k] for m in metrics_save_dirs.keys()} for k in ["max_lyap", "prediction_time"]
}

nan_masks = {
    model_name: {pred_key: {} for pred_key in ["gt", "pred", "full_trajectory"]}
    for model_name in metrics_save_dirs.keys()
}

# Remove None, NaN, and Inf values
for model_name in metrics_save_dirs.keys():
    for key in ["gt", "pred", "full_trajectory"]:
        if key in metrics["max_lyap"][model_name]:
            for pred_interval in metrics["max_lyap"][model_name][key]:
                system_dict = metrics["max_lyap"][model_name][key][pred_interval]
                # Count and filter None, NaN, and Inf values
                num_nones = sum(1 for v in system_dict.values() if v is None)
                num_nans = sum(1 for v in system_dict.values() if v is not None and np.isnan(v))
                num_infs = sum(1 for v in system_dict.values() if v is not None and np.isinf(v))
                if num_nones > 0 or num_nans > 0 or num_infs > 0:
                    print(
                        f"{model_name} - {key} - {pred_interval}: {num_nones} Nones, {num_nans} NaNs, {num_infs} Infs"
                    )
                    nan_masks[model_name][key][pred_interval] = {
                        s: v is not None and np.isfinite(v) for s, v in system_dict.items()
                    }
                else:
                    nan_masks[model_name][key][pred_interval] = None
                metrics["max_lyap"][model_name][key][pred_interval] = {
                    s: v for s, v in system_dict.items() if v is not None and np.isfinite(v)
                }

## Max Lyapunov Exponent Comparison

In [None]:
pred_length = 1024
model_type = "Panda"
use_full_traj_gt = False
show_figure = True

gt_key, gt_key_name = ("full_trajectory", "Full Trajectory") if use_full_traj_gt else ("gt", "Ground Truth")
pred_key, pred_key_name = "pred", "Prediction"

gt_dict = metrics["max_lyap"][model_type][gt_key].get(pred_length, {})
pred_dict = metrics["max_lyap"][model_type][pred_key].get(pred_length, {})

system_names = set(gt_dict) & set(pred_dict)
x_raw = [gt_dict[s] for s in system_names]
y_raw = [pred_dict[s] for s in system_names]

# Filter out nan/inf
x, y = [], []
num_invalid = 0
for xi, yi in zip(x_raw, y_raw):
    if np.isfinite(xi) and np.isfinite(yi):
        x.append(xi)
        y.append(yi)
    else:
        num_invalid += 1

r2 = r2_score(x, y) if x and y else float("nan")
print(f"Filtered out {num_invalid} invalid (nan/inf) pairs from {len(x_raw)} total.")
print(f"{model_type}: {pred_key_name} vs {gt_key_name} at L_pred={pred_length}, R^2={r2:.3f}")

if show_figure:
    plt.figure(figsize=(4, 4))
    plt.scatter(x, y, color="black", s=5, alpha=0.1)
    plt.xlabel(gt_key_name, fontweight="bold")
    plt.ylabel(pred_key_name, fontweight="bold")
    plt.title(rf"{model_type} $\lambda_{{\max}}$ ($L_{{\mathrm{{pred}}}} = {pred_length}$)", fontweight="bold")

    # y=x line
    if x and y:
        xy_min, xy_max = min(x + y), max(x + y)
    else:
        xy_min, xy_max = 0, 1
    (h1,) = plt.plot([xy_min, xy_max], [xy_min, xy_max], "r--", label=r"$y=x$")

    handles, labels = [], []
    # Best fit line
    if len(x) > 1 and len(y) > 1:
        m, b = np.polyfit(x, y, 1)
        x_fit = np.array([xy_min, xy_max])
        y_fit = m * x_fit + b
        eqn_str = rf"$y = {m:.2f}x$" if abs(b) < 1e-10 else rf"$y = {m:.2f}x {'+' if b >= 0 else '-'} {abs(b):.2f}$"
        eqn_r2_label = rf"{eqn_str}  $(R^2 = {r2:.3f})$" if not np.isnan(r2) else eqn_str
        (h2,) = plt.plot(x_fit, y_fit, "r-", linewidth=1.5, label=eqn_r2_label)
        handles += [h2, h1]
        labels += [eqn_r2_label, r"$y=x$"]
    else:
        handles.append(h1)
        labels.append(r"$y=x$")

    ax = plt.gca()
    ax.xaxis.set_major_formatter(ScalarFormatter(useMathText=True))
    ax.yaxis.set_major_formatter(ScalarFormatter(useMathText=True))
    ax.ticklabel_format(style="sci", axis="both", scilimits=(0, 0))
    if handles:
        ax.legend(handles=handles, labels=labels, loc="lower right", fontsize=8, frameon=True)
    plt.tight_layout()
    # plt.savefig(
    #     os.path.join(
    #         fig_save_dir,
    #         f"max_lyap_r_full_pred_{pred_length}_{model_type}.pdf",
    #     ),
    #     bbox_inches="tight",
    # )

    plt.show()

In [None]:
for model_name in ["Dynamix", "Panda"]:
    tmp = np.array(list(metrics["max_lyap"][model_name]["pred"][3200].values()))
    print(f"{model_name}:")
    print(f"  num nans in tmp: {np.sum(np.isnan(tmp))}")
    print(f"  num infs in tmp: {np.sum(np.isinf(tmp))}")
    # total number of elements in tmp
    print(f"  total number of elements in tmp: {tmp.size}")


In [None]:
pred_length = 3200
use_full_traj_gt = False
show_figure = True

gt_key, gt_key_name = ("full_trajectory", "Full Trajectory") if use_full_traj_gt else ("gt", "Ground Truth")
pred_key, pred_key_name = "pred", "Prediction"

if show_figure:
    plt.figure(figsize=(4, 4))

# Increase hatch linewidth so it appears in PDF
plt.rcParams["hatch.linewidth"] = 2.0

model_types = list(metrics_by_modelname.keys())
x_min = 0
x_max = 8
bins = np.linspace(x_min, x_max, 40)
colors = DEFAULT_COLORS[:4] + ["#FFB5B8"]
print(colors)

gt_lyaps_to_plot = []  # this is simply because e.g. Dynamix has some nans, which we filter out, so we want to plot all the gt_lyaps
for i, model_type in enumerate(model_types):
    if model_type in ["Chronos 20M SFT", "Chronos 20M", "Chronos 200M"]:
        continue
    print(model_type)

    if i == 0:
        gt_dict = metrics["max_lyap"][model_type][gt_key].get(pred_length, {})
    pred_dict = metrics["max_lyap"][model_type][pred_key].get(pred_length, {})
    system_names = set(gt_dict) & set(pred_dict)

    gt_lyaps = [gt_dict[s] for s in system_names]
    if i == 0:
        gt_lyaps_to_plot = gt_lyaps
    pred_lyaps = [pred_dict[s] for s in system_names]

    print(f"{model_type}: {pred_key_name} vs {gt_key_name} at L_pred={pred_length}")
    print(f"  {gt_key_name}: mean={np.mean(gt_lyaps):.3f}, std={np.std(gt_lyaps):.3f}")
    print(f"  {pred_key_name}: mean={np.mean(pred_lyaps):.3f}, std={np.std(pred_lyaps):.3f}")

    if i == 4:
        # For hatches to appear in PDF, we need to manually set the patches
        n, bins_edges, patches = plt.hist(
            pred_lyaps,
            bins=bins,
            color=colors[i],
            alpha=0.6,
            label=model_type,
            density=False,
            histtype="stepfilled",
            linewidth=2,
            zorder=5 - i,
            edgecolor=colors[i],
        )
        # Apply hatch to each patch with contrasting edge color
        for patch in patches:
            patch.set_hatch("////")
            patch.set_edgecolor("hotpink")  # Use black for hatch visibility
    else:
        plt.hist(
            pred_lyaps,
            bins=bins,
            color=colors[i],
            alpha=0.6,
            label=model_type,
            density=False,
            histtype="stepfilled",
            linewidth=2,
            zorder=5 - i,
        )

# Plot ground truth histogram
plt.hist(
    gt_lyaps,
    bins=bins,
    color="black",
    alpha=0.8,
    label=gt_key_name,
    density=False,
    linestyle="--",
    histtype="step",
    linewidth=2,
    zorder=10,
)

plt.xlabel(r"$\lambda_{\max}$", fontweight="bold")
plt.ylabel("Count", fontweight="bold")
plt.title(rf"$\lambda_{{\max}}$ Distribution ($L_{{\mathrm{{pred}}}} = {pred_length}$)", fontweight="bold")
plt.xlim(x_min, x_max)
plt.legend(loc="upper right", fontsize=10, frameon=True)
plt.tight_layout()
plt.yscale("log")

save_path = os.path.join(fig_save_dir, f"max_lyap_r_pred_{pred_length}_comparison.pdf")
print(f"Saving to {save_path}")
plt.savefig(save_path, bbox_inches="tight", dpi=300)
plt.show()


## Summary and Statistical Significance Tests

### Load Lyapunov of Full Trajectories

In [None]:
# Load the Rosenstein Lyapunov Exponents of the full trajectory
full_traj_lyap_r_lst = json.load(open(full_traj_lyap_r_fpath))["4096"]

In [None]:
len(full_traj_lyap_r_lst)

In [None]:
full_traj_lyap_r_dict = {entry[0]: entry[1]["max_lyap_rosenstein"] for entry in full_traj_lyap_r_lst}
print(len(full_traj_lyap_r_dict.keys()))

In [None]:
# make have same order of keys as metrics_by_modelname["Dynamix"]["max_lyap"]["pred"][3200]
full_traj_lyap_r_dict = {
    k: full_traj_lyap_r_dict[k]
    for k in metrics_by_modelname["Dynamix"]["max_lyap"]["pred"][3200].keys()
    if k in full_traj_lyap_r_dict
}

In [None]:
full_traj_lyap_r_dict.keys()

In [None]:
example_metrics_system_names = list(metrics_by_modelname["Dynamix"]["max_lyap"]["pred"][3200].keys())
# get overlap of keys
overlap_keys = set(full_traj_lyap_r_dict.keys()) & set(example_metrics_system_names)
print(len(overlap_keys))

### Predictions vs. Ground Truth

In [None]:
# Choose the prediction interval (pred_length) of 512
pred_lengths = [3200]
use_full_traj_gt = False

In [None]:
if use_full_traj_gt:
    gt_key = "full_trajectory"
    gt_key_name = "Full Trajectory"
else:
    gt_key = "gt"
    gt_key_name = "Ground Truth"

print(f"(Prediction) vs {gt_key_name}")

for model_name in metrics_save_dirs.keys():
    for pred_length in pred_lengths:
        print(f"Prediction Length L_pred = {pred_length}")

        gt_dict = metrics["max_lyap"][model_name][gt_key].get(pred_length, {})
        # gt_dict = full_traj_lyap_r_dict
        pred_dict = metrics["max_lyap"][model_name][pred_key].get(pred_length, {})

        # Find the intersection of system names present in both
        system_names = set(gt_dict.keys()) & set(pred_dict.keys())

        # Prepare x and y data for scatter plot
        x_raw = [gt_dict[sys] for sys in system_names]
        y_raw = [pred_dict[sys] for sys in system_names]

        # Filter out pairs where either value is nan or inf
        x = []
        y = []
        num_invalid = 0
        for xi, yi in zip(x_raw, y_raw):
            # Convert to float if needed and check validity
            try:
                xi_float = float(xi) if not isinstance(xi, (int, float)) else xi
                yi_float = float(yi) if not isinstance(yi, (int, float)) else yi
                if (
                    np.isfinite(xi_float)
                    and np.isfinite(yi_float)
                    and not np.isnan(xi_float)
                    and not np.isnan(yi_float)
                ):
                    x.append(xi_float)
                    y.append(yi_float)
                else:
                    num_invalid += 1
            except (TypeError, ValueError):
                num_invalid += 1

        # Compute R^2 score
        # NOTE: also can swap out with pearsonr here to compute pearson correlation
        if len(x) > 0 and len(y) > 0:
            r2 = r2_score(x, y)
        else:
            r2 = float("nan")

        # print(
        #     f"Filtered out {num_invalid} invalid (nan/inf) pairs from {len(x_raw)} total."
        # )

        print(f"R^2={r2:.3f}")

In [None]:
from scipy.stats import pearsonr

results = defaultdict(dict)
for model_name in metrics_by_modelname.keys():
    print(f"Model: {model_name}")
    lyaps_for_model = metrics_by_modelname[model_name]["max_lyap"]["pred"]
    lyaps_for_gt = metrics_by_modelname[model_name]["max_lyap"]["gt"]
    pred_lengths = list(lyaps_for_model.keys())
    print(pred_lengths)
    for pred_length in pred_lengths:
        if pred_length == "128":  # TODO: remove when we have 128 predictions
            continue
        model_lyaps = [lyaps_for_model[pred_length][sys] for sys in lyaps_for_model[pred_length].keys()]
        gt_lyaps = [lyaps_for_gt[pred_length][sys] for sys in lyaps_for_gt[pred_length].keys()]
        nan_mask = nan_masks[model_name]["pred"][pred_length]
        if nan_mask is not None:
            print(f"Masking gt_lyaps for {model_name} at pred_length {pred_length}")
            # nan_mask is a dict with system names as keys, not a list
            # Filter gt_lyaps by matching system names
            valid_systems = [sys for sys in lyaps_for_gt[pred_length].keys() if nan_mask.get(sys, False)]
            gt_lyaps = [lyaps_for_gt[pred_length][sys] for sys in valid_systems]
            print(f"length of gt_lyaps: {len(gt_lyaps)}")

        # Convert to numpy arrays with explicit float dtype to avoid object dtype issues
        gt_lyaps_array = np.array(gt_lyaps, dtype=np.float64)
        model_lyaps_array = np.array(model_lyaps, dtype=np.float64)
        # print number of nan values in gt_lyaps_array
        print(
            f"number of nans in (gt, model): {np.sum(np.isnan(gt_lyaps_array))}, {np.sum(np.isnan(model_lyaps_array))}"
        )
        print(f"shapes of (gt_lyaps_array, model_lyaps_array): {gt_lyaps_array.shape}, {model_lyaps_array.shape}")

        # Filter out nan and inf values before computing metrics
        valid_mask = np.isfinite(gt_lyaps_array) & np.isfinite(model_lyaps_array)
        gt_lyaps_filtered = gt_lyaps_array[valid_mask]
        model_lyaps_filtered = model_lyaps_array[valid_mask]

        print(f"Filtered to {len(gt_lyaps_filtered)} valid pairs from {len(gt_lyaps_array)} total")

        # measure correlation and wilcoxon signed rank test statistics and pvalues
        if len(gt_lyaps_filtered) > 0:
            pearson_result = pearsonr(gt_lyaps_filtered, model_lyaps_filtered)
            r2_result = r2_score(gt_lyaps_filtered, model_lyaps_filtered)
            print(f"pearson_result: {pearson_result}")
            print(f"r2_result: {r2_result}")
            results[model_name][pred_length] = {
                "pearson_corr": float(f"{pearson_result.statistic:.3f}"),  # type: ignore
                "pearson_corr pval": float(f"{pearson_result.pvalue:.3e}"),  # type: ignore
                "r2": float(f"{r2_result:.3f}"),
            }
        else:
            print("No valid pairs after filtering")
            results[model_name][pred_length] = {
                "pearson_corr": float("nan"),
                "pearson_corr pval": float("nan"),
                "r2": float("nan"),
            }

pd.set_option("display.max_columns", None)
pd.set_option("display.max_rows", None)
pd.set_option("display.width", None)
pd.set_option("display.max_colwidth", None)
pd.DataFrame(results).T

In [None]:
results

### Predictions vs. Full Trajectory

In [None]:
full_lyaps = [full_traj_lyap_r_dict[sys] for sys in full_traj_lyap_r_dict.keys()]

In [None]:
metrics_by_modelname["Dynamix"]["max_lyap"].keys()

In [None]:
# from scipy.stats import pearsonr

# results = defaultdict(dict)
# for model_name in metrics_by_modelname.keys():
#     print(f"Model: {model_name}")
#     lyaps_for_model = metrics_by_modelname[model_name]["max_lyap"]["pred"]
#     for pred_length in lyaps_for_model.keys():
#         if pred_length == "128":  # TODO: remove when we have 128 predictions
#             continue
#         model_lyaps = lyaps_for_model[pred_length]
#         model_lyaps = [model_lyaps[sys] for sys in model_lyaps.keys()]
#         if nan_masks[model_name]["pred"][pred_length] is not None:
#             # mask full_lyaps with nan_masks[model_name]["pred"][pred_length]
#             # full_lyaps is a list, but we need to filter by system names
#             # Convert full_lyaps back to dict first
#             full_lyaps_dict = {sys: full_traj_lyap_r_dict[sys] for sys in full_traj_lyap_r_dict.keys()}
#             masked_full_lyaps = [full_lyaps_dict[sys] for sys in nan_masks[model_name]["pred"][pred_length].keys()]
#         else:
#             masked_full_lyaps = full_lyaps
#         assert len(model_lyaps) == len(masked_full_lyaps), f"{len(model_lyaps)} != {len(masked_full_lyaps)}"

#         # Convert to numpy arrays with explicit float dtype to avoid object dtype issues
#         full_lyaps_array = np.array(masked_full_lyaps, dtype=np.float64)
#         model_lyaps_array = np.array(model_lyaps, dtype=np.float64)

#         # measure correlation and wilcoxon signed rank test statistics and pvalues
#         result = pearsonr(full_lyaps_array, model_lyaps_array)
#         results[model_name][pred_length] = {
#             "corr": float(f"{result.statistic:.3f}"),  # type: ignore
#             "pval": float(f"{result.pvalue:.3e}"),  # type: ignore
#         }


# pd.DataFrame(results).T