In [None]:
import numpy as np
import pysindy as ps
from pysindy.feature_library import PolynomialLibrary, FourierLibrary
import matplotlib.pyplot as plt

from generate_data import generate_hiv_data, generate_discrete_hiv_data

In [None]:
T = 1e-1
t = np.linspace(0, 100, 1001)
x0 = [1, 1, 1, 1, 1]
x0_val = [2, 1, 2, 2, 3]

u = lambda t: 0.1*np.sin(t)
u_val = lambda t: 0.1*np.cos(2*t)

x, x_dot = generate_hiv_data(t=t, x0=x0, u=u)
x_val, _ = generate_hiv_data(t=t, x0=x0_val, u=u_val)

xk = generate_discrete_hiv_data(t=t, x0=x0, T=T, u=u(t))
xk_val = generate_discrete_hiv_data(t=t, x0=x0_val, T=T, u=u_val(t))

In [None]:
model = ps.SINDy(
    feature_library=PolynomialLibrary(degree=3), # 2, 1
    # feature_library=FourierLibrary(n_frequencies=1), # 2, 1
    optimizer=ps.STLSQ(threshold=1e-4), # 0.001, 0.01, 0.1
    feature_names=[f'x{i+1}' for i in range(len(x[0]))]+['u'],
    discrete_time=True
    )
model.fit(x=xk, u=u(t))
# model.fit(x=x, x_dot=x_dot, t=t, u=u(t))
model.print()

In [None]:
x_sim = model.simulate(x0=x0_val, t=len(t), u=u_val(t))
mse = ((x_sim - xk_val)**2).mean(axis=0)
mse

In [None]:
fig, axs = plt.subplots(6, 1, figsize=(10, 18), sharex=True)

labels_true = [
    "Zdrowe CD4+", "Zainfekowane CD4+", "Prekursory LTC",
    "Pomocniczo-niezależne LTC", "Pomocniczo-zależne LTC"]

labels_model = [
    "Model zdrowych CD4+", "Model zainfekowanych CD4+", "Model prekursorów LTC",
    "Model pomocniczo-niezależnych LTC", "Model pomocniczo-zależnych LTC"]

for i in range(5):
    axs[i].plot(t, xk_val[:, i], label=labels_true[i], color=f"C{2*i}")
    axs[i].plot(t, x_sim[:, i], '--', label=labels_model[i], color=f"C{2*i+1}")
    axs[i].set_ylabel("Stężenie", fontsize=16)
    axs[i].legend(fontsize=14)
    axs[i].grid()
    axs[i].tick_params(axis='both', labelsize=14)

axs[5].plot(t, u_val(t), 'g--', alpha=0.7)
axs[5].set_ylabel("Sterowanie", fontsize=16)
axs[5].set_xlabel("Czas [dni]", fontsize=16)
axs[5].grid()
axs[5].tick_params(axis='both', labelsize=14)

plt.xlim(0, max(t))
plt.tight_layout()
plt.show()