In [2]:
import sys
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


from datafold.appfold import EDMD
from scipy.integrate import solve_ivp
from sklearn.pipeline import Pipeline
from datafold.dynfold.transform import TSCRadialBasis
from datafold.pcfold import GaussianKernel, TSCDataFrame

In [71]:
def solve_pendulum(initial_conditions, t_eval):
    npts, _ = initial_conditions.shape

    C = np.array([[1, 0, 0, 0], [0, 0, 1, 0]])
    
    def y_dot(t, y):
        g = 9.8 
        L = 1.5 # Length of pendulum
        m = 1.0 #mass of bob (kg)
        M = 5.0  # mass of cart (kg)
        d1 = 1.0
        d2 = 0.5
        
        
        u = np.random.uniform(-4,4,1)
     
        x_ddot = u[0] - m*L*y[3]*y[3] * np.cos(y[2] ) + m*g*np.cos(y[2]) *  np.sin(y[2])
        x_ddot = x_ddot / ( M+m-m* np.sin(y[2])* np.sin(y[2]) )
        theta_ddot = -g/L * np.cos(y[2] ) -  np.sin( y[2]) / L*x_ddot

        damping_x =  - d1*y[1]
        damping_theta =  - d2*y[3]
        
        x_next = [y[1], x_ddot + damping_x, y[3], theta_ddot + damping_theta, u[0]]
        
        return np.array(x_next)

    #assert initial_conditions.ndim == 4
    assert initial_conditions.shape[1] == 5


    time_series_dfs = []
    ytime_series_dfs = []
    utime_series_dfs = []
    
    for ic in initial_conditions:
        solution = solve_ivp(
            y_dot, t_span=(t_eval[0], t_eval[-1]), y0=ic, t_eval=t_eval
        )

        data_u_values = solution["y"][4]
        data_values = solution["y"][:4]
        
        mdim, ndim = data_values.shape
        
        x_data = data_values[:, 0:ndim-1]
        sub_y = data_values[:, 1:ndim]
        u_data = data_u_values[0:ndim-1]
        
        y_data = C@sub_y
        intermidiate_time = solution["t"]
        
        time_x = intermidiate_time[0:ndim-1]
        
        solution = pd.DataFrame(
            data=x_data.T,
            index=time_x,
            columns=["x1", "x2", "x3", "x4"],
        )
        
        y_values_df = pd.DataFrame(
            data=y_data.T,
            index=time_x,
            columns=["y1", "y2"],
        )

       
        u_values_df = pd.DataFrame(
            data=u_data.T,
            index=time_x,
            columns=["u1"],
        )

        utime_series_dfs.append(u_values_df)
        time_series_dfs.append(solution)
        ytime_series_dfs.append(y_values_df)


    return TSCDataFrame.from_frame_list(time_series_dfs), TSCDataFrame.from_frame_list(ytime_series_dfs), TSCDataFrame.from_frame_list(utime_series_dfs)

In [4]:
horizon = 4
maxT = 5
DT = 0.2 
np.random.seed(55)
t_s = np.arange(0, maxT+DT, DT)

In [73]:
x0 = np.array([
        [0.0],
        [0.0],
        [0.3],
        [0.0]
    ])
x = x0.T.reshape(-1, 4)

In [5]:
initial_conditions = np.array(
    [np.random.uniform(-2, 2, horizon), np.random.uniform(-2, 2, horizon),  np.random.uniform(-2, 2, horizon),  np.random.uniform(-2, 2, horizon), np.random.uniform(-4,4,horizon)]
).T.reshape(-1, 5)

In [75]:
tsc_data, y_data, u_data= solve_pendulum(initial_conditions, t_s)

print(f"time delta: {tsc_data.delta_time}")
print(f"#time series: {tsc_data.n_timeseries}")
print(f"#time steps per time series: {tsc_data.n_timesteps}")
print(f"(n_samples, n_features): {tsc_data.shape}")
print(f"time interval {tsc_data.time_interval()}")
print(f"Same time values: {tsc_data.is_same_time_values()}")
print("")
print("Data snippet:")
tsc_data

time delta: 0.19999999999999998
#time series: 4
#time steps per time series: 25
(n_samples, n_features): (100, 4)
time interval (0.0, 4.800000000000001)
Same time values: True

Data snippet:


Unnamed: 0_level_0,feature,x1,x2,x3,x4
ID,time,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,0.0,-1.627567,0.124495,-1.566609,-1.963444
0,0.2,-1.599145,0.171672,-1.920977,-1.503521
0,0.4,-1.551874,0.304757,-2.143501,-0.689425
0,0.6,-1.480132,0.406300,-2.193560,0.181079
0,0.8,-1.390326,0.494410,-2.078854,0.936642
...,...,...,...,...,...
3,4.0,-2.676658,0.118494,3.625157,0.472130
3,4.2,-2.639689,0.272116,3.829220,1.539491
3,4.4,-2.562018,0.498261,4.223350,2.332111
3,4.6,-2.448786,0.561720,4.722060,2.504192


In [76]:
func_rbf = TSCRadialBasis(
            kernel=GaussianKernel(epsilon=0.17), center_type="initial_condition"
        )
dict_step = [
    (
        "rbf", func_rbf
        ,
    )
]

pipe = Pipeline(dict_step)
pipe.fit(tsc_data)
XLIFT = pipe.transform(tsc_data)

In [77]:
XLIFT

Unnamed: 0_level_0,feature,rbf0,rbf1,rbf2,rbf3
ID,time,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,0.0,1.000000e+00,3.453473e-34,1.487707e-19,2.320007e-31
0,0.2,3.677292e-01,5.222521e-34,1.052663e-15,3.175069e-30
0,0.4,2.836576e-03,1.841557e-32,2.590534e-10,4.748545e-27
0,0.6,3.120819e-07,5.994467e-31,2.069625e-06,7.757432e-24
0,0.8,4.732719e-12,1.045420e-29,2.176058e-04,2.606348e-21
...,...,...,...,...,...
3,4.0,3.864178e-44,6.845316e-37,4.779958e-50,8.388373e-19
3,4.2,6.315724e-55,1.030864e-39,1.871272e-51,3.289314e-19
3,4.4,2.070442e-68,1.762085e-45,1.555617e-57,8.144810e-24
3,4.6,7.651218e-78,1.257763e-49,3.632320e-65,1.989403e-28


In [78]:
pipe.fit(y_data)
YLIFT = pipe.transform(y_data)

In [79]:
edmd_rbf = EDMD(dict_steps=dict_step, include_id_state=True).fit(
    X=y_data
)  
YLIFT = edmd_rbf.predict(y_data.initial_states())

In [80]:
YLIFT

Unnamed: 0_level_0,feature,y1,y2
ID,time,Unnamed: 2_level_1,Unnamed: 3_level_1
0,0.0,-1.599145,-1.920977
0,0.2,-1.546764,-2.016986
0,0.4,-1.496887,-2.062818
0,0.6,-1.449182,-2.069344
0,0.8,-1.403496,-2.045451
...,...,...,...
3,4.0,-2.130663,4.138514
3,4.2,-2.135654,4.227419
3,4.4,-2.141295,4.314377
3,4.6,-2.147569,4.399595


In [81]:
CT = np.linalg.lstsq(XLIFT.loc[0], tsc_data.loc[0])

  CT = np.linalg.lstsq(XLIFT.loc[0], tsc_data.loc[0])


In [82]:
xlift1 = XLIFT.loc[0].T
ylift1 = YLIFT.loc[0].T
u_1 = u_data.loc[0].T

In [83]:
xlift1

time,0.0,0.2,0.4,0.6,0.8,1.0,1.2,1.4,1.6,1.8,...,3.0,3.2,3.4,3.6,3.8,4.0,4.2,4.4,4.6,4.8
feature,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
rbf0,1.0,0.3677292,0.002836576,3.120819e-07,4.732719e-12,1.243932e-15,4.921604e-16,9.417486e-13,1.297582e-08,3.470029e-05,...,7.895907e-08,2.900979e-10,2.918052e-11,2.32522e-10,2.37525e-08,3.669196e-06,0.0001693023,0.001309704,0.001767346,0.0004867951
rbf1,3.453473e-34,5.2225209999999995e-34,1.841557e-32,5.994467e-31,1.04542e-29,6.368852e-28,5.039023e-25,2.364088e-21,2.689521e-19,1.89499e-19,...,2.483509e-25,1.5143309999999999e-24,3.554027e-23,1.716905e-21,2.972134e-20,5.3530019999999995e-20,8.891045999999999e-21,4.553142e-22,3.353034e-23,8.483356999999999e-24
rbf2,1.487707e-19,1.052663e-15,2.590534e-10,2.069625e-06,0.0002176058,0.0009490134,0.001192834,0.0003156519,9.260194e-06,2.264896e-08,...,1.025918e-05,0.0001539127,0.0003346357,0.0001459269,1.616219e-05,4.530019e-07,8.456058e-09,6.129474e-10,6.771227e-10,7.050082e-09
rbf3,2.3200070000000003e-31,3.175069e-30,4.7485450000000005e-27,7.757432e-24,2.606348e-21,4.995477e-19,1.333146e-16,3.389126e-14,8.715887e-14,1.672961e-15,...,3.2574089999999996e-19,7.685752e-18,1.748246e-16,2.448053e-15,5.471184e-15,1.106528e-15,3.3727450000000005e-17,7.697066e-19,6.962546e-20,4.877168e-20


In [84]:
ylift1

time,0.0,0.2,0.4,0.6,0.8,1.0,1.2,1.4,1.6,1.8,...,3.0,3.2,3.4,3.6,3.8,4.0,4.2,4.4,4.6,4.8
feature,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
y1,-1.599145,-1.546764,-1.496887,-1.449182,-1.403496,-1.359763,-1.31795,-1.278034,-1.239987,-1.203774,...,-1.021914,-0.996905,-0.973234,-0.950838,-0.929654,-0.909627,-0.8907,-0.872823,-0.855945,-0.84002
y2,-1.920977,-2.016986,-2.062818,-2.069344,-2.045451,-1.998339,-1.933795,-1.856439,-1.769934,-1.677168,...,-1.085543,-0.990136,-0.896959,-0.806217,-0.718041,-0.632505,-0.54964,-0.469442,-0.391882,-0.316914


In [85]:
u_1

time,0.0,0.2,0.4,0.6,0.8,1.0,1.2,1.4,1.6,1.8,...,3.0,3.2,3.4,3.6,3.8,4.0,4.2,4.4,4.6,4.8
feature,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
u1,3.884893,3.772222,3.754394,3.735635,3.723777,3.567958,3.541276,3.354129,3.388082,3.372329,...,3.371552,3.529631,3.603134,3.567727,3.574334,3.530007,3.466649,3.473235,3.582037,3.686965


In [86]:
Xliftcombined = xlift1.append(u_1)

In [87]:
Xliftcombined

time,0.0,0.2,0.4,0.6,0.8,1.0,1.2,1.4,1.6,1.8,...,3.0,3.2,3.4,3.6,3.8,4.0,4.2,4.4,4.6,4.8
feature,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
rbf0,1.0,0.3677292,0.002836576,3.120819e-07,4.732719e-12,1.243932e-15,4.921604e-16,9.417486e-13,1.297582e-08,3.470029e-05,...,7.895907e-08,2.900979e-10,2.918052e-11,2.32522e-10,2.37525e-08,3.669196e-06,0.0001693023,0.001309704,0.001767346,0.0004867951
rbf1,3.453473e-34,5.2225209999999995e-34,1.841557e-32,5.994467e-31,1.04542e-29,6.368852e-28,5.039023e-25,2.364088e-21,2.689521e-19,1.89499e-19,...,2.483509e-25,1.5143309999999999e-24,3.554027e-23,1.716905e-21,2.972134e-20,5.3530019999999995e-20,8.891045999999999e-21,4.553142e-22,3.353034e-23,8.483356999999999e-24
rbf2,1.487707e-19,1.052663e-15,2.590534e-10,2.069625e-06,0.0002176058,0.0009490134,0.001192834,0.0003156519,9.260194e-06,2.264896e-08,...,1.025918e-05,0.0001539127,0.0003346357,0.0001459269,1.616219e-05,4.530019e-07,8.456058e-09,6.129474e-10,6.771227e-10,7.050082e-09
rbf3,2.3200070000000003e-31,3.175069e-30,4.7485450000000005e-27,7.757432e-24,2.606348e-21,4.995477e-19,1.333146e-16,3.389126e-14,8.715887e-14,1.672961e-15,...,3.2574089999999996e-19,7.685752e-18,1.748246e-16,2.448053e-15,5.471184e-15,1.106528e-15,3.3727450000000005e-17,7.697066e-19,6.962546e-20,4.877168e-20
u1,3.884893,3.772222,3.754394,3.735635,3.723777,3.567958,3.541276,3.354129,3.388082,3.372329,...,3.371552,3.529631,3.603134,3.567727,3.574334,3.530007,3.466649,3.473235,3.582037,3.686965


In [88]:
G = Xliftcombined@Xliftcombined.T

In [89]:
G

feature,rbf0,rbf1,rbf2,rbf3,u1
feature,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
rbf0,1.13722,1.842961e-23,1.65818e-11,7.678045e-20,5.527303
rbf1,1.842961e-23,1.120866e-37,4.009029e-24,2.406506e-32,1.899667e-18
rbf2,1.65818e-11,4.009029e-24,2.627888e-06,1.2170340000000001e-17,0.01188243
rbf3,7.678045e-20,2.406506e-32,1.2170340000000001e-17,8.785285e-27,4.48072e-13
u1,5.527303,1.899667e-18,0.01188243,4.48072e-13,310.8555


In [90]:
V = ylift1@(Xliftcombined.T)

In [91]:
V

feature,rbf0,rbf1,rbf2,rbf3,u1
feature,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
y1,-2.254041,-6.535738e-19,-0.00423,-1.622116e-13,-101.409194
y2,-2.770581,-8.634809999999999e-19,-0.005848,-2.270354e-13,-117.098821


In [92]:
print(G.shape, V.shape)

(5, 5) (2, 5)


In [93]:
AB = np.linalg.lstsq(G.T, V.T)

  AB = np.linalg.lstsq(G.T, V.T)


In [94]:
AB_val = AB[0].T

In [95]:
AB_val.shape

(2, 5)

In [96]:
A = AB_val[:, :4]

In [97]:
B =AB_val[:, 4]

In [98]:
print(AB_val)
print(A)
print(B)

[[-4.76473676e-01  5.42580425e-11 -2.08903115e+02  6.44873115e-10
  -3.09768662e-01]
 [-8.06982353e-01  1.84335103e-10 -7.09724194e+02  2.19087638e-09
  -3.35220467e-01]]
[[-4.76473676e-01  5.42580425e-11 -2.08903115e+02  6.44873115e-10]
 [-8.06982353e-01  1.84335103e-10 -7.09724194e+02  2.19087638e-09]]
[-0.30976866 -0.33522047]


In [6]:
initial_conditions

array([[-1.62756685,  0.12449532, -1.56660906, -1.96344423,  3.88489265],
       [ 1.88662368, -0.85782306,  1.06864018,  0.47324845, -0.02604062],
       [-0.06456008,  1.45052151, -1.79428516,  1.27483733, -1.18144859],
       [-1.02990919, -1.83559939,  1.10286615,  1.59434289,  2.93667996]])