**Load necessary modules**

In [1]:
import numpy as np 
from scipy.integrate import odeint
import os, sys 
from pathlib import Path
from os.path import dirname, realpath
script_dir = Path(dirname(realpath('.')))
module_dir = str(script_dir)
sys.path.insert(0, module_dir + '/modules')
import utility as ut
import surrogate as sr
import matplotlib.pyplot as plt
import pandas as pd
import ipywidgets as widgets
from IPython.display import display

**Play with different errors and corresponding $\tau_f$**

In [2]:
L63_data_folder = '../data/L63-trajectories'
data_folder = '../data/L63-SMLR-0'
test = np.load('{}/test.npy'.format(L63_data_folder))[:, :, :]
ba = sr.BatchRunAnalyzer_SMLR(data_folder)
forecast_till = 500
t = np.arange(0., forecast_till*ba.dt, dt) / ba.Lyapunov_time
dim = 3

def update_plot(random_trajectory, w, b, realization, error_threshold):    
    # random_trajectory = int(random_trajectory)
    fig = plt.figure(figsize=(12, 12))
    ax_paths = [fig.add_subplot(321), fig.add_subplot(323), fig.add_subplot(325)]
    ax_se, ax_mse = [fig.add_subplot(322), fig.add_subplot(324)]
    ax_all = ax_paths + [ax_se, ax_mse]
    # for ax in ax_all:
    #     ax.clear()
    model = ba.get_model(w, b, realization)
    predicted = model.multistep_forecast(test[random_trajectory][:, 0], forecast_till)
    tse, se, tmse, mse = model.compute_forecast_time(test[random_trajectory][:, :forecast_till], error_threshold, ba.dt, ba.Lyapunov_time)
    for i, ax in enumerate(ax_paths):
        ax.plot(t, test[random_trajectory][i, :forecast_till], label='truth')
        ax.plot(t, predicted[i, :forecast_till], label='predicted')
        ax.legend(loc='upper right')
        if i == dim-1:
            ax.set_xlabel('t')
        if i == 0:
            ax.set_title('Trajectories')
        ax.set_ylabel(r'$x_{}$'.format(i+1))

    ax_se.plot(t, se[0])
    ax_se.axvline(x=tse[0], c='black', label=r'$\tau_f$')
    ax_se.set_title(r'$e(t_n) =\frac{\|u_{\rm true}(t_n) - u_{\rm predicted}(t_n)\|_2^2}{\|u_{\rm true}(t_n)\|_2^2},\quad$' + '\t' +\
                    r'$\tau_f={:.2f}$'.format(tse[0]))
    ax_se.set_ylabel('squared error')
    ax_se.legend()
    ax_mse.plot(t, mse[0])
    ax_mse.axvline(x=tmse[0], c='black', label=r'$\tau_f$')
    ax_mse.set_title(r'$\bar{e}(t_n)=\frac{1}{n}\sum_{i=1}^ne(t_i),\quad$' + '\t' +\
                    r'$\tau_f={:.2f}$'.format(tmse[0]))
    ax_mse.set_ylabel('mean squared error')
    ax_mse.set_xlabel('t')
    ax_mse.legend()
    plt.subplots_adjust(hspace=0.3)
    plt.show()

random_trajectory_slider = widgets.IntSlider(value=1, min=0, max=500, step=1, description='test path', continuous_update=False)
w_slider = widgets.IntSlider(value=1, min=ba.w_idx[0], max=ba.w_idx[-1], step=1, description='w', continuous_update=False)
b_slider = widgets.IntSlider(value=1, min=ba.b_idx[0], max=ba.b_idx[-1], step=1, description='b', continuous_update=False)
realization_slider = widgets.IntSlider(value=1, min=ba.random_idx[0], max=ba.random_idx[-1], step=1, description='realization', continuous_update=False)
error_threshold_slider = widgets.FloatSlider(value=1, min=0.01, max=1.0, step=0.01, description='error limit', continuous_update=False)
interactive_plot = widgets.interactive(update_plot, random_trajectory=random_trajectory_slider, w=w_slider, b=b_slider, realization=realization_slider,\
                                      error_threshold=error_threshold_slider)
output = interactive_plot.children[-1]
output.layout.height = '1000px'
interactive_plot

NameError: name 'dt' is not defined

In [None]:
def update_plot(freq):
    x = np.linspace(0, 2 * np.pi, 1000)
    y = np.sin(freq * x)
    
    plt.figure(figsize=(8, 4))
    plt.plot(x, y)
    plt.title(f'Sine Wave with Frequency {freq}')
    plt.xlabel('x')
    plt.ylabel('y')
    plt.grid(True)
    plt.show()

freq_slider = widgets.FloatSlider(
    value=1.0,
    min=0.1,
    max=10.0,
    step=0.1,
    description='Frequency:',
    continuous_update=False
)



interactive_plot = widgets.interactive(update_plot, freq=freq_slider)
output = interactive_plot.children[-1]
output.layout.height = '350px'
interactive_plot