# Evaluations DGP Waves

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from mmfm.utils_eval import (
    load_all_fm_models,
    process_all_fm_models,
    get_model_battery_of_best_models,
    predict_on_testset_mmfm,
    predict_on_testset_fsi,
    plot_results,
)
from pathlib import Path
import pandas as pd
import numpy as np
from mmfm.data import dgp_waves_data
import pickle
import re

In [None]:
PROD = True
minimum_seeds = 5
dimension = 2
batch_size = None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
path_name_fsi = "/home/rohbeckm/scratch/results/dgp_waves/results_fsi"

# Data
DGP = "e"
eval_class = 5.0
label_list = np.linspace(1, 10, 10)

ns_per_t_and_c = 50
coupling = "cot"
data_std = 0.025
off_diagonal = 0.0
train_test_split = 0.5
embedding_type = "free"
classifier_free = False
model_type = "mmfm"

### Load all MMFM models

In [None]:
if False & (Path("data") / f"df_gt_{DGP}.parquet").exists():
    df = pd.read_parquet(Path("data") / f"df_gt_{DGP}.parquet")
    # Try to convert each value in column flow_variance to float if possible, otherwise leave value as string
    try:
        df["flow_variance"] = df["flow_variance"].astype(float)
    except ValueError:
        pass
    with open(Path("data") / f"grouping_columns_{DGP}.pkl", "rb") as f:
        grouping_columns = pickle.load(f)
    with open(Path("data") / f"performance_columns_{DGP}.pkl", "rb") as f:
        performance_columns = pickle.load(f)
else:
    df, grouping_columns, performance_columns = load_all_fm_models(
        path="/home/rohbeckm/scratch/results/dgp_waves/results_mmfm",
        production=PROD,
        dgp=DGP,
        embedding_type=embedding_type,
        coupling=coupling,
    )

    # Save everything
    # Convert
    df["flow_variance"] = df["flow_variance"].astype(str)
    df.to_parquet(Path("data") / f"df_gt_{DGP}.parquet")
    # Try to convert each value in column flow_variance to float if possible, otherwise leave value as string
    try:
        df["flow_variance"] = df["flow_variance"].astype(float)
    except ValueError:
        pass

    with open(Path("data") / f"grouping_columns_{DGP}.pkl", "wb") as f:
        pickle.dump(grouping_columns, f)
    with open(Path("data") / f"performance_columns_{DGP}.pkl", "wb") as f:
        pickle.dump(performance_columns, f)

verbose = False


### Preprocess all MMFM models

In [None]:
df = df.loc[
    (df["ns_per_t_and_c"] == ns_per_t_and_c)
    & (df["coupling"] == coupling)
    & (df["data_std"] == data_std)
    & (df["off_diagonal"] == off_diagonal)
    & (df["train_test_split"] == train_test_split)
    & (df["embedding_type"] == embedding_type)
    & (df["model_type"] == model_type)
]

if not classifier_free:
    df = df.loc[~df["classifier_free"]]

# Filter for validation results
df.loc[:, "train"] = df["train"].astype(bool)
df = df.loc[~df["train"]]

print(df.shape)

In [None]:
df, grouping_columns, performance_columns = process_all_fm_models(
    df,
    grouping_columns,
    performance_columns,
    plot=1,
    verbose=1,
    minimum_seeds=minimum_seeds,
    data_cols=[
        "ns_per_t_and_c",
        "coupling",
        "data_std",
        "dgp",
        "interpolation",
        "guidance",
        "embedding_type",
        "classifier_free",
    ],
)

In [None]:
# We only compare to natural cubic splines on FSI later, so let's focus on
# these interpolations for MMFM as well
df_cubic = df.loc[(df["interpolation"] == "cubic")]

### Find best model on Validation Set

In [None]:
# Add weight column to lay focus on conditional generalization
df_cubic["weight"] = 1

# this is our extra validation timepoint
# Note, that for a fair comparison we provide this sample as training data to FSI
add_time_cond = (5, 0.55)

