In [None]:
import numpy as np
import pysindy as ps
from pysindy.feature_library import PolynomialLibrary, FourierLibrary
import matplotlib.pyplot as plt

from generate_data import generate_tracking_data, generate_discrete_tracking_data
from decimal import Decimal

In [74]:
T = 1e-1
t = np.linspace(0, 20, 201)

x0 = [-0.1, 0.2, -0.1]
x0_val = [0.1, 0.1, 0]

u = lambda t: 0.001*np.sin(t)
u_val = lambda t: 0.001*np.cos(2*t)

x, x_dot = generate_tracking_data(t=t, x0=x0, u=u)
x_val, _ = generate_tracking_data(t=t, x0=x0_val, u=u_val)

xk = generate_discrete_tracking_data(t=t, x0=x0, T=T, u=u(t))
xk_val = generate_discrete_tracking_data(t=t, x0=x0_val, T=T, u=u_val(t))

In [None]:
x0s = np.random.uniform(-0.1, 0.1, (500, 3))

xss = []
x_dots = []
us = []

for x0 in x0s:
    u_test = lambda t: x0[0] / 100. * np.sin(t)
    x = generate_discrete_tracking_data(t=t, x0=x0, T=T, u=u_test(t))
    # x, x_dot = generate_tracking_data(t=t, x0=x0, u=u)
    xss.append(x)
    x_dots.append(x_dot)
    us.append(u_test(t))

In [None]:
model = ps.SINDy(
    feature_library=PolynomialLibrary(degree=3), # 2, 1
    # feature_library=FourierLibrary(n_frequencies=1), # 2, 1
    optimizer=ps.STLSQ(threshold=1e-7), # 0.001, 0.01, 0.1
    feature_names=[f'x{i+1}' for i in range(3)]+['u'],
    discrete_time=True
    )
model.fit(x=xss, u=us, multiple_trajectories=True)
model.print(precision=4)

(x1)[k+1] = 0.9123 x1[k] + 0.1000 x3[k] + -0.0215 u[k] + 0.0470 x1[k]^2 + -0.0088 x1[k] x3[k] + -0.0019 x2[k]^2 + 0.3846 x1[k]^3 + -0.1000 x1[k]^2 x3[k] + 0.0280 x1[k]^2 u[k] + 0.0455 x1[k] u[k]^2 + -0.0004 x3[k] u[k]^2
(x2)[k+1] = 1.0000 x2[k] + 0.1000 x3[k]
(x3)[k+1] = -0.4208 x1[k] + 0.9604 x3[k] + -2.0967 u[k] + -0.0470 x1[k]^2 + -0.3564 x1[k]^3 + 0.6265 x1[k]^2 u[k] + 4.6000 x1[k] u[k]^2 + 6.1100 u[k]^3


In [93]:
x_sim = model.simulate(x0=x0_val, t=201, u=u_val(t))
mse = ((x_sim - x_val)**2).mean(axis=0)
# mse1 = '%.3E' % Decimal(str(mse[0]))
# mse2 = '%.3E' % Decimal(str(mse[1]))
# mse3 = '%.3E' % Decimal(str(mse[2]))
mse1, mse2, mse3 = mse[0], mse[1], mse[2]
print(f'Błąd średniokwadratowy x1: {mse1}, x2: {mse2}, x3: {mse3}')

Błąd średniokwadratowy x1: 4.076221959308485e-05, x2: 4.169042593878539e-05, x3: 0.0001680509350714841
