In [None]:
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

In [None]:
# import packages
import os
import argparse
import logging
import time
import numpy as np
import numpy.random as npr
import matplotlib
# matplotlib.use('agg')
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import math
import random


from datetime import timedelta
import scipy.stats
from sklearn.decomposition import PCA
from torch.utils.data import Dataset, DataLoader

import copy

from scipy.interpolate import splev, splrep, interp1d

import sklearn
from sklearn.neighbors import NearestNeighbors

from prettytable import PrettyTable

from torchdiffeq import odeint_adjoint as odeint_adjoint
from torchdiffeq import odeint


In [None]:
from adaptive_select import create_adap_data_buffer
from interpolate_windows import interpolate_window, interpolate_window_adap
from RNN_ODE_functions import fit_non_adap_models_grids, pred_non_adap_models_grids, train_non_adap_models
from RNN_ODE_Adap_functions import fit_adap_models, pred_adap_models, train_adap_models
from network import RNNODE, OutputNN

import data


In [None]:
def seed_everything(seed=1234):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True)

## generate training and testing data

In [None]:
train_x,_ = data.get_data(128)
valid_x,_ = data.get_data(128)
test_x,_ = data.get_data(1024)

In [None]:
num_fine_grid = 1000
num_layers = 3
num_coarse_grid = num_fine_grid // 2**(num_layers-1)


## RNN-ODE One-Step Prediction Experiment (N=64)

In [None]:
# Experiment parameters for RNN-ODE
# num_grids = 65  # N=64, so num_grids = N+1 = 65
num_grids = 500
verbose = 2
method = "naiveEuler"
n_iter = 400
batch_size = 64
obs_dim = 1
n_hidden = 128
n_latent = 128
rescale_const = 1.
time_scale = 10
buffer_start_steps = 2

print(f"Experiment Configuration:")
print(f"Window length (N): {num_fine_grid}")
print(f"Number of grids: {num_grids}")
print(f"Method: {method}")
print(f"Iterations: {n_iter}")

In [None]:
# Storage for testing RMSE results
testing_rmse_list = []

print("Starting RNN-ODE One-Step Prediction Experiment...")
print("=" * 60)

# Run experiment for each repetition - following spiral_example structure
rep = 0
while rep < num_rep:
    print(f"\n------the {rep+1}-th replica:--------")
    
    # Get training and testing data for this repetition
    train_windows0 = train_windows0_list[rep]
    train_ts0 = train_ts0_list[rep]
    test_windows0 = test_windows0_list[rep]
    test_ts0 = test_ts0_list[rep]
    
    valid_index = np.sort(npr.choice(train_windows0.shape[0], 50, replace=False))
    valid_windows0 = train_windows0[valid_index, :, :]
    valid_ts0 = train_ts0[valid_index, :]
    
    # Create validation set - exactly like spiral_example
    train_index = np.delete(np.arange(train_windows0.shape[0]), valid_index)
    train_windows1 = train_windows0[train_index, :, :].clone()
    train_ts1 = train_ts0[train_index, :].clone()
    
    # Add buffer steps like in spiral_example
    if buffer_start_steps > 0:
        deltat = train_ts1[0][-1] - train_ts1[0][-2]
        train_ts1[:,:buffer_start_steps] = (train_ts1[:,:buffer_start_steps] + 
                                           torch.arange(-deltat*buffer_start_steps, 0, deltat))
    
    # Create dataset and dataloader
    train_dataset1 = MyDataset(train_windows1, train_ts1)
    train_loader1 = DataLoader(train_dataset1, shuffle=True, batch_size=batch_size)
    
    input_steps1 = train_windows1.shape[1]
    
    # Train RNN-ODE model - exact parameters from spiral_example
    flag, odefunc1, outputfunc1 = train_non_adap_models(
        train_loader1, input_steps1, valid_windows0, valid_ts0, num_grids,  
        verbose, n_iter=n_iter, method=method, thres1=6.5e2, thres2=6.5e2, weight=(1,0),        
        obs_dim=obs_dim, n_hidden=n_hidden, n_latent=n_latent, num_train_windows=len(train_index), 
        time_scale=time_scale, buffer_start_steps=buffer_start_steps
    )

    if not flag:
        print(f"Training failed for repetition {rep+1}, skipping...")
        continue
    else:
        rep += 1
        print(f"Training successful for repetition {rep}")
    
    # Evaluate standard one-step prediction (RMSE from fit_non_adap_models_grids)
    print('\\nfitting error of testing data:')
    _, fit_L2_err_test1, _ = fit_non_adap_models_grids(
        odefunc1, outputfunc1, test_windows0, test_ts0, num_grids, method=method,
        buffer_start_steps=buffer_start_steps, n_latent=n_latent, obs_dim=obs_dim,
        rescale_const=rescale_const, time_scale=time_scale
    ) # this is RMSE
    
    
    # Store results
    testing_rmse = fit_L2_err_test1.mean().item()  # This is RMSE
    testing_rmse_list.append(testing_rmse)
    
    print(f"Testing RMSE: {testing_rmse:.6f}")

