In [1]:
import numpy as np
import pysindy as ps
from pysindy.feature_library import PolynomialLibrary, FourierLibrary, CustomLibrary
import matplotlib.pyplot as plt

from generate_data import generate_tracking_data, generate_discrete_tracking_data
from decimal import Decimal

In [2]:
T = 1e-1
t = np.linspace(0, 20, 201)

x0 = [-0.1, 0.2, -0.1]
x0_val = [0.1, 0.1, 0]

u = lambda t: 0.001*np.sin(t)
u_val = lambda t: 0.001*np.cos(2*t)

x, x_dot = generate_tracking_data(t=t, x0=x0, u=u)
x_val, _ = generate_tracking_data(t=t, x0=x0_val, u=u_val)

xk = generate_discrete_tracking_data(t=t, x0=x0, T=T, u=u(t))
xk_val = generate_discrete_tracking_data(t=t, x0=x0_val, T=T, u=u_val(t))

In [3]:
model = ps.SINDy(
    feature_library=PolynomialLibrary(degree=3), # 2, 1
    # feature_library=FourierLibrary(n_frequencies=1), # 2, 1
    optimizer=ps.STLSQ(threshold=1e-4), # 0.001, 0.01, 0.1
    feature_names=[f'x{i+1}' for i in range(3)]+['u'],
    # discrete_time=True
    )
model.fit(x=x, x_dot=x_dot, u=u(t))
# model.fit(x=xk, u=u(t))
model.print(precision=4)

(x1)' = 0.0176 1 + -0.7335 x1 + -0.1854 x2 + 0.9012 x3 + 1.8208 x1^2 + -0.9303 x1 x2 + -0.9548 x1 x3 + 0.3733 x1 u + 0.6321 x2^2 + 0.7406 x2 x3 + -0.7847 x2 u + 0.3974 x3^2 + 0.4567 x3 u + 7.4121 x1^3 + -4.6573 x1^2 x2 + -3.5038 x1^2 x3 + 1.4821 x1 x2^2 + 3.1935 x1 x2 x3 + 1.4959 x1 x3^2 + -0.7614 x2^3 + -1.3874 x2^2 x3 + -1.4824 x2 x3^2 + -1.9615 x2 x3 u + -0.2438 x3^3
(x2)' = 1.0000 x3
(x3)' = -0.0018 1 + -4.2178 x1 + 0.0184 x2 + -0.3815 x3 + -20.8075 u + -0.9065 x1^2 + 0.0619 x1 x2 + 0.2215 x1 x3 + -0.2481 x1 u + -0.0642 x2^2 + -0.1039 x2 x3 + -1.1625 x2 u + -0.0909 x3^2 + -0.1010 x3 u + -4.9494 x1^3 + 1.6055 x1^2 x2 + 0.8893 x1^2 x3 + -0.0966 x1 x2^2 + -0.8048 x1 x2 x3 + 1.2453 x1 x2 u + -0.3880 x1 x3^2 + 0.0745 x2^3 + 0.1864 x2^2 x3 + 2.1178 x2^2 u + 0.3346 x2 x3^2 + 0.2398 x2 x3 u + 0.0585 x3^3


In [4]:
t = np.linspace(0, 200, 201)
x0s = np.random.uniform(-0.1, 0.1, (500, 3))

xss = []
x_dots = []
us = []

for x0 in x0s:
    u_val = lambda t: x0[0] / 100. * np.sin(t)
    x, x_dot = generate_tracking_data(t=t, x0=x0, u=u_val)
    xss.append(x)
    x_dots.append(x_dot)
    us.append(u_val(t))

In [7]:
model = ps.SINDy(
    feature_library=PolynomialLibrary(degree=3), # 2, 1
    # feature_library=FourierLibrary(n_frequencies=1), # 2, 1
    optimizer=ps.STLSQ(threshold=0), # 0.001, 0.01, 0.1
    feature_names=[f'x{i+1}' for i in range(3)]+['u']
    )
model.fit(x=xss, x_dot=x_dots, u=us, multiple_trajectories=True)
model.print(precision=3)

(x1)' = -0.877 x1 + 1.000 x3 + -0.215 u + 0.470 x1^2 + -0.088 x1 x3 + -0.019 x2^2 + 3.846 x1^3 + -1.000 x1^2 x3 + 0.280 x1^2 u + 0.470 x1 u^2 + 0.630 u^3
(x2)' = 1.000 x3
(x3)' = -4.208 x1 + -0.396 x3 + -20.967 u + -0.470 x1^2 + -3.564 x1^3 + 6.265 x1^2 u + 46.000 x1 u^2 + 61.100 u^3