# Filtering for target validation time and c
df_cubic_valid = df_cubic.loc[
    ((df_cubic["c"] == add_time_cond[0]) & (df_cubic["time"].isin([add_time_cond[1]])))
    | ((df_cubic["c"] == eval_class) & (df_cubic["time"].isin([0, 0.25, 0.5])))
    | ((df_cubic["c"] != eval_class) & (df_cubic["time"].isin([0, 0.25, 0.5, 0.75, 1.0])))
]
# We focus a bit more on the validation timepoint than on the training timepoints we saw
# during training
df_cubic_valid = df_cubic_valid.assign(
    weight=np.where((df_cubic_valid["c"] == add_time_cond[0]) & (df_cubic_valid["time"] == add_time_cond[1]), 10, 1)
)

for select_by in [
    "mean_diff_l2_mean",
]:
    print(f"Selecting by {select_by}")
    df_top_valid, model_battery, model_states, model_guidances = get_model_battery_of_best_models(
        df_cubic_valid,
        DGP,
        grouping_columns,
        select_by=select_by,
        device="cuda",
        n_top_models=1,
        n_top_seeds=3,
        model_string="dgp_waves",
        label_list=label_list,
        average_out_seed=True,
    )

    X_test, y_test, t_test, n_classes, timepoints, all_classes = dgp_waves_data(
        coupling, batch_size, dimension, off_diagonal, data_std, ns_per_t_and_c, dgp=DGP, return_data="test"
    )

    df_cubic_results_mmfm, trajectories_test_mmfm = predict_on_testset_mmfm(
        model_battery=model_battery,
        model_states=model_states,
        model_guidances=model_guidances,
        X_test=X_test,
        y_test=y_test,
        t_test=t_test,
        device=device,
        steps=101,
        method="rk4",
    )
    df_cubic_results_mmfm["model"] = "COT-MMFM"

    plot_results(X_test, y_test, t_test, trajectories_test_mmfm, ncols=5, n_classes=n_classes, plot_ode="u_sine")

In [None]:
# Evaluate FSI models on the same TEST data
add_time_cond = (5, 0.55)

df_cubic_results_fsi = []
for seed in range(3):
    filename = f"dgp_waves_{DGP}_{seed}_{ns_per_t_and_c}_{train_test_split}_{off_diagonal}_{data_std}_{dimension}"
    if add_time_cond:
        filename = filename + "_" + re.sub(r"[(), ]", "", str(add_time_cond))
    results_path = Path(path_name_fsi) / filename
    print(results_path)

    df_results_fsi = predict_on_testset_fsi(
        results_path,
        X_test,
        y_test,
        t_test,
        seed=seed,
        coupling=coupling,
        n_classes=len(label_list),
        plot_results=True,
        ncols=5,
    )
    df_cubic_results_fsi.append(df_results_fsi)

df_cubic_results_fsi = pd.concat(df_cubic_results_fsi).reset_index(drop=True)

## OT-CFM Models

In [None]:
# Evaluate FM models
if (Path("data") / f"df_gt_{DGP}.parquet").exists():
    df = pd.read_parquet(Path("data") / f"df_gt_{DGP}.parquet")
    # Try to convert each value in column flow_variance to float if possible, otherwise leave value as string
    try:
        df["flow_variance"] = df["flow_variance"].astype(float)
    except ValueError:
        pass
    with open(Path("data") / f"grouping_columns_{DGP}.pkl", "rb") as f:
        grouping_columns = pickle.load(f)
    with open(Path("data") / f"performance_columns_{DGP}.pkl", "rb") as f:
        performance_columns = pickle.load(f)
else:
    df, grouping_columns, performance_columns = load_all_fm_models(
        path="/home/rohbeckm/scratch/results/dgp_waves/results_mmfm", production=PROD, dgp=DGP
    )

    # Save everything
    # Convert
    df["flow_variance"] = df["flow_variance"].astype(str)
    df.to_parquet(Path("data") / f"df_gt_{DGP}.parquet")
    # Try to convert each value in column flow_variance to float if possible, otherwise leave value as string
    try:
        df["flow_variance"] = df["flow_variance"].astype(float)
    except ValueError:
        pass

    with open(Path("data") / f"grouping_columns_{DGP}.pkl", "wb") as f:
        pickle.dump(grouping_columns, f)
    with open(Path("data") / f"performance_columns_{DGP}.pkl", "wb") as f:
        pickle.dump(performance_columns, f)

verbose = False

df = df.loc[
    (df["ns_per_t_and_c"] == ns_per_t_and_c)
    & (df["coupling"] == coupling)
    & (df["data_std"] == data_std)
    & (df["off_diagonal"] == off_diagonal)
    & (df["train_test_split"] == train_test_split)
    & (df["embedding_type"] == embedding_type)
    & (df["model_type"] == "fm")
]