print(f"\\n" + "="*60)
print(f"EXPERIMENT COMPLETED")
print(f"="*60)
print(f"Successful repetitions: {len(testing_rmse_list)}/{num_rep}")

In [None]:

# Display final results
if len(testing_rmse_list) > 0:
    testing_rmse_array = np.array(testing_rmse_list)
    testing_mse_array = testing_rmse_array ** 2
        
    
    print("\\n" + "="*60)
    print("RNN-ODE ONE-STEP PREDICTION RESULTS (N=64)")
    print("="*60)
    print(f"Successful repetitions: {len(testing_rmse_array)}/{num_rep}")
    print(f"Success rate: {len(testing_rmse_array)/num_rep*100:.1f}%")
    print()
    
    print("TESTING L2-error RESULTS:")
    print("-" * 35)
    print(f"Mean: {testing_rmse_array.mean():.6f}")
    print(f"Std:  {testing_rmse_array.std():.6f}")
    print(f"Min:  {testing_rmse_array.min():.6f}")
    print(f"Max:  {testing_rmse_array.max():.6f}")

    
    print()
else:
    print("\\nNo successful repetitions completed!")

In [None]:
# Run focused grid search to find optimal threshold
print("="*60)
print("FOCUSED THRESHOLD SEARCH (0.04 - 0.05)")
print("="*60)
print("Searching in narrow range based on previous results...")
print()

optimal_threshold_new, achieved_steps_new, search_results_new, valid_results_new = grid_search_threshold_updated(target_steps=43, tolerance=0.5)

print(f"\n" + "="*50)
print(f"FOCUSED GRID SEARCH RESULTS")
print(f"="*50)
print(f"Optimal threshold: {optimal_threshold_new:.4f}")
print(f"Achieved avg steps: {achieved_steps_new:.2f}")
print(f"Target steps: 43")
print(f"Search precision: 0.0001")
print(f"Search range: [0.04, 0.05]")
print(f"Total candidates tested: {len(search_results_new)}")
print(f"="*50)

Using fine-grained grid search, the optimal threshold is 0.0398.

## RNN-ODE Adaptive One-Step Prediction Experiment (threshold=0.0358)

In [None]:
# Adaptive RNN-ODE experiment parameters
thres = 0.0358 # Optimal threshold found from grid search
num_fine_adap = 64
num_layers_adap = 3

print(f"RNN-ODE Adaptive One-Step Prediction Experiment")
print(f"Threshold: {thres}")
print(f"Fine grid: {num_fine_adap}")
print(f"Layers: {num_layers_adap}")
print("=" * 60)

In [None]:
# Storage for adaptive testing MSE results
testing_rmse_adap_list = []
num_grids_adap_list = []  # Track actual grid counts

print("Starting RNN-ODE Adaptive One-Step Prediction Experiment...")
print("=" * 60)