In [None]:
for x_num in range(len(x0)):
    q = 'Biblioteka funkcji & Próg & $\Dot{x}'
    print(f'{q}_{x_num+1}$ & $E_{x_num+1}$ \\\\')
    for i, library in enumerate([FourierLibrary(n_frequencies=2), FourierLibrary(n_frequencies=1), PolynomialLibrary(degree=3), PolynomialLibrary(degree=2), PolynomialLibrary(degree=1)]):
        print(f'\\hline')
        for threshold in range(3):
            threshold = 10**(threshold-10) if i >= 2 else 100*10**(threshold-10)
            # threshold = T * threshold
            name = ['Trygonometryczna (st. 2)', 'Trygonometryczna (st. 1)', 'Wielomiany (st. 3)', 'Wielomiany (st. 2)', 'Liniowa']
            model = ps.SINDy(
                feature_library=library,
                optimizer=ps.STLSQ(threshold=threshold),
                feature_names=[f'x{i+1}' for i in range(len(x0))]+['u'],
                # discrete_time=True
                )
            # model.fit(x=xk, u=u(t))
            model.fit(x=x, x_dot=x_dot, u=u(t))
            try:
                x_sim = model.simulate(x0=x0_val, t=t, u=u_val)
                # x_sim = model.simulate(x0=x0_val, t=201, u=u_val(t))
                mse = ((x_sim - x_val)**2).mean(axis=0)
                E = '%.3E' % Decimal(str(mse[x_num]))
            except:
                E = '\infty'
            coeffs = ' + '.join(['%.3E' % Decimal(str(coeff))+' '+model.get_feature_names()[i] for i, coeff in enumerate(model.coefficients()[x_num]) if abs(model.coefficients()[x_num][i]) > threshold])
            if len(coeffs.split(' + ')) > 2:
                eq = (coeffs.split(' + ')[0] + ' + ' + coeffs.split(' + ')[1] + '\dots').replace(' 1 +', ' +')
            else:
                eq = coeffs
            if len(coeffs) == 0:
                eq = '0,000'
            eq = eq.replace(' 1 +', ' +').replace('.', ',').replace('+ -', '- ').replace('sin', '\sin').replace('cos', '\cos').replace('(1 x1)', '(x_1)').replace('(1 x2)', '(x_2)').replace('x2', 'x_2').replace('(1 x3)', '(x_3)').replace('x3', 'x_3').replace('x1', 'x_1').replace('(1 u)', '(u)')
            thr = '%.0E' % Decimal(str(threshold))
            for pow in range(1, 10):
                eq = eq.replace('E+00', '')
                E = E.replace('E+00', '')
                eq = eq.replace(f'E+0{pow}', f'\cdot 10^{pow}').replace(f'E-0{pow}', '\cdot 10^{'+f'{-pow}'+'}')
                E = E.replace(f'E+0{pow}', f'\cdot 10^{pow}').replace(f'E-0{pow}', '\cdot 10^{'+f'{-pow}'+'}')
                thr = thr.replace(f'1E+0{pow}', f'10^{pow}').replace(f'1E-0{pow}', '10^{'+f'{-pow}'+'}')
            print(f"{name[i]} & ${thr}$ & ${eq}$ & ${E.replace('.', ',')}$ \\\\")
    print('\n\n')

In [None]:
# x_sim = model.simulate(x0=x0_val, t=t, u=u_val)
# mse = ((x_sim - x_val)**2).mean(axis=0)
# print(f'Błąd średniokwadratowy x1: {mse[0]}, x2: {mse[1]}, x3: {mse[2]}')

In [None]:
plt.plot(t, x_val[:, 0])
plt.plot(t, x_val[:, 1])
plt.plot(t, x_val[:, 2], 'y')
plt.grid()
plt.legend(["Kąt natarcia", "Kąt nachylenia", "Współczynnik nachylenia"])
plt.xlim(0, max(t))
plt.xlabel("Czas [dni]")
ax1 = plt.gca()
ax2 = ax1.twinx()
ax2.spines['right'].set_color('green')
ax2.yaxis.label.set_color('green')
ax2.tick_params(axis='y', colors='green')
ax2.plot(t, u_val(t), 'g--', alpha=0.4)
ax2.set_ylabel("Sterowanie")
ax2.set_ylim(0, 0.003)
ax1.set_ylabel("Kąt [Rad]")
ax2.set_ylabel("Sterowanie")
plt.show()