In [None]:
import numpy as np
import pysindy as ps
from pysindy.feature_library import PolynomialLibrary, FourierLibrary, CustomLibrary

import matplotlib.pyplot as plt
from IPython.display import display, Latex

from generate_data import generate_population_data

In [None]:
t = np.linspace(0, 100, 1001)
x0 = [50, 50]

x, x_dot = generate_population_data(t, x0)

In [None]:
model = ps.SINDy(
    feature_library=PolynomialLibrary(degree=2), # 2, 1
    # feature_library=FourierLibrary(n_frequencies=1), # 2, 1
    optimizer=ps.STLSQ(threshold=0.001), # 0.001, 0.01, 0.1
    feature_names=[f'x{i+1}' for i in range(len(x[0]))]
    )
model.fit(x=x, x_dot=x_dot, t=t)
model.print()

In [None]:
x_sim = model.simulate(x0=x0, t=t)
mse_x1 = ((x_sim[:, 0] - x[:, 0])**2).mean()
mse_x2 = ((x_sim[:, 1] - x[:, 1])**2).mean()
print(f'Błąd średniokwadratowy x1: {mse_x1:.2f}, x2: {mse_x2:.2f}')

In [None]:
plt.plot(t, x[:, 0])
plt.plot(t, x[:, 1])
plt.plot(t, x_sim[:, 0], "r--")
plt.plot(t, x_sim[:, 1], "k--")
plt.legend(["Populacja ofiar", "Populacja drapieżników", "Symulacja modelu ofiar", "Symulacja modelu drapieżników"])
plt.ylim(0, max(x[:, 0]*1.4))
plt.xlim(0, max(t))
plt.xlabel("Czas [dni]")
plt.ylabel("Liczba osobników")
plt.grid()
plt.show()