if not classifier_free:
    df = df.loc[~df["classifier_free"]]

# Filter for validation results
df.loc[:, "train"] = df["train"].astype(bool)
df = df.loc[~df["train"]]

print(df.shape)

df, grouping_columns, performance_columns = process_all_fm_models(
    df,
    grouping_columns,
    performance_columns,
    plot=False,
    verbose=False,
    minimum_seeds=minimum_seeds,
    data_cols=[
        "ns_per_t_and_c",
        "coupling",
        "data_std",
        "dgp",
        "interpolation",
        "guidance",
        "embedding_type",
        "classifier_free",
    ],
)

print(df.shape)

# We only compare to natural cubic splines on FSI later, so let's focus on
# these interpolations for MMFM as well
df_cubic = df.loc[(df["interpolation"] == "cubic")]

# Add weight column to lay focus on conditional generalization
df_cubic["weight"] = 1

# this is our extra validation timepoint
# Note, that for a fair comparison we provide this sample as training data to FSI
add_time_cond = (5, 0.55)

# Filtering for target validation time and c
df_cubic_valid = df_cubic.loc[
    ((df_cubic["c"] == add_time_cond[0]) & (df_cubic["time"].isin([add_time_cond[1]])))
    | ((df_cubic["c"] == eval_class) & (df_cubic["time"].isin([0, 0.25, 0.5])))
    | ((df_cubic["c"] != eval_class) & (df_cubic["time"].isin([0, 0.25, 0.5, 0.75, 1.0])))
]
# We focus a bit more on the validation timepoint than on the training timepoints we saw
# during training
df_cubic_valid = df_cubic_valid.assign(
    weight=np.where((df_cubic_valid["c"] == add_time_cond[0]) & (df_cubic_valid["time"] == add_time_cond[1]), 100, 1)
)

for select_by in [
    "mean_diff_l2_mean",
]:
    # for N in range(5):
    print(f"Selecting by {select_by}")
    df_top_valid, model_battery, model_states, model_guidances = get_model_battery_of_best_models(
        df_cubic_valid,
        DGP,
        grouping_columns,
        select_by=select_by,
        device="cuda",
        n_top_models=1,
        n_top_seeds=5,
        model_string="dgp_waves",
        label_list=label_list,
        average_out_seed=True,
    )

    X_test, y_test, t_test, n_classes, timepoints, all_classes = dgp_waves_data(
        coupling, batch_size, dimension, off_diagonal, data_std, ns_per_t_and_c, dgp=DGP, return_data="test"
    )

    df_cubic_results_fm, trajectories_test = predict_on_testset_mmfm(
        model_battery=model_battery,
        model_states=model_states,
        model_guidances=model_guidances,
        X_test=X_test,
        y_test=y_test,
        t_test=t_test,
        device=device,
        steps=101,
        method="rk4",
    )
    df_cubic_results_fm["model"] = "FM"

    plot_results(X_test, y_test, t_test, trajectories_test, ncols=5, n_classes=n_classes, plot_ode="u_sine")


## Linear COT-MMFM

In [None]:
# Evaluate FM models
if (Path("data") / f"df_gt_{DGP}.parquet").exists():
    df = pd.read_parquet(Path("data") / f"df_gt_{DGP}.parquet")
    # Try to convert each value in column flow_variance to float if possible, otherwise leave value as string
    try:
        df["flow_variance"] = df["flow_variance"].astype(float)
    except ValueError:
        pass
    with open(Path("data") / f"grouping_columns_{DGP}.pkl", "rb") as f:
        grouping_columns = pickle.load(f)
    with open(Path("data") / f"performance_columns_{DGP}.pkl", "rb") as f:
        performance_columns = pickle.load(f)
else:
    df, grouping_columns, performance_columns = load_all_fm_models(
        path="/home/rohbeckm/scratch/results/dgp_waves/results_mmfm", production=PROD, dgp=DGP
    )

    # Save everything
    # Convert
    df["flow_variance"] = df["flow_variance"].astype(str)
    df.to_parquet(Path("data") / f"df_gt_{DGP}.parquet")
    # Try to convert each value in column flow_variance to float if possible, otherwise leave value as string
    try:
        df["flow_variance"] = df["flow_variance"].astype(float)
    except ValueError:
        pass

    with open(Path("data") / f"grouping_columns_{DGP}.pkl", "wb") as f:
        pickle.dump(grouping_columns, f)
    with open(Path("data") / f"performance_columns_{DGP}.pkl", "wb") as f:
        pickle.dump(performance_columns, f)