# Run experiment for each repetition
rep = 0
while rep < num_rep:
    print(f"\n------Adaptive the {rep+1}-th replica:--------")
    
    # Get training and testing data for this repetition
    train_windows0 = train_windows0_list[rep]
    train_ts0 = train_ts0_list[rep]
    test_windows0 = test_windows0_list[rep]
    test_ts0 = test_ts0_list[rep]
    
    valid_index = np.sort(npr.choice(train_windows0.shape[0], 50, replace=False))
    train_index = np.delete(np.arange(train_windows0.shape[0]), valid_index)
    
    # Create adaptive data buffers - exactly like spiral_example
    train_windows0_adap_buffer, train_ts0_adap_buffer, train_len0_adap_buffer, train_windows_adap_len = create_adap_data_buffer(
        train_windows0, train_ts0, thres, num_fine_adap, num_layers_adap)
    test_windows0_adap_buffer, test_ts0_adap_buffer, test_len0_adap_buffer, _ = create_adap_data_buffer(
        test_windows0, test_ts0, thres, num_fine_adap, num_layers_adap)
    
    # train_windows, 1st half
    train_windows0_adap_buffer_trunc, train_ts0_adap_buffer_trunc, _, _, train_time_index = create_adap_data_buffer(
        train_windows0[:,:(train_ts0.shape[1]+1)//2,:], train_ts0[:,:(train_ts0.shape[1]+1)//2], thres, num_fine_adap//2, num_layers_adap, return_index=1)
    # train_windows, 2nd half
    _, train_ts0_adap_buffer_trunc2, train_len0_adap_buffer_trunc2, _ = create_adap_data_buffer(
        train_windows0[:,-(train_ts0.shape[1]+1)//2:,:], train_ts0[:,-(train_ts0.shape[1]+1)//2:], thres, num_fine_adap//2, num_layers_adap, buffer_start_steps=0)
    
    num_grids_adap_list.append(train_windows_adap_len)
    print(f'Adaptive method, thres = {thres:.4f}, number of grids = {train_windows_adap_len:.1f}')
    
    # Create validation set
    valid_index = np.sort(npr.choice(train_windows0.shape[0], 50, replace=False))
    valid_windows0 = train_windows0[valid_index,:,:]
    valid_ts0 = train_ts0[valid_index,:]
    
    # Prepare adaptive validation data
    valid_windows0_adap_buffer = train_windows0_adap_buffer[valid_index,:,:]
    valid_ts0_adap_buffer = train_ts0_adap_buffer[valid_index,:]
    valid_windows0_adap_buffer_trunc = train_windows0_adap_buffer_trunc[valid_index,:,:]
    valid_ts0_adap_buffer_trunc = train_ts0_adap_buffer_trunc[valid_index,:]
    
    # Create second half data for prediction validation - exactly like spiral_example
    valid_ts_adap_trunc2 = valid_ts0[:, (valid_ts0.shape[1]-1)//2::1]
    valid_len_adap_trunc2 = [(valid_ts0.shape[1]-1)//2+1] * valid_windows0.shape[0]
    
    # Create adaptive dataset
    train_dataset_adap = MyDataset_adap(train_windows0_adap_buffer, train_ts0_adap_buffer, train_len0_adap_buffer)
    train_loader_adap = DataLoader(train_dataset_adap, shuffle=True, batch_size=batch_size)
    
    input_steps_adap = train_windows0_adap_buffer.shape[1]
    
    # Train adaptive RNN-ODE model - follow spiral_example parameters exactly
    flag, odefunc_adap, outputfunc_adap = train_adap_models(
        train_loader_adap, input_steps_adap, verbose,
        valid_window=valid_windows0, valid_ts=valid_ts0, 
        valid_window_adap=valid_windows0_adap_buffer, valid_ts_adap=valid_ts0_adap_buffer,
        valid_window_adap_trunc=valid_windows0_adap_buffer_trunc, valid_ts_adap_trunc=valid_ts0_adap_buffer_trunc,
        buffer_start_steps=buffer_start_steps,
        weight=(1,0), n_iter=n_iter, thres1=1.5e3, 
        valid_len=np.array(train_len0_adap_buffer)[valid_index],
        valid_ts_adap_trunc2=valid_ts_adap_trunc2, valid_len_adap_trunc2=valid_len_adap_trunc2,
        pred_len=(train_ts0.shape[1]-1)//2, rescale_const=rescale_const,  
        n_latent=n_latent, obs_dim=obs_dim, time_scale=time_scale
    )
    
    if not flag:
        continue
    else:
        rep += 1
        print(f"Adaptive training successful for repetition {rep}")
    
    # Evaluate adaptive one-step prediction using fitting function - exactly like spiral_example
    interp_kind = 'cubic'
    print('\\nAdaptive fitting error of testing data (cubic):')
    _, fit_L2_err_adap_test, _ = fit_adap_models(
        odefunc_adap, outputfunc_adap, test_windows0, test_ts0,
        test_windows0_adap_buffer, test_ts0_adap_buffer, test_len0_adap_buffer, 
        buffer_start_steps, interp_kind=interp_kind, rescale_const=rescale_const, 
        n_latent=n_latent, obs_dim=obs_dim, time_scale=time_scale
    )
    
    # Store results
    testing_rmse_adap = (fit_L2_err_adap_test).mean().item()  # Convert RMSE to MSE
    testing_rmse_adap_list.append(testing_rmse_adap)
    
    print(f"Adaptive Testing RMSE: {testing_rmse_adap:.6f}")

print(f"\\n" + "="*60)
print(f"ADAPTIVE EXPERIMENT COMPLETED")
print(f"="*60)
print(f"Successful repetitions: {len(testing_rmse_adap_list)}/{num_rep}")

In [None]:
# Display adaptive results and comparison
if len(testing_rmse_adap_list) > 0:
    testing_rmse_adap_array = np.array(testing_rmse_adap_list)
    num_grids_adap_array = np.array(num_grids_adap_list)
    
    testing_mse_adap_array = testing_rmse_adap_array ** 2
    
    print("\\n" + "="*60)
    print("RNN-ODE ADAPTIVE ONE-STEP PREDICTION RESULTS")
    print("="*60)
    print(f"Successful repetitions: {len(testing_rmse_adap_array)}/{num_rep}")
    print(f"Success rate: {len(testing_rmse_adap_array)/num_rep*100:.1f}%")
    print(f"Average grid points: {num_grids_adap_array.mean():.1f} ± {num_grids_adap_array.std():.1f}")
    print()
    
    print("ADAPTIVE TESTING RMSE RESULTS:")
    print("-" * 35)
    print(f"Mean: {testing_rmse_adap_array.mean():.6f}")
    print(f"Std:  {testing_rmse_adap_array.std():.6f}")
    print(f"Min:  {testing_rmse_adap_array.min():.6f}")
    print(f"Max:  {testing_rmse_adap_array.max():.6f}")
    
    print("ADAPTIVE TESTING MSE RESULTS:")
    print("-" * 35)
    print(f"Mean: {testing_mse_adap_array.mean():.6f}")
    print(f"Std:  {testing_mse_adap_array.std():.6f}")
    print(f"Min:  {testing_mse_adap_array.min():.6f}")
    print(f"Max:  {testing_mse_adap_array.max():.6f}")
    print()
    


In [None]:
# Create visualization plots
if len(testing_mse_adap_list) > 0 and len(testing_mse_list) > 0:
    testing_mse_adap_array = np.array(testing_mse_adap_list)
    testing_mse_nonadap_array = np.array(testing_mse_list)
    num_grids_adap_array = np.array(num_grids_adap_list)
    
    # Create subplots
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    fig.suptitle('RNN-ODE: Adaptive vs Non-Adaptive Method Comparison', fontsize=16, fontweight='bold')
    
    # Plot 1: MSE Comparison Bar Chart
    ax1 = axes[0, 0]
    methods = ['Non-Adaptive', 'Adaptive']
    mse_means = [testing_mse_nonadap_array.mean(), testing_mse_adap_array.mean()]
    mse_stds = [testing_mse_nonadap_array.std(), testing_mse_adap_array.std()]
    colors = ['lightcoral', 'lightblue']
    
    bars = ax1.bar(methods, mse_means, yerr=mse_stds, capsize=5, color=colors, alpha=0.7, edgecolor='black')
    ax1.set_ylabel('Testing MSE')
    ax1.set_title('Average Testing MSE Comparison')
    ax1.grid(True, alpha=0.3)
    
    # Add value labels on bars
    for i, (mean, std) in enumerate(zip(mse_means, mse_stds)):
        ax1.text(i, mean + std + 0.001, f'{mean:.5f}', ha='center', va='bottom', fontweight='bold')
    
    # Plot 2: Individual Repetition Comparison
    ax2 = axes[0, 1]
    min_reps = min(len(testing_mse_nonadap_array), len(testing_mse_adap_array))
    rep_indices = np.arange(1, min_reps + 1)
    
    ax2.plot(rep_indices, testing_mse_nonadap_array[:min_reps], 'o-', 
             color='red', linewidth=2, markersize=8, label='Non-Adaptive', alpha=0.7)
    ax2.plot(rep_indices, testing_mse_adap_array[:min_reps], 's-', 
             color='blue', linewidth=2, markersize=8, label='Adaptive', alpha=0.7)
    
    ax2.set_xlabel('Repetition')
    ax2.set_ylabel('Testing MSE')
    ax2.set_title('MSE by Repetition')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    ax2.set_xticks(rep_indices)
    
    # Plot 3: Grid Points Comparison
    ax3 = axes[1, 0]
    grid_methods = ['Non-Adaptive\n(Fixed)', 'Adaptive\n(Average)']
    grid_means = [num_grids, num_grids_adap_array.mean()]
    grid_stds = [0, num_grids_adap_array.std()]
    colors_grid = ['orange', 'green']
    
    bars_grid = ax3.bar(grid_methods, grid_means, yerr=grid_stds, capsize=5, 
                       color=colors_grid, alpha=0.7, edgecolor='black')
    ax3.set_ylabel('Number of Grid Points')
    ax3.set_title('Grid Points Comparison')
    ax3.grid(True, alpha=0.3)
    
    # Add value labels and reduction percentage
    reduction_pct = ((num_grids - num_grids_adap_array.mean()) / num_grids * 100)
    ax3.text(0, grid_means[0] + 2, f'{grid_means[0]:.0f}', ha='center', va='bottom', fontweight='bold')
    ax3.text(1, grid_means[1] + grid_stds[1] + 2, f'{grid_means[1]:.1f}\\n({reduction_pct:.1f}% reduction)', 
             ha='center', va='bottom', fontweight='bold')
    
    # Plot 4: MSE vs Grid Points Scatter
    ax4 = axes[1, 1]
    # Non-adaptive points (all have same grid count)
    ax4.scatter([num_grids] * len(testing_mse_nonadap_array), testing_mse_nonadap_array,
               color='red', s=100, alpha=0.7, label='Non-Adaptive', marker='o')
    # Adaptive points
    ax4.scatter(num_grids_adap_array, testing_mse_adap_array,
               color='blue', s=100, alpha=0.7, label='Adaptive', marker='s')
    
    ax4.set_xlabel('Number of Grid Points')
    ax4.set_ylabel('Testing MSE')
    ax4.set_title('MSE vs Grid Points')
    ax4.legend()
    ax4.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Create additional detailed comparison plot
    fig2, ax = plt.subplots(1, 1, figsize=(12, 8))
    
    # Box plot comparison
    data_to_plot = [testing_mse_nonadap_array, testing_mse_adap_array]
    box_plot = ax.boxplot(data_to_plot, labels=['Non-Adaptive', 'Adaptive'], 
                         patch_artist=True, notch=True)
    
    # Customize box plot colors
    colors = ['lightcoral', 'lightblue']
    for patch, color in zip(box_plot['boxes'], colors):
        patch.set_facecolor(color)
        patch.set_alpha(0.7)
    
    # Add individual points
    for i, data in enumerate(data_to_plot):
        x = np.random.normal(i+1, 0.04, size=len(data))
        ax.scatter(x, data, alpha=0.6, s=50, color='darkred' if i==0 else 'darkblue')
    
    ax.set_ylabel('Testing MSE')
    ax.set_title('Testing MSE Distribution: Non-Adaptive vs Adaptive', fontsize=14, fontweight='bold')
    ax.grid(True, alpha=0.3)
    
    # Add statistics text
    stats_text = f'''Non-Adaptive: μ={testing_mse_nonadap_array.mean():.6f}, σ={testing_mse_nonadap_array.std():.6f}
Adaptive: μ={testing_mse_adap_array.mean():.6f}, σ={testing_mse_adap_array.std():.6f}
Improvement: {((testing_mse_nonadap_array.mean() - testing_mse_adap_array.mean()) / testing_mse_nonadap_array.mean() * 100):.1f}%
Grid Reduction: {((num_grids - num_grids_adap_array.mean()) / num_grids * 100):.1f}%'''
    
    ax.text(0.02, 0.98, stats_text, transform=ax.transAxes, fontsize=10,
            verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
    
    plt.tight_layout()
    plt.show()
    
else:
    print("Insufficient data for plotting. Need both adaptive and non-adaptive results.")

In [None]:
# LSTM One-Step Prediction Experiment (N=64)
testing_rmse_lstm_list = []

# 可复用你前面 rnn_ode 的配置；如未定义，可启用以下默认
num_grids = 65
verbose = 2
n_iter = 40
batch_size = 64
obs_dim = 2
n_latent = 128
time_scale = 10
buffer_start_steps = 2

print("Starting LSTM One-Step Prediction Experiment...")
print("=" * 60)

rep = 0
while rep < num_rep:
    print(f"\n------LSTM the {rep+1}-th replica:--------")
    # 数据
    train_windows0 = train_windows0_list[rep]
    train_ts0 = train_ts0_list[rep]
    test_windows0 = test_windows0_list[rep]
    test_ts0 = test_ts0_list[rep]

    # 验证集（从训练中抽 50 条）
    valid_index = np.sort(npr.choice(train_windows0.shape[0], 50, replace=False))
    valid_windows0 = train_windows0[valid_index, :, :]
    valid_ts0 = train_ts0[valid_index, :]

    # 训练集 = 剔除验证后的其余
    train_index = np.delete(np.arange(train_windows0.shape[0]), valid_index)
    train_windows1 = train_windows0[train_index, :, :].clone()
    train_ts1 = train_ts0[train_index, :].clone()

    # 与 rnn_ode 一致的 buffer 时间处理（只平移时间戳）
    if buffer_start_steps > 0:
        deltat = train_ts1[0, -1] - train_ts1[0, -2]
        train_ts1[:, :buffer_start_steps] = (
            train_ts1[:, :buffer_start_steps] +
            torch.arange(-deltat * buffer_start_steps, 0, deltat)
        )

    # DataLoader
    train_dataset1 = MyDataset(train_windows1, train_ts1)
    train_loader1 = DataLoader(train_dataset1, shuffle=True, batch_size=batch_size)

    input_steps1 = train_windows1.shape[1]

    # 训练 LSTM
    flag, lstm1 = train_LSTM(
        train_loader1, input_steps1, valid_windows0, valid_ts0, num_grids,
        verbose, n_iter=n_iter, n_latent=n_latent, obs_dim=obs_dim,
        buffer_start_steps=buffer_start_steps, time_scale=time_scale
    )

    if not flag:
        print(f"Training failed for repetition {rep+1}, skipping...")
        continue
    else:
        rep += 1
        print(f"Training successful for repetition {rep}")

    # 测试集拟合 RMSE（与 rnn_ode 一致，fit_* 返回的是 RMSE）
    print("\nLSTM fitting error of testing data:")
    _, fit_L2_err_test_lstm, _ = fit_LSTM_grids(
        lstm1, test_windows0, test_ts0, num_grids,
        buffer_start_steps=buffer_start_steps, verbose=1,
        n_latent=n_latent, obs_dim=obs_dim
    )
    testing_rmse = fit_L2_err_test_lstm.mean().item()
    testing_rmse_lstm_list.append(testing_rmse)
    print(f"LSTM Testing RMSE: {testing_rmse:.6f}")

print("\n" + "=" * 60)
print("LSTM EXPERIMENT COMPLETED")
print("=" * 60)
print(f"Successful repetitions: {len(testing_rmse_lstm_list)}/{num_rep}")


# LEM Baseline

In [None]:
from LEM_functions import train_LEM, fit_LEM_grids

testing_rmse_lem_list = []


print("Starting LEM One-Step Prediction Experiment...")
print("=" * 60)

testing_rmse_lem_list = []
rep = 0
while rep < num_rep:
  print(f"\n------LEM the {rep+1}-th replica:--------")
  # Data for this replica
  train_windows0 = train_windows0_list[rep]
  train_ts0 = train_ts0_list[rep]
  test_windows0 = test_windows0_list[rep]
  test_ts0 = test_ts0_list[rep]

  # Validation split from training
  valid_index = np.sort(npr.choice(train_windows0.shape[0], 50, replace=False))
  valid_windows0 = train_windows0[valid_index, :, :]
  valid_ts0 = train_ts0[valid_index, :]

  # Training set = remove validation
  train_index = np.delete(np.arange(train_windows0.shape[0]), valid_index)
  train_windows1 = train_windows0[train_index, :, :].clone()
  train_ts1 = train_ts0[train_index, :].clone()

  # Buffer timestamp shift (match RNN-ODE handling)
  if buffer_start_steps > 0:
    deltat = train_ts1[0, -1] - train_ts1[0, -2]
    train_ts1[:, :buffer_start_steps] = (
        train_ts1[:, :buffer_start_steps]
        + torch.arange(-deltat * buffer_start_steps, 0, deltat)
    )

# Dataloader
  train_dataset1 = MyDataset(train_windows1, train_ts1)
  train_loader1 = DataLoader(train_dataset1, shuffle=True, batch_size=batch_size)
  input_steps1 = train_windows1.shape[1]

# Train LEM (defaults in LEM_functions.py: lr=0.00904, nhid=16)
  flag, lem = train_LEM(
    train_loader1, input_steps1, valid_windows0, valid_ts0, num_grids,
    verbose, n_iter=n_iter, obs_dim=obs_dim,
    buffer_start_steps=buffer_start_steps, time_scale=time_scale,
    num_train_windows=len(train_index)
  )
  if not flag:
    print(f"Training failed for repetition {rep+1}, skipping...")
    continue
  else:
    rep += 1
    print(f"Training successful for repetition {rep}")

  # Test RMSE (fit metric on uniform grid)
  print("\nLEM fitting error of testing data:")
  _, fit_L2_err_test_lem, _ = fit_LEM_grids(
    lem, test_windows0, test_ts0, num_grids,
    buffer_start_steps=buffer_start_steps, verbose=1, obs_dim=obs_dim
  )
  testing_rmse = fit_L2_err_test_lem.mean().item()
  testing_rmse_lem_list.append(testing_rmse)
  print(f"LEM Testing RMSE: {testing_rmse:.6f}")

print("\n" + "=" * 60)
print("LEM EXPERIMENT COMPLETED")
print("=" * 60)
print(f"Successful repetitions: {len(testing_rmse_lem_list)}/{num_rep}")

In [None]:
rmse = np.array(testing_rmse_lem_list, dtype=float)

print("RMSE RESULTS:")
print("-----------------------------------")
print(f"Mean: {rmse.mean():.6f}")
print(f"Std:  {rmse.std(ddof=0):.6f}")
print(f"Min:  {rmse.min():.6f}")
print(f"Max:  {rmse.max():.6f}")


In [None]:
import matplotlib.pyplot as plt
dataset = "train" # options: "long", "train", "test"

num_to_plot = 3 # how many sequences/windows to visualize
indices = None # e.g., [0, 5, 10]; if None, random pick

def to_numpy(x):
  if isinstance(x, torch.Tensor):
    return x.detach().cpu().numpy()
  return np.asarray(x)

def plot_fhn(x_seq, t_seq, title_prefix=""):  
  x = to_numpy(x_seq)
  t = to_numpy(t_seq)
  C = x.shape[1]
  fig, axes = plt.subplots(2 if C >= 2 else 1, 1, figsize=(8, 4 if C == 1 else 6), sharex=True)
  if C == 1:
    axes = [axes]
# v(t)
  axes[0].plot(t, x[:, 0], color='tab:blue', lw=2)
  axes[0].set_title(f"{title_prefix} v(t)")
  axes[0].set_ylabel("v")
  axes[0].grid(alpha=0.3)
# w(t)
  if C >= 2:
    axes[1].plot(t, x[:, 1], color='tab:orange', lw=2)
    axes[1].set_title(f"{title_prefix} w(t)")
    axes[1].set_xlabel("time")
    axes[1].set_ylabel("w")
    axes[1].grid(alpha=0.3)
  else:
    axes[0].set_xlabel("time")
  plt.tight_layout()
  plt.show()

# Phase portrait if both dims present
  if C >= 2:
    plt.figure(figsize=(5, 5))
    plt.plot(x[:, 0], x[:, 1], lw=2, color='tab:green')
    plt.xlabel("v")
    plt.ylabel("w")
    plt.title(f"{title_prefix} Phase Portrait (w vs v)")
    plt.grid(alpha=0.3)
    plt.tight_layout()
    plt.show()

  

In [None]:
plot_fhn(train_windows0_list[0][0], train_ts0_list[0][0], "Train window #0 -")

