In [None]:
import numpy as np
import pysindy as ps
from pysindy.feature_library import PolynomialLibrary, FourierLibrary, CustomLibrary
import matplotlib.pyplot as plt

from decimal import Decimal
from generate_data import generate_population_data, generate_discrete_population_data

In [None]:
T = 1e-2

# t = np.linspace(0, 100, int(100/T)+1)
t = np.linspace(0, 100, 1001)

x0 = [50, 50]
x0_val = [100, 30]

In [None]:
xk = generate_discrete_population_data(t, x0, T=T)
xk_val = generate_discrete_population_data(t, x0_val, T=T)

In [None]:
x, x_dot = generate_population_data(t, x0)
x_val, _ = generate_population_data(t, x0_val)

In [None]:
pr = int(3 - np.log10(T))
model = ps.SINDy(
    feature_library=PolynomialLibrary(degree=2), # 2, 1
    # feature_library=FourierLibrary(n_frequencies=2), # 2, 1
    optimizer=ps.STLSQ(threshold=10.), # 0.001, 0.01, 0.1
    feature_names=[f'x{i+1}' for i in range(len(x[0]))],
    discrete_time=False
    )
model.fit(x=x)
model.print(precision=pr)

In [None]:
coeffs = ' + '.join(['%.3E' % Decimal(str(coeff))+' '+model.get_feature_names()[i] for i, coeff in enumerate(model.coefficients()[1]) if abs(model.coefficients()[1][i]) > 0.000001])

In [None]:
t = np.linspace(0, 100, 1001)
for x_num in range(len(x0)):
    q = 'Biblioteka funkcji & Próg & $\Dot{x}'
    print(f'{q}_{x_num+1}$ & $E_{x_num+1}$ \\\\')
    for i, library in enumerate([FourierLibrary(n_frequencies=2), FourierLibrary(n_frequencies=1), PolynomialLibrary(degree=2), PolynomialLibrary(degree=1)]):
        print(f'\\hline')
        for threshold in range(3):
            threshold = 10**(threshold-3) if i >= 2 else 100*10**(threshold-3)
            name = ['Trygonometryczna (st. 2)', 'Trygonometryczna (st. 1)', 'Wielomiany (st. 2)', 'Liniowa']
            model = ps.SINDy(
                feature_library=library,
                optimizer=ps.STLSQ(threshold=threshold),
                feature_names=[f'x{i+1}' for i in range(len(x0))])
            model.fit(x=x, x_dot=x_dot, t=t)
            try:
                x_sim = model.simulate(x0=x0_val, t=t)
                mse = ((x_sim - x_val)**2).mean(axis=0)
                E = '%.3E' % Decimal(str(mse[x_num]))
            except:
                E = '\infty'
            coeffs = ' + '.join(['%.3E' % Decimal(str(coeff))+' '+model.get_feature_names()[i] for i, coeff in enumerate(model.coefficients()[x_num]) if abs(model.coefficients()[x_num][i]) > threshold])
            if len(coeffs.split(' + ')) > 2:
                eq = (coeffs.split(' + ')[0] + ' + ' + coeffs.split(' + ')[1] + '\dots').replace(' 1 +', ' +')
            else:
                eq = coeffs
            if len(coeffs) == 0:
                eq = '0,000'
            eq = eq.replace(' 1 +', ' +').replace('.', ',').replace('+ -', '- ').replace('sin', '\sin').replace('cos', '\cos').replace('(1 x1)', '(x_1)').replace('(1 x2)', '(x_2)').replace('x2', 'x_2').replace('x1', 'x_1').replace('(1 u)', '(u)')
            for pow in range(1, 10):
                eq = eq.replace('E+00', '')
                E = E.replace('E+00', '')
                eq = eq.replace(f'E+0{pow}', f'\cdot 10^{pow}').replace(f'E-0{pow}', '\cdot 10^{'+f'{-pow}'+'}')
                E = E.replace(f'E+0{pow}', f'\cdot 10^{pow}').replace(f'E-0{pow}', '\cdot 10^{'+f'{-pow}'+'}')
            print(f"{name[i]} & {(str(threshold).replace('.', ',') + ' &').replace(',0 &', ' &')} ${eq}$ & {E.replace('.', ',')} \\\\")
    print('\n\n')

In [None]:
x_sim = model.simulate(x0=x0_val, t=len(t))
mse_x1 = ((x_sim[:, 0] - x_val[:, 0])**2).mean()
mse_x2 = ((x_sim[:, 1] - x_val[:, 1])**2).mean()
print(f'Błąd średniokwadratowy x1: {mse_x1:.2f}, x2: {mse_x2:.2f}')

In [None]:
plt.plot(t, x_val[:, 0])
plt.plot(t, x_val[:, 1])
plt.plot(t, x_sim[:, 0], "r--")
plt.plot(t, x_sim[:, 1], "k--")
plt.legend(["Populacja ofiar", "Populacja drapieżników", "Model ofiar", "Model drapieżników"])
plt.ylim(0, max(x[:, 0]*1.4))
plt.xlim(0, max(t))
plt.xlabel("Czas [dni]")
plt.ylabel("Liczba osobników")
plt.grid()
plt.show()

In [None]:
t = np.linspace(0, 100, 1001)

x0 = [50, 50]
x, _ = generate_population_data(t, x0)

plt.plot(t, x[:, 0])
plt.plot(t, x[:, 1])
plt.legend(["Populacja ofiar", "Populacja drapieżników"])
plt.ylim(0, max(x[:, 0]*1.2))
plt.xlim(0, max(t))
plt.xlabel("Czas [dni]")
plt.ylabel("Liczba osobników")
plt.grid()
# plt.savefig(f'imgs/proces-1 {x0[0]}-{x0[1]}.png', dpi=300, bbox_inches='tight')
plt.show()