verbose = False

df = df.loc[
    (df["ns_per_t_and_c"] == ns_per_t_and_c)
    & (df["coupling"] == coupling)
    & (df["data_std"] == data_std)
    & (df["off_diagonal"] == off_diagonal)
    & (df["train_test_split"] == train_test_split)
    & (df["embedding_type"] == embedding_type)
    & (df["model_type"] == "mmfm")
]

if not classifier_free:
    df = df.loc[~df["classifier_free"]]

# Filter for validation results
df.loc[:, "train"] = df["train"].astype(bool)
df = df.loc[~df["train"]]

print(df.shape)

df, grouping_columns, performance_columns = process_all_fm_models(
    df,
    grouping_columns,
    performance_columns,
    plot=False,
    verbose=False,
    minimum_seeds=minimum_seeds,
    data_cols=[
        "ns_per_t_and_c",
        "coupling",
        "data_std",
        "dgp",
        "interpolation",
        "guidance",
        "embedding_type",
        "classifier_free",
    ],
)

print(df.shape)

# We only compare to natural cubic splines on FSI later, so let's focus on
# these interpolations for MMFM as well
df_cubic = df.loc[(df["interpolation"] == "linear")]

# Add weight column to lay focus on conditional generalization
df_cubic["weight"] = 1

# this is our extra validation timepoint
# Note, that for a fair comparison we provide this sample as training data to FSI
add_time_cond = (5, 0.55)

# Filtering for target validation time and c
df_cubic_valid = df_cubic.loc[
    ((df_cubic["c"] == add_time_cond[0]) & (df_cubic["time"].isin([add_time_cond[1]])))
    | ((df_cubic["c"] == eval_class) & (df_cubic["time"].isin([0, 0.25, 0.5])))
    | ((df_cubic["c"] != eval_class) & (df_cubic["time"].isin([0, 0.25, 0.5, 0.75, 1.0])))
]
# We focus a bit more on the validation timepoint than on the training timepoints we saw
# during training
df_cubic_valid = df_cubic_valid.assign(
    weight=np.where((df_cubic_valid["c"] == add_time_cond[0]) & (df_cubic_valid["time"] == add_time_cond[1]), 100, 1)
)

for select_by in [
    "mean_diff_l2_mean",
]:
    print(f"Selecting by {select_by}")
    df_top_valid, model_battery, model_states, model_guidances = get_model_battery_of_best_models(
        df_cubic_valid,
        DGP,
        grouping_columns,
        select_by=select_by,
        device="cuda",
        n_top_models=1,
        n_top_seeds=5,
        model_string="dgp_waves",
        label_list=label_list,
        average_out_seed=True,
    )

    X_test, y_test, t_test, n_classes, timepoints, all_classes = dgp_waves_data(
        coupling, batch_size, dimension, off_diagonal, data_std, ns_per_t_and_c, dgp=DGP, return_data="test"
    )

    df_cubic_results_pcfm, trajectories_test = predict_on_testset_mmfm(
        model_battery=model_battery,
        model_states=model_states,
        model_guidances=model_guidances,
        X_test=X_test,
        y_test=y_test,
        t_test=t_test,
        device=device,
        steps=101,
        method="rk4",
    )
    df_cubic_results_pcfm["model"] = "PCFM"

    plot_results(X_test, y_test, t_test, trajectories_test, ncols=5, n_classes=n_classes, plot_ode="u_sine")


In [None]:
df_results = pd.concat(
    [df_cubic_results_mmfm, df_cubic_results_fsi, df_cubic_results_fm, df_cubic_results_pcfm], ignore_index=True, axis=0
).reset_index(drop=True)

df_results.loc[:, "training"] = False
df_results.loc[(df_results["c"] != eval_class) & (df_results["time"].isin([0, 0.25, 0.5, 0.75, 1.0])), "training"] = (
    True
)
df_results.loc[(df_results["c"] == eval_class) & (df_results["time"].isin([0, 0.25, 0.5])), "training"] = True

#### Results only on holdout timepoints

