In [None]:
import numpy as np
import pysindy as ps
from pysindy.feature_library import PolynomialLibrary, FourierLibrary, CustomLibrary, ConcatLibrary
import matplotlib.pyplot as plt

from decimal import Decimal
from generate_data import generate_population_data, generate_discrete_population_data

In [None]:
t = np.linspace(0, 100, 1001)
t_val = np.linspace(0, 100, 1001)

x0 = [50, 50]
x0_val = [100, 30]

u = lambda t : np.sin(t)
u_val = lambda t : np.cos(2*t)

In [None]:
x, x_dot = generate_population_data(t, x0, u)
x_val, _ = generate_population_data(t, x0_val, u_val)

In [None]:
T = 1e-1
xk = generate_discrete_population_data(t, x0, T=T, u=u(t))
xk_val = generate_discrete_population_data(t, x0_val, T=T, u=u_val(t))

In [None]:
pr = int(3 - np.log10(T))
model = ps.SINDy(
    feature_library=PolynomialLibrary(degree=2),
    # feature_library=FourierLibrary(n_frequencies=1),
    optimizer=ps.STLSQ(threshold=0.0001),
    feature_names=[f'x{i+1}' for i in range(len(x0))]+['u'],
    # discrete_time=True
    )
# model.fit(x=xk, u=u)
model.fit(x=x, x_dot=x_dot, u=u(t))
model.print(precision=pr)

In [None]:
x_sim = model.simulate(x0=x0, t=t, u=u)
mse = ((x_sim - x)**2).mean(axis=0)
print(f'Błąd średniokwadratowy x1: {mse[0]}, x2: {mse[1]}')

In [None]:
plt.plot(t, x[:, 0])
plt.plot(t, x[:, 1])
plt.plot(t, x_sim[:, 0], "r--")
plt.plot(t, x_sim[:, 1], "k--")
plt.ylim(0, max(x[:, 0]*1.4))
plt.xlim(0, max(t))
plt.grid()
plt.legend(["Populacja ofiar", "Populacja drapieżników"])
plt.ylabel("Liczba osobników")
plt.xlabel("Czas [dni]")
ax1 = plt.gca()
ax2 = ax1.twinx()
ax2.spines['right'].set_color('green')
ax2.yaxis.label.set_color('green')
ax2.tick_params(axis='y', colors='green')
ax2.plot(t, u(t), 'g--', alpha=0.4)
ax2.set_ylabel("Sterowanie")
ax2.set_ylim(1, 10)
ax1.set_ylabel("Liczba osobników")
ax2.set_ylabel("Sterowanie")
plt.show()