In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import pickle
from collections import defaultdict
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
import torch
from dysts.metrics import compute_metrics
from scipy.io import loadmat
from tqdm import tqdm

from panda.chronos.pipeline import ChronosPipeline
from panda.patchtst.pipeline import PatchTSTPipeline
from panda.utils.data_utils import safe_standardize
from panda.utils.plot_utils import DEFAULT_MARKERS, apply_custom_style

apply_custom_style("../config/plotting.yaml")

In [None]:
WORK = os.environ.get("WORK", "")
base_dir = f"{WORK}/physics-datasets"
DEFAULT_COLORS = list(plt.rcParams["axes.prop_cycle"].by_key()["color"])

In [None]:
metrics = ["mse", "mae", "smape"]

fig_dir = "../figures/realdata"
os.makedirs(fig_dir, exist_ok=True)

## Load Model Checkpoints

In [None]:
# run_name = "pft_chattn_emb_w_poly-0"  # NOTE: this is still the best
run_name = "pft_polyfeats_fromscratch_repro-0"

pft_model = PatchTSTPipeline.from_pretrained(
    mode="predict",
    pretrain_path=f"/stor/work/AMDG_Gilpin_Summer2024/checkpoints/{run_name}/checkpoint-final",
    device_map="cuda:2",
)

In [None]:
run_name = "chronos_t5_mini_ft-0"
chronos_ft = ChronosPipeline.from_pretrained(
    f"/stor/work/AMDG_Gilpin_Summer2024/checkpoints/{run_name}/checkpoint-final",
    device_map="cuda:3",
    torch_dtype=torch.float32,
)

chronos_ft_kwargs = {
    "transpose": True,
    "limit_prediction_length": False,
    "num_samples": 1,
    "deterministic": True,
}

In [None]:
run_name = "amazon/chronos-t5-mini"
chronos_zs = ChronosPipeline.from_pretrained(run_name, device_map="cuda:4", torch_dtype=torch.float32)

chronos_zs_kwargs = {
    "transpose": True,
    "limit_prediction_length": False,
    "num_samples": 1,
    "deterministic": True,
}

## Forecast and Plot Utils

In [None]:
def forecast(
    model,
    context: np.ndarray,
    prediction_length: int,
    transpose: bool = False,
    standardize: bool = True,
    differenced: bool = False,
    **kwargs,
) -> np.ndarray:
    """
    Args:
        model: The model to use for forecasting.
        context: The context to forecast (n_timesteps, n_features)
        context_length: The length of the context.
        prediction_length: The length of the prediction.
        transpose: Whether to transpose the data.

    Returns:
        The forecasted data (prediction_length, n_features)
    """
    preprocessed_context = context.copy()

    if differenced:
        differenced_context = np.diff(preprocessed_context, axis=0)
        preprocessed_context = differenced_context.copy()
    if standardize:
        preprocessed_context = safe_standardize(preprocessed_context, axis=0)

    context_tensor = torch.from_numpy(preprocessed_context.T if transpose else preprocessed_context).float()
    pred = model.predict(context_tensor, prediction_length, verbose=False, **kwargs).squeeze().cpu().numpy()
    if transpose:
        pred = pred.T

    if standardize:
        pred = safe_standardize(
            pred,
            axis=0,
            context=differenced_context if differenced else context,
            denormalize=True,
        )
    if differenced:
        pred = np.cumsum(pred, axis=0) + context[-1]

    # prediction length may be shorter than model output length
    return pred[:prediction_length, :] if pred.ndim == 2 else pred[:prediction_length]


def compute_rollout_metrics(
    model,
    data: np.ndarray,
    context_length: int,
    prediction_length: int,
    starts: np.ndarray | list[int] | None = None,
    num_windows: int | None = None,
    step: int = 64,
    metrics: list[str] = ["mse", "mae", "smape"],
    **kwargs,
) -> tuple[
    dict[str, np.ndarray],
    dict[str, np.ndarray],
    np.ndarray | list[int],
    list[np.ndarray],
]:
    if starts is not None:
        assert num_windows is None, "num_windows must be None if starts is provided"
        num_windows = len(starts)
    else:
        if num_windows is None:
            raise ValueError("num_windows must be provided if starts is not provided")
        starts = np.random.randint(0, len(data) - context_length - prediction_length, num_windows)

    assert len(starts) == num_windows, "starts must be a list of length num_windows"
    assert max(starts) < len(data) - context_length - prediction_length, (
        "starts must be less than the length of the data"
    )

    full_metrics = defaultdict(lambda: np.zeros((num_windows, prediction_length // step)))

    predictions = []
    for s in tqdm(range(num_windows), desc="Sampling contexts", total=num_windows):
        start = starts[s]
        context = data[start : start + context_length]
        prediction = forecast(model, context, prediction_length, **kwargs)
        for i in range(0, prediction_length, step):
            pred = prediction[i : i + step]

            gt = data[start + context_length + i : start + context_length + i + step]
            submetrics = compute_metrics(pred, gt, include=metrics)
            for k, v in submetrics.items():
                full_metrics[k][s, i // step] += v
        predictions.append(prediction)
    mean_metrics = {k: v.mean(axis=0) for k, v in full_metrics.items()}
    std_metrics = {k: v.std(axis=0) / np.sqrt(num_windows) for k, v in full_metrics.items()}
    return mean_metrics, std_metrics, starts, predictions


def plot_model_prediction(
    model,
    data: np.ndarray,
    context_length: int,
    prediction_length: int,
    transpose: bool = False,
    standardize: bool = True,
    save_path: str | None = None,
    color: str = "red",
    **kwargs,
):
    context = data[:context_length]
    groundtruth = data[context_length : context_length + prediction_length]
    prediction = forecast(model, context, prediction_length, transpose, standardize, **kwargs)

    total_length = context_length + prediction_length
    context_ts = np.arange(context_length + 1)
    pred_ts = np.arange(context_length, total_length)

    fig = plt.figure(figsize=(15, 4))
    outer_grid = fig.add_gridspec(1, 2, width_ratios=[0.5, 0.5], wspace=0.05)
    gs = outer_grid[1].subgridspec(3, 1, height_ratios=[1 / 3] * 3, wspace=0, hspace=0)
    ax_3d = fig.add_subplot(outer_grid[0], projection="3d")
    ax_3d.plot(*context.T[:3], alpha=0.5, color="black", label="Context")
    ax_3d.plot(*groundtruth.T[:3], linestyle="-", color="black", label="Groundtruth")
    ax_3d.plot(*prediction.T[:3], color=color, label="Prediction")
    ax_3d.legend(loc="upper right", fontsize=8)
    ax_3d.set_xlabel("$x_1$")
    ax_3d.set_ylabel("$x_2$")
    ax_3d.set_zlabel("$x_3$")

    # Make clean projection
    ax_3d.grid(False)
    ax_3d.set_facecolor("white")
    ax_3d.set_xticks([])
    ax_3d.set_yticks([])
    ax_3d.set_zticks([])
    ax_3d.axis("off")

    axes_1d = [fig.add_subplot(gs[i, 0]) for i in range(3)]
    for i, ax in enumerate(axes_1d):
        ax.plot(
            context_ts,
            data[: context_length + 1, i],
            alpha=0.5,
            color="black",
        )
        ax.plot(pred_ts, groundtruth[:, i], linestyle="-", color="black")
        ax.plot(pred_ts, prediction[:, i], color=color)
        ax.set_ylabel(f"$x_{i + 1}$")
        ax.set_aspect("auto")
    axes_1d[-1].set_xlabel("Time")

    if save_path is None:
        plt.show()
    else:
        plt.savefig(save_path)
    plt.close()

In [None]:
def plot_forecast_3d(
    data: np.ndarray,
    predictions_dict: dict[str, np.ndarray],
    context_length: int,
    prediction_length: int,
    figsize: tuple[int, int] = (6, 6),
    show_legend: bool = True,
    legend_kwargs: dict[str, Any] = {},
    save_path: str | None = None,
):
    context = data[: context_length + 1, :3]
    groundtruth = data[context_length : context_length + prediction_length, :3]

    plt.figure(figsize=figsize)
    ax = plt.axes(projection="3d")
    ax._axis3don = False

    # Combine all data to find min/max bounds
    all_data = [context, groundtruth] + [pred[:, :3] for pred in predictions_dict.values()]
    mins = np.array([d.min(axis=0) for d in all_data])
    maxs = np.array([d.max(axis=0) for d in all_data])

    xmin, ymin, zmin = np.min(mins, axis=0)
    xmax, ymax, zmax = np.max(maxs, axis=0)

    ax.xaxis.pane.set_visible(False)
    ax.yaxis.pane.set_visible(False)
    ax.zaxis.pane.set_visible(False)
    ax.grid(False)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_zticks([])

    ax.plot3D(*context.T, alpha=0.1, color="black", zorder=1)
    ax.plot3D(
        *groundtruth.T,
        alpha=0.8,
        color="black",
        linestyle="-",
        zorder=2,
        label="Ground Truth",
    )
    for model_name, prediction in predictions_dict.items():
        ax.plot3D(
            *prediction[:, :3].T,
            label=model_name,
            zorder=10 if model_name == "Panda" else 1,
        )
    if show_legend:
        ax.legend(**legend_kwargs)

    ax.quiver(
        xmin,
        ymax,
        zmin,
        xmax - xmin,
        0,
        0,
        color="black",
        arrow_length_ratio=0.05,
        zorder=5,
    )
    ax.quiver(
        xmin,
        ymax,
        zmin,
        0,
        -ymax + ymin,
        0,
        color="black",
        arrow_length_ratio=0.05,
        zorder=5,
    )
    ax.quiver(
        xmin,
        ymax,
        zmin,
        0,
        0,
        zmax - zmin,
        color="black",
        arrow_length_ratio=0.05,
        zorder=5,
    )

    plt.tight_layout()
    if save_path is not None:
        plt.savefig(save_path, bbox_inches="tight")
    plt.show()

In [None]:
def plot_metric_comparison(
    model_metrics: dict[str, tuple[dict[str, np.ndarray], dict[str, np.ndarray]]],
    prediction_length: int,
    compute_metrics_time_interval: int,
    metric_name: str = "smape",
    colors: list[str] = DEFAULT_COLORS,
    markers: list[str] = DEFAULT_MARKERS,
    title: str | None = None,
    figsize: tuple[float, float] = (4, 3),
    show_legend: bool = True,
    legend_kwargs: dict[str, Any] = {},
    ylim: tuple[float | None, float | None] | None = None,
    save_path: str | None = None,
    metric_name_mapping: dict[str, str] = {"smape": "sMAPE"},
):
    """
    Plot comparison between different models on a given metric.

    Args:
        model_metrics: Dictionary with model names as keys and tuples of (mean_metrics, std_metrics) as values
        metric_name: Name of the metric to plot
        prediction_length: Length of prediction
        compute_metrics_time_interval: Time interval for computing metrics
        save_path: Path to save the figure
    """
    plt.figure(figsize=figsize)
    ts = np.arange(
        compute_metrics_time_interval,
        prediction_length + compute_metrics_time_interval,
        compute_metrics_time_interval,
    )

    for i, (model_name, (mean_metrics, std_metrics)) in enumerate(model_metrics.items()):
        plt.plot(
            ts,
            mean_metrics[metric_name],
            color=colors[i],
            marker=markers[i],
            label=model_name,
        )
        plt.fill_between(
            ts,
            mean_metrics[metric_name] - std_metrics[metric_name],
            mean_metrics[metric_name] + std_metrics[metric_name],
            alpha=0.1,
            color=colors[i],
        )

    metric_name_title = metric_name.upper()
    if metric_name in metric_name_mapping:
        metric_name_title = metric_name_mapping[metric_name]

    plt.ylabel(metric_name_title, fontweight="bold")
    plt.xlabel("Prediction Length", fontweight="bold")
    if show_legend:
        plt.legend(frameon=True, **legend_kwargs)
    plt.xticks(ts)
    plt.tight_layout()
    if title is not None:
        plt.title(title, fontweight="bold")
    if ylim is not None:
        plt.ylim(*ylim)
    if save_path is not None:
        plt.savefig(save_path, bbox_inches="tight")
    plt.show()

# Double Pendulum

In [None]:
SPLIT = "train"
INDEX = 0

fpath = f"{base_dir}/double_pendulum_chaotic/train_and_test_split/dpc_dataset_traintest_4_200_csv/{SPLIT}/{INDEX}.csv"
pendulum_data = np.loadtxt(fpath)
print(pendulum_data.shape)

# data is non-stationary, subsample and detrend it
subsampled_pendulum_data = pendulum_data[::10, -4:]
print(subsampled_pendulum_data.shape)
## The position of the pivot point (mostly constant)
plt.plot(pendulum_data[:, 1], -pendulum_data[:, 0])

## The position of the tip of the first pendulum
plt.plot(pendulum_data[:, 3], -pendulum_data[:, 2])

## The position of the tip of the second pendulum
plt.plot(pendulum_data[:, 5], -pendulum_data[:, 4])

In [None]:
context_length = 512
prediction_length = 128

compute_metrics_time_interval = 16

differenced = False

Panda

In [None]:
pft_prediction = forecast(
    pft_model,
    subsampled_pendulum_data[:context_length],
    prediction_length,
    limit_prediction_length=False,
    sliding_context=True,
    differenced=differenced,
)

Chronos Finetune

In [None]:
chronos_ft_prediction = forecast(
    chronos_ft,
    subsampled_pendulum_data[:context_length],
    prediction_length,
    **chronos_ft_kwargs,
    differenced=differenced,
)

In [None]:
plt.figure(figsize=(4, 4))

start_time = 0
## The position of the tip of the second pendulum
plt.plot(
    subsampled_pendulum_data[start_time : start_time + context_length, 3],
    -subsampled_pendulum_data[start_time : start_time + context_length, 2],
    alpha=0.1,
    color="black",
)
plt.plot(
    subsampled_pendulum_data[
        start_time + context_length : start_time + context_length + prediction_length,
        3,
    ],
    -subsampled_pendulum_data[
        start_time + context_length : start_time + context_length + prediction_length,
        2,
    ],
    alpha=0.8,
    color="black",
    linestyle="-",
)
# get rid of the ticks
plt.xticks([])
plt.yticks([])
plt.plot(pft_prediction[:, 3], -pft_prediction[:, 2])
# plt.plot(chronos_ft_prediction[:, 3], -chronos_ft_prediction[:, 2])
plt.tight_layout()
plt.savefig(f"{fig_dir}/double_pendulum_forecasts.pdf", bbox_inches="tight")
plt.show()

In [None]:
rseed = 123
num_windows_pendulum = 20
rng = np.random.default_rng(rseed)
pendulum_start_times = rng.choice(
    len(subsampled_pendulum_data) - context_length - prediction_length,
    size=num_windows_pendulum,
    replace=False,
)
print(pendulum_start_times)

In [None]:
pft_mean_metrics, pft_std_metrics, _, pft_predictions = compute_rollout_metrics(
    pft_model,
    subsampled_pendulum_data,
    context_length,
    prediction_length=prediction_length,
    starts=pendulum_start_times,
    step=compute_metrics_time_interval,
    sliding_context=True,
    limit_prediction_length=False,
)

In [None]:
chronos_ft_mean_metrics, chronos_ft_std_metrics, _, chronos_ft_predictions = compute_rollout_metrics(
    chronos_ft,
    subsampled_pendulum_data,
    context_length,
    prediction_length=prediction_length,
    starts=pendulum_start_times,
    step=compute_metrics_time_interval,
    **chronos_ft_kwargs,
)

In [None]:
chronos_zs_mean_metrics, chronos_zs_std_metrics, _, chronos_zs_predictions = compute_rollout_metrics(
    chronos_zs,
    subsampled_pendulum_data,
    context_length,
    prediction_length=prediction_length,
    starts=pendulum_start_times,
    step=compute_metrics_time_interval,
    **chronos_zs_kwargs,
)

In [None]:
# sanity check
fig, axes = plt.subplots(3, 1, figsize=(10, 4), sharex=True)
plt.subplots_adjust(hspace=0.0)

total_ts = np.arange(len(subsampled_pendulum_data))
for i, ax in enumerate(axes.flatten()):
    ax.plot(
        total_ts,
        subsampled_pendulum_data[:, i],
        color="black",
        alpha=0.2,
    )
    for j, start_time in enumerate(pendulum_start_times):
        pft_pred_ts = np.arange(start_time + context_length, start_time + context_length + prediction_length)
        chronos_ft_context_ts = np.arange(start_time, start_time + context_length)
        chronos_ft_pred_ts = np.arange(
            start_time + context_length,
            start_time + context_length + prediction_length,
        )
        chronos_zs_context_ts = np.arange(start_time, start_time + context_length)
        chronos_zs_pred_ts = np.arange(
            start_time + context_length,
            start_time + context_length + prediction_length,
        )
        ax.plot(pft_pred_ts, pft_predictions[j][:, i], color=DEFAULT_COLORS[0], alpha=0.1)
        ax.plot(
            chronos_ft_pred_ts,
            chronos_ft_predictions[j][:, i],
            color=DEFAULT_COLORS[1],
            alpha=0.1,
        )
        ax.plot(
            chronos_zs_pred_ts,
            chronos_zs_predictions[j][:, i],
            color=DEFAULT_COLORS[2],
            alpha=0.1,
        )
        # get rid of the ticks
        ax.set_xticks([])
        ax.set_yticks([])

plt.show()

In [None]:
model_metrics = {
    "Panda": (pft_mean_metrics, pft_std_metrics),
    "Chronos 20M SFT": (chronos_ft_mean_metrics, chronos_ft_std_metrics),
    "Chronos 20M": (chronos_zs_mean_metrics, chronos_zs_std_metrics),
}

plot_metric_comparison(
    model_metrics,
    prediction_length,
    compute_metrics_time_interval,
    metric_name="smape",
    ylim=(0, None),
    save_path=f"{fig_dir}/double_pendulum_comparison_smape.pdf",
)

In [None]:
save_dir = "../outputs/double_pendulum"
os.makedirs(save_dir, exist_ok=True)
pickle.dump(model_metrics, open(os.path.join(save_dir, "model_metrics.pkl"), "wb"))

# Eigenworms

In [None]:
INDEX = 9
fpath = f"{base_dir}/worm_behavior/data/worm_{INDEX}.pkl"
worm_data = np.load(fpath, allow_pickle=True)
eigenworms = loadmat(f"{base_dir}/worm_behavior/data/EigenWorms.mat")["EigenWorms"]

# de-NaN the data with linear interpolation
time_idx = np.arange(len(worm_data))
for d in range(worm_data.shape[1]):
    mask = np.isnan(worm_data[:, d])
    if mask.any():
        valid = ~mask
        worm_data[:, d] = np.interp(time_idx, time_idx[valid], worm_data[valid, d])
assert not np.isnan(worm_data).any()

worm_data_subsampled = worm_data[2048::1]
print(worm_data_subsampled.shape)

### Make Video

In [None]:
from IPython.display import HTML
from matplotlib.animation import FuncAnimation


def reconstruct_worm(coeffs, eigenworms, segment_length=1.0):
    """
    Reconstruct a worm from its coefficients and the eigenworms.

    Args:
        coeffs: The coefficients of the worm (n_timesteps, n_eigenworms)
        eigenworms: The eigenworms (n_features, n_eigenworms)
        segment_length: The length of each segment of the worm.

    Returns:
        The reconstructed worm.
    """
    T, nworms = coeffs.shape
    n_segments = eigenworms.shape[0]
    basis = eigenworms[:, :nworms]
    theta = coeffs @ basis.T

    x = np.zeros((T, n_segments + 1))
    y = np.zeros((T, n_segments + 1))
    x[:, 1:] = segment_length * np.cos(theta)
    y[:, 1:] = segment_length * np.sin(theta)

    return x.cumsum(axis=1), y.cumsum(axis=1)


def animate_worm(x, y, num_frames=200, interval=50, save_path=None):
    """
    Create an animation of the worm's movement over time.

    Args:
        x: Array of x coordinates with shape (T, n_segments+1)
        y: Array of y coordinates with shape (T, n_segments+1)
        num_frames: Number of frames to include in the animation
        interval: Time between frames in milliseconds

    Returns:
        HTML animation that can be displayed in the notebook
    """
    fig, ax = plt.subplots(figsize=(8, 6))

    # Set consistent axis limits for the animation
    x_min, x_max = x.min(), x.max()
    y_min, y_max = y.min(), y.max()

    # Add some padding to the limits
    x_padding = (x_max - x_min) * 0.1
    y_padding = (y_max - y_min) * 0.1

    ax.set_xlim(x_min - x_padding, x_max + x_padding)
    ax.set_ylim(y_min - y_padding, y_max + y_padding)
    ax.set_aspect("equal")
    ax.set_title("Worm Movement")

    # Create line and fill objects
    line = ax.plot([], [], "b-", lw=2)[0]
    fill = ax.fill([], [], color="blue")
    time_text = ax.text(0.02, 0.95, "", transform=ax.transAxes)

    # Calculate width profile - increases toward middle, decreases toward ends
    n_points = x.shape[1]
    width_profile = np.zeros(n_points)
    max_width = 3  # Maximum width of the worm body
    for i in range(n_points):
        arg = 2 * i / (n_points - 1) - 1  # normalize to [-1, 1]
        width_profile[i] = max_width * (1 / (1 + np.exp(-8 * (arg + 0.7))) * (1 - 1 / (1 + np.exp(-8 * (arg - 0.7)))))

    def init():
        line.set_data([], [])
        fill[0].set_xy(np.zeros((0, 2)))
        time_text.set_text("")
        return line, fill[0], time_text

    def update(frame):
        # Update the centerline
        line.set_data(x[frame], y[frame])

        # Calculate perpendicular vectors for width
        dx = np.diff(x[frame])
        dy = np.diff(y[frame])
        # Normalize and rotate 90 degrees to get perpendicular direction
        lengths = np.sqrt(dx**2 + dy**2)
        nx = -dy / lengths
        ny = dx / lengths

        # Create polygon vertices for the worm body
        vertices = []

        # Top edge (add points from head to tail)
        for i in range(n_points - 1):
            vertices.append(
                (
                    x[frame][i] + width_profile[i] * nx[i],
                    y[frame][i] + width_profile[i] * ny[i],
                )
            )

        # Bottom edge (add points from tail to head)
        for i in range(n_points - 2, -1, -1):
            vertices.append(
                (
                    x[frame][i] - width_profile[i] * nx[i],
                    y[frame][i] - width_profile[i] * ny[i],
                )
            )

        # Update the fill
        fill[0].set_xy(vertices)
        time_text.set_text(f"Frame: {frame}")

        return line, fill[0], time_text

    # Use a subset of frames if there are too many
    total_frames = min(num_frames, len(x))
    frame_indices = np.linspace(0, len(x) - 1, total_frames, dtype=int)

    anim = FuncAnimation(fig, update, frames=frame_indices, init_func=init, blit=True, interval=interval)
    if save_path is not None:
        anim.save(save_path, writer="ffmpeg")
    plt.close()
    return HTML(anim.to_jshtml())

In [None]:
# # Create and display the animation
# x, y = reconstruct_worm(worm_data_subsampled, eigenworms)
# worm_animation = animate_worm(x[:1000], y[:1000], save_path="../figures/wormanim.mp4")

# worm_animation

### Forecast Worms

In [None]:
context_length = 512
prediction_length = 128

differenced = True

In [None]:
pft_prediction = forecast(
    pft_model,
    worm_data_subsampled[:context_length],
    prediction_length,
    limit_prediction_length=False,
    sliding_context=True,
    differenced=differenced,
)

In [None]:
chronos_ft_prediction = forecast(
    chronos_ft,
    worm_data_subsampled[:context_length],
    prediction_length,
    **chronos_ft_kwargs,
    differenced=differenced,
)

In [None]:
plot_forecast_3d(
    worm_data_subsampled,
    {
        "Panda": pft_prediction,
        # "Chronos 20M SFT": chronos_ft_prediction,
    },
    context_length,
    prediction_length,
    show_legend=False,
    save_path=f"{fig_dir}/worm_comparison.pdf",
)

In [None]:
# sanity check
_ = plot_model_prediction(
    pft_model,
    worm_data_subsampled,
    context_length=context_length,
    prediction_length=prediction_length,
    sliding_context=True,
    limit_prediction_length=False,
    differenced=differenced,
    color=DEFAULT_COLORS[0],
)
_ = plot_model_prediction(
    chronos_ft,
    worm_data_subsampled,
    context_length=context_length,
    prediction_length=prediction_length,
    **chronos_ft_kwargs,
    differenced=differenced,
    color=DEFAULT_COLORS[1],
)

In [None]:
compute_metrics_time_interval = 64
prediction_length = 512

worms_start_times = np.arange(0, len(worm_data_subsampled) - context_length - prediction_length, 1280)
num_windows_worms = len(worms_start_times)

In [None]:
pft_mean_metrics, pft_std_metrics, _, pft_predictions = compute_rollout_metrics(
    pft_model,
    worm_data_subsampled,
    context_length,
    prediction_length,
    starts=worms_start_times,
    step=compute_metrics_time_interval,
    sliding_context=True,
    limit_prediction_length=False,
    differenced=differenced,
)

In [None]:
chronos_ft_mean_metrics, chronos_ft_std_metrics, _, chronos_ft_predictions = compute_rollout_metrics(
    chronos_ft,
    worm_data_subsampled,
    context_length,
    prediction_length,
    starts=worms_start_times,
    step=compute_metrics_time_interval,
    differenced=differenced,
    **chronos_ft_kwargs,
)

In [None]:
chronos_zs_mean_metrics, chronos_zs_std_metrics, _, chronos_zs_predictions = compute_rollout_metrics(
    chronos_zs,
    worm_data_subsampled,
    context_length,
    prediction_length,
    starts=worms_start_times,
    step=compute_metrics_time_interval,
    differenced=differenced,
    **chronos_zs_kwargs,
)

In [None]:
model_metrics = {
    "Panda": (pft_mean_metrics, pft_std_metrics),
    "Chronos 20M SFT": (chronos_ft_mean_metrics, chronos_ft_std_metrics),
    "Chronos 20M": (chronos_zs_mean_metrics, chronos_zs_std_metrics),
}

plot_metric_comparison(
    model_metrics,
    prediction_length,
    compute_metrics_time_interval,
    metric_name="smape",
    save_path=f"{fig_dir}/worms_comparison_smape.pdf",
)

In [None]:
save_dir = "../outputs/eigenworms"
os.makedirs(save_dir, exist_ok=True)
pickle.dump(model_metrics, open(os.path.join(save_dir, "model_metrics.pkl"), "wb"))

# Electronic Circuit

In [None]:
netfpath = f"{base_dir}/electronic_circuit/Structure/Net_1.dat"
subdir = "R1"
fname = "ST_100_3"
fpath = f"{base_dir}/electronic_circuit/{subdir}/{fname}.dat"
net = np.loadtxt(netfpath)
circuit_data = np.loadtxt(fpath)
print(net.shape, circuit_data.shape)

In [None]:
context_length = 512
prediction_length = 512

In [None]:
pft_prediction = forecast(
    pft_model,
    circuit_data[:context_length],
    prediction_length,
    limit_prediction_length=False,
    sliding_context=True,
)

In [None]:
chronos_ft_prediction = forecast(
    chronos_ft,
    circuit_data[:context_length],
    prediction_length,
    **chronos_ft_kwargs,
)

In [None]:
chronos_zs_prediction = forecast(
    chronos_zs,
    circuit_data[:context_length],
    prediction_length,
    **chronos_zs_kwargs,
)

In [None]:
plot_forecast_3d(
    circuit_data,
    {
        "Panda": pft_prediction,
        "Chronos 20M SFT": chronos_ft_prediction,
        "Chronos 20M": chronos_zs_prediction,
    },
    context_length,
    prediction_length,
    show_legend=False,
    legend_kwargs={"loc": "center right", "frameon": True},
    # save_path=f"{fig_dir}/circuit_comparison_{subdir}_{fname}.pdf",
)

In [None]:
# sanity check
_ = plot_model_prediction(
    pft_model,
    circuit_data,
    context_length,
    prediction_length,
    sliding_context=True,
    limit_prediction_length=False,
)
_ = plot_model_prediction(
    chronos_ft,
    circuit_data,
    context_length,
    prediction_length,
    **chronos_ft_kwargs,
    color=DEFAULT_COLORS[1],
)
_ = plot_model_prediction(
    chronos_zs,
    circuit_data,
    context_length,
    prediction_length,
    **chronos_zs_kwargs,
    color=DEFAULT_COLORS[2],
)

### Coupling Strength Scaling Law

In [None]:
fpaths = os.listdir(f"{base_dir}/electronic_circuit/R1")
ec_fpaths = defaultdict(list)

In [None]:
for fpath in fpaths:
    ec_fpaths[int(fpath.split("_")[2][0])].append(fpath)
for k, v in ec_fpaths.items():
    ec_fpaths[k] = sorted(v, key=lambda x: int(x.split("_")[1]))

# subset the data by coupling strength
coupling_strength_interval = 10
ec_fpaths = {k: v[::coupling_strength_interval] for k, v in ec_fpaths.items()}
print(ec_fpaths)

coupling_strengths_lst = [
    int(fpath.split("_")[1]) for fpath in ec_fpaths[1]
]  # Assume same coupling strength for ec splits
print(coupling_strengths_lst)

In [None]:
n_steps = 8
step = prediction_length // n_steps

save_dir = "../outputs/electronic_circuit"
os.makedirs(save_dir, exist_ok=True)

In [None]:
metrics_by_model = {
    "Panda": {k: {m: np.zeros((n_steps, len(v))) for m in metrics} for k, v in ec_fpaths.items()},
    "Chronos 20M SFT": {k: {m: np.zeros((n_steps, len(v))) for m in metrics} for k, v in ec_fpaths.items()},
    "Chronos 20M": {k: {m: np.zeros((n_steps, len(v))) for m in metrics} for k, v in ec_fpaths.items()},
}

for k, v in tqdm(ec_fpaths.items()):
    for i, fpath in tqdm(enumerate(v), desc=f"Processing experiment {k}", total=len(v)):
        circuit_data = np.loadtxt(f"{base_dir}/electronic_circuit/R1/{fpath}")
        pft_prediction = forecast(
            pft_model,
            circuit_data[:context_length],
            prediction_length,
            limit_prediction_length=False,
            sliding_context=True,
        )
        chronos_ft_prediction = forecast(
            chronos_ft,
            circuit_data[:context_length],
            prediction_length,
            **chronos_ft_kwargs,
        )
        chronos_zs_prediction = forecast(
            chronos_zs,
            circuit_data[:context_length],
            prediction_length,
            **chronos_zs_kwargs,
        )

        for chunk, j in enumerate(np.arange(0, prediction_length, prediction_length // n_steps)):
            target = circuit_data[context_length : context_length + j + step]

            curr_preds_by_model = {
                "Panda": pft_prediction[0 : j + step],
                "Chronos 20M SFT": chronos_ft_prediction[0 : j + step],
                "Chronos 20M": chronos_zs_prediction[0 : j + step],
            }
            for model_name, pred in curr_preds_by_model.items():
                model_metrics = compute_metrics(pred, target, include=metrics)
                for metric in metrics:
                    metrics_by_model[model_name][k][metric][chunk, i] = model_metrics[metric]

In [None]:
# save the metrics
for model_name in metrics_by_model.keys():
    pickle.dump(
        metrics_by_model[model_name],
        open(os.path.join(save_dir, f"{model_name}_metrics.pkl"), "wb"),
    )

In [None]:
saved_metrics_path_dict = {
    "Panda": os.path.join(save_dir, "Panda_metrics.pkl"),
    "Chronos 20M SFT": os.path.join(save_dir, "Chronos 20M SFT_metrics.pkl"),
    "Chronos 20M": os.path.join(save_dir, "Chronos 20M_metrics.pkl"),
}

metrics_by_model = {}
for model_name, path in saved_metrics_path_dict.items():
    metrics_by_model[model_name] = pickle.load(open(path, "rb"))

In [None]:
# averaged over experiments
mean_metrics = defaultdict(dict)
std_metrics = defaultdict(dict)

# averaged over experiments at the middle coupling strength
mean_metrics_middle = defaultdict(dict)
std_metrics_middle = defaultdict(dict)

for model_name, model_metrics in metrics_by_model.items():
    for m in metrics:
        metrics_arr = np.array([model_metrics[k][m] for k in ec_fpaths])
        num_exp, _, num_coupling = metrics_arr.shape
        mean_metrics[model_name][m] = np.mean(metrics_arr, axis=0)
        std_metrics[model_name][m] = np.std(metrics_arr, axis=0) / np.sqrt(num_exp)
        mean_metrics_middle[model_name][m] = np.mean(metrics_arr[..., num_coupling // 2], axis=0)
        std_metrics_middle[model_name][m] = np.std(metrics_arr[..., num_coupling // 2], axis=0) / np.sqrt(num_exp)

In [None]:
# averaged over experiments at the middle coupling strength
model_metrics = {
    "Panda": (mean_metrics_middle["Panda"], std_metrics_middle["Panda"]),
    "Chronos 20M SFT": (
        mean_metrics_middle["Chronos 20M SFT"],
        std_metrics_middle["Chronos 20M SFT"],
    ),
    "Chronos 20M": (
        mean_metrics_middle["Chronos 20M"],
        std_metrics_middle["Chronos 20M"],
    ),
}

plot_metric_comparison(
    model_metrics,
    prediction_length,
    prediction_length // n_steps,
    metric_name="smape",
    save_path=f"{fig_dir}/circuit_comparison_smape_@50.pdf",
    legend_kwargs={"loc": "upper left"},
)

In [None]:
plt.figure(figsize=(4, 3))

metric_rollout_length_idx = -1

for model_name in metrics_by_model.keys():
    data = mean_metrics[model_name]["smape"][metric_rollout_length_idx]
    std = std_metrics[model_name]["smape"][metric_rollout_length_idx]

    assert len(data) == len(coupling_strengths_lst)

    plt.plot(coupling_strengths_lst, data, label=model_name)
    plt.fill_between(coupling_strengths_lst, data - std, data + std, alpha=0.2)

plt.xlabel("Coupling Strength", fontweight="bold")
plt.ylabel("sMAPE", fontweight="bold")
plt.legend(loc="upper right", frameon=True)
plt.tight_layout()
plt.savefig(f"{fig_dir}/circuit_coupling_strength_scaling.pdf", bbox_inches="tight")
plt.show()

In [None]:
plt.figure(figsize=(3, 6))
# plt.title(r"% $\Delta$sMAPE", fontweight="bold")
plt.title(r"Log sMAPE Ratio", fontweight="bold")

# Define the starting coupling strength and range
start_coupling = 0
coupling_strengths = np.array(coupling_strengths_lst)
coupling_range = coupling_strengths[coupling_strengths >= start_coupling]

pred_length_start_idx = 0
# Extract the data for the specified coupling range
chronos_data = metrics_by_model["Chronos 20M SFT"][k]["smape"][:, coupling_strengths >= start_coupling]
our_model_data = metrics_by_model["Panda"][k]["smape"][:, coupling_strengths >= start_coupling]

# Calculate percentage error over the specified range
percentage_error = np.log(our_model_data[pred_length_start_idx:, :] / chronos_data[pred_length_start_idx:, :])
print(percentage_error.shape)
# Find the maximum absolute value to center the colormap at zero
vmax = np.abs(percentage_error).max()

# Transpose the data for swapping axes
percentage_error = percentage_error.T

# Flip the y-axis by using origin='upper' and adjusting the extent
# Now prediction length is on x-axis and coupling strength is on y-axis
plt.imshow(
    percentage_error,
    cmap="RdBu",
    label=f"Type-{k}",
    aspect="auto",
    vmin=-vmax,
    vmax=vmax,
    extent=(0, percentage_error.shape[1], coupling_range[-1], coupling_range[0]),
    origin="upper",
)
cbar = plt.colorbar(format="%.1f", shrink=0.75)
plt.xlabel("Prediction Length", fontweight="bold")
plt.xticks(
    np.arange(n_steps - pred_length_start_idx),
    [str(i) for i in np.arange(0, prediction_length, prediction_length // n_steps) + prediction_length // n_steps][
        pred_length_start_idx:
    ],
    rotation=45,
)
plt.ylabel("Coupling Strength", fontweight="bold", labelpad=-2)
plt.tight_layout()
plt.savefig(
    f"{fig_dir}/circuit_coupling_scaling_heatmap_transposed.pdf",
    bbox_inches="tight",
)
plt.show()