In [None]:
# Groupby marginal, c, model and compute mean and std
# and change order of multiindex
df_results.loc[~df_results["training"]].groupby(["model", "marginal"]).agg(
    {
        "mmd": ["mean", "std", "max"],
        "mmd_median": ["mean", "std", "max"],
        "wasserstein": ["mean", "std", "max"],
        "mean_diff_l1": ["mean", "std", "max"],
        "mean_diff_l2": ["mean", "std", "max"],
    }
).round(3).sort_index(axis=1).rename(columns={"model": "Model"})["wasserstein"].T

#### Results on both training and holdout timepoints

In [None]:
df_r = (
    df_results.assign(class3=df_results["c"] == 5)
    .groupby(["model", "training", "class3"])
    .agg(
        {
            "mmd": ["mean", "std", "max"],
            "mmd_median": ["mean", "std", "max"],
            "wasserstein": ["mean", "std", "max"],
            "mean_diff_l1": ["mean", "std", "max"],
            "mean_diff_l2": ["mean", "std", "max"],
        }
    )
    .round(2)
    .rename(columns={"model": "Model"})[["mean_diff_l2", "wasserstein"]]
    .reset_index()
)
df_r = df_r.loc[~df_r["training"]]
df_r

# Put everything into one figure

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.integrate import odeint
from mmfm.data import u_sine, u
import cloudpickle
from mmfm.utils import COLORMAP10, ThickerLine2D

plt.style.use(["science", "no-latex"])

params = {
    "axes.labelsize": 18,
    "axes.titlesize": 22,
    "xtick.labelsize": 18,
    "ytick.labelsize": 18,
    "legend.fontsize": 18,
    "legend.title_fontsize": 18,
    "figure.titlesize": 30,
}
plt.rcParams.update(params)

fig, ax = plt.subplots(1, 3, figsize=(20, 5), sharex=True, sharey=True)  #

add_time_cond = (5, 0.55)
train_loader, X_train, y_train, t_train, X_valid, y_valid, t_valid, n_classes, label_list = dgp_waves_data(
    coupling,
    batch_size,
    dimension,
    off_diagonal,
    data_std,
    ns_per_t_and_c,
    dgp=DGP,
    return_data="train-valid",
    add_time_cond=add_time_cond,
    filter_beginning_end=False,
)


def plot_background(ax, X, y, t, arrows=False, legend=False):
    df = pd.DataFrame(X.reshape(-1, 2)).assign(condition=y.reshape(-1, 1), time=t.reshape(-1, 1))
    df.columns = ["x", "y", "condition", "time"]
    df = df.loc[~df["condition"].isna()]
    sns.scatterplot(data=df, x="x", y="y", hue="condition", ax=ax, s=10, palette=COLORMAP10, legend=legend)
    if arrows:
        for c in label_list:
            t = np.linspace(0, 1, 101)
            sol = odeint(u_sine, (0, 0), t, args=(c,))
            ax.plot(sol[:, 0], sol[:, 1], color="gray", alpha=0.25)
            for idx in [20, 40, 60, 80]:
                ax.arrow(
                    sol[idx, 0],
                    sol[idx, 1],
                    sol[idx + 1, 0] - sol[idx, 0],
                    sol[idx + 1, 1] - sol[idx, 1],
                    color="gray",
                    alpha=0.25,
                    head_width=0.05,
                    head_length=0.05,
                    fc="black",
                    ec="black",
                )
    return ax


#
# Figure 1 is true vector field
#
ax[0] = plot_background(ax[0], X_train, y_train, t_train, arrows=True)
ax[0].set_title("True Vector Field (Phase Diagram)\nwith Training Data")

#
# Figure 2 shows FSI interpolation
#
ax[1] = plot_background(ax[1], X_train, y_train, t_train, arrows=False)
ax[1].set_title("Predicted Trajectory FSI")

# Plot interpolation between one sample per condition, use a natural cubic spline
for seed in range(1):
    filename = f"dgp_waves_{DGP}_{seed}_{ns_per_t_and_c}_{train_test_split}_{off_diagonal}_{data_std}_{dimension}"
    if add_time_cond:
        filename = filename + "_" + re.sub(r"[(), ]", "", str(add_time_cond))
    results_path = Path(path_name_fsi) / filename

    if coupling == "ot":
        name = "model_ot_fsi.pkl"
    elif coupling == "cot":
        name = "model_cot_fsi.pkl"
    elif coupling == "None":
        name = "model_fsi.pkl"
    else:
        raise ValueError("Coupling not recognized.")

    with open(results_path / name, "rb") as f:
        fsi_model = cloudpickle.load(f)

    # Plot trajectory
    for idx_c, condition in enumerate([1, 5, 10]):
        color_conditions = list(label_list).index(condition)
        T = 100
        trajectory = np.nan * np.ones(shape=(T, 10, 2))
        for idx, sample in enumerate(range(5)):
            sample_c = np.where(y_test[:, 0] == condition)[0]
            for tx in range(T):
                transport_c = fsi_model.interpolate_from_x0(
                    X=X_test[sample_c[sample], 0][None, :],
                    y=condition,
                    t_query=tx / T,
                )
                trajectory[tx, idx] = transport_c

        for sample in range(5):
            for t in range(T - 1):
                # Plot from t to t+1
                ax[1].plot(
                    trajectory[t : t + 2, sample, 0],
                    trajectory[t : t + 2, sample, 1],
                    color=COLORMAP10[color_conditions],
                    lw=1,
                )

#
# Figure 3 shows MMFM Predicted Trajectory
#
ax[2] = plot_background(ax[2], X_train, y_train, t_train, arrows=False, legend=True)
ax[2].set_title("Predicted Trajectory COT-MMFM")

# Plot interpolation between one sample per condition, use a natural cubic spline
for seed in range(1):
    for idx, condition in enumerate([1, 5, 10]):
        color_conditions = list(label_list).index(condition)
        sub_trajectories = trajectories_test_mmfm[0][seed][:, ((color_conditions) * 50) : ((color_conditions + 1) * 50)]
        for sample in range(5):
            # sample_c = np.where(y_test[:, 0] == condition)[0]
            for t in range(T - 1):
                ax[2].plot(
                    sub_trajectories[t : t + 2, sample, 0],
                    sub_trajectories[t : t + 2, sample, 1],
                    lw=1,
                    color=COLORMAP10[color_conditions],
                )

# Create a single legend for the figure
handles, labels = fig.gca().get_legend_handles_labels()
by_label = dict(zip(labels, handles))
fig.legend(
    by_label.values(),
    by_label.keys(),
    loc="center left",
    bbox_to_anchor=(0.86, 0.5),
    handler_map={plt.Line2D: ThickerLine2D()},
)

for k in range(3):
    ax[k].set_xlabel("x")
    ax[k].set_ylabel("y")
    # Set x limits
    ax[k].set_xlim(-0.15, 2.15)
    ax[k].set_ylim(-3, 2.5)

# Remove legend from third subplot
ax[2].get_legend().remove()
plt.tight_layout()
fig.subplots_adjust(right=0.85)
# Add dots at integer coordinates
for a in ax.flatten():
    x_dots = np.linspace(-0.1, 2.1, 23)
    y_dots = np.linspace(-3.1, 2.2, 24)
    for x in x_dots:
        for y in y_dots:
            a.scatter(x, y, color="gray", s=2, alpha=0.5, marker="+")

# Save figure as svg and png
plt.savefig("/home/rohbeckm/code/mmfm/figures/fig2_waves_extrapol.png", bbox_inches="tight")
plt.show()

In [None]:
from mmfm.mmfm_utils import plot_results_mmfm

train_loader, X_train, y_train, t_train, X_valid, y_valid, t_valid, n_classes, label_list = dgp_waves_data(
    coupling, batch_size, dimension, off_diagonal, data_std, ns_per_t_and_c, dgp=DGP, return_data="train-valid"
)

idx_plot = []
for c in np.unique(y_test[:, 0]):
    idx = np.where(y_test[:, 0] == c)[0][:2]
    idx_plot.append(idx)
idx_plot = [x.item() for x in np.array(idx_plot).flatten()]

plot_results_mmfm(
    X=X_train,
    y=y_train,
    t=t_train,
    trajectory=trajectories_test_mmfm[0][0],
    idx_plot=idx_plot,
    n_classes=n_classes if n_classes is not None else 9,
    title="",
    save=True,
    filepath="/home/rohbeckm/code/mmfm/figures/figure_waves_mmfm_all_extrapol.png",
    s=5,
    ncols=5,
    plot_ode="u_sine",
    paper_style=True,
)