In [None]:
import numpy as np
import pysindy as ps
from pysindy.feature_library import PolynomialLibrary, FourierLibrary
import matplotlib.pyplot as plt

from generate_data import generate_tracking_data, generate_discrete_tracking_data

In [None]:
T = 1e-1
t = np.linspace(0, 20, 201)

x0 = [-0.1, 0.2, -0.1]
x0_val = [0.1, 0.1, 0]

u = lambda t: 0.001*np.sin(t)
u_val = lambda t: 0.001*np.cos(2*t)

In [None]:
x, x_dot = generate_tracking_data(t=t, x0=x0, u=u)
x_val, _ = generate_tracking_data(t=t, x0=x0_val, u=u_val)

xk = generate_discrete_tracking_data(t=t, x0=x0, T=T, u=u(t))
xk_val = generate_discrete_tracking_data(t=t, x0=x0_val, T=T, u=u_val(t))

In [None]:
model = ps.SINDy(
    feature_library=PolynomialLibrary(degree=3),
    # feature_library=FourierLibrary(n_frequencies=1),
    optimizer=ps.STLSQ(threshold=1e-5),
    feature_names=[f'x{i+1}' for i in range(3)]+['u'],
    discrete_time=True
    )
# model.fit(x=x, x_dot=x_dot, t=t, u=u(t))
model.fit(x=xk, u=u(t))
model.print(precision=3)

In [None]:
x_sim = model.simulate(x0=x0_val, t=len(t), u=u_val(t))
mse = ((x_sim - xk_val)**2).mean(axis=0)
print(f'Błąd średniokwadratowy x1: {mse[0]}, x2: {mse[1]}, x3: {mse[2]}')

In [None]:
fig, axs = plt.subplots(4, 1, figsize=(10, 8), sharex=True)

labels = [
    ("Kąt natarcia", "Model kąt natarcia"),
    ("Kąt nachylenia", "Model kąta nachylenia"),
    ("Współczynnik nachylenia", "Model współczynnika nachylenia"),
    ("Sterowanie")
]
colors = ['b', 'orange', 'y']
sim_line_colors = ['k', 'purple', 'r']

for i, ax in enumerate(axs[:3]):
    ax.plot(t, xk_val[:, i], color=colors[i])
    ax.plot(t, x_sim[:, i], sim_line_colors[i], linestyle='--')
    ax.set_ylabel("Rad", fontsize=16)
    ax.legend([labels[i][0], labels[i][1]], fontsize=14)
    ax.grid()
    ax.set_xlim(0, 20)
    ax.tick_params(axis='both', labelsize=14)

axs[3].plot(t, u_val(t), 'g--', alpha=0.7)
axs[3].set_ylabel("Sterowanie", fontsize=16)
axs[3].set_xlabel("Czas [s]", fontsize=16)
axs[3].tick_params(axis='y')
axs[3].grid()
axs[3].tick_params(axis='both', labelsize=14)

plt.tight_layout()
plt.show()