In [None]:
import matplotlib.pyplot as plt
import torch
from plotly.subplots import make_subplots

import cvs_fit
from biophysical_models import models
from time_series_prediction import settings

In [None]:
import plotly.io as pio
pio.templates.default = 'plotly_white'

In [None]:
# settings.switch_device('cuda')

In [None]:
cvs_cls = models.add_bp_metrics(models.SmithCardioVascularSystem)
y, obs_matrix, cvs = cvs_fit.load_data(cvs_cls)

In [None]:
for param in cvs.parameters():
    param.requires_grad_(False)
cvs.v_tot.requires_grad_(True)
cvs.mt.r.requires_grad_(True)
cvs.tc.r.requires_grad_(True)
cvs.av.r.requires_grad_(True)
cvs.pv.r.requires_grad_(True)
cvs.pul.r.requires_grad_(True)
cvs.sys.r.requires_grad_(True)
cvs.vc.e_es.requires_grad_(True)
cvs.pa.e_es.requires_grad_(True)
cvs.pu.e_es.requires_grad_(True)
cvs.ao.e_es.requires_grad_(True)
cvs.lvf.e_es.requires_grad_(True)
cvs.rvf.e_es.requires_grad_(True)
cvs.lvf.p_0.requires_grad_(True)
cvs.rvf.p_0.requires_grad_(True)

In [None]:
kf = cvs_fit.get_enkf(
    obs_matrix, 
    cvs, 
    n_particles=100, 
    obs_noise=5, 
    init_proc_noise=1e-5, 
    # atol=1e-9, 
    max_step=1e-2,
)

In [None]:
t = torch.linspace(0, 500, 500)
hr = [cvs.e.f_hr(x) for x in t]
plt.plot(t, hr)

In [None]:
plt.plot(y[1:500, :])

In [None]:
kf.train(y[1:500, :], 20, dt=1, subseq_len=20, print_timing=True, lr_alpha=1e-1, lr_beta=1e-3, save_checkpoints='step', auto_scale_lr=True)

In [None]:
kf._fig.update_layout(template='plotly_white').show()

In [None]:
with torch.no_grad():
    cvs.save_traj = True
    cvs.verbose = True
    t_sol, sol = cvs.simulate(500, 50)

In [None]:
d1 = models.SmithCardioVascularSystem().state_dict()
d2 = kf.transition_function.state_dict()
for key in d1:
    try:
        diff = d2[key]-d1[key]
    except RuntimeError:
        continue
    if diff == 0:
        continue
    print(f'{key:25s} {d1[key]:12.3f} {d2[key]:12.3f} {diff:12.3f}')

In [None]:
torch.save(cvs.state_dict(), 'params20.to')

In [None]:
fig = make_subplots(5, 1, shared_xaxes=True)
fig.add_scatter(
    y=[x['model_state_dict']['v_tot'] for x in kf.checkpoints],
    name='v_tot',
    row=1, col=1,
)
fig.update_yaxes(title_text='v_tot', row=1)

for name in [
    'mt.r',
    'tc.r',
    'av.r',
    'pv.r',
]:
    fig.add_scatter(
        y=[x['model_state_dict'][name] for x in kf.checkpoints],
        name=name,
        row=2, col=1,
    )
fig.update_yaxes(title_text='Valve resistances', row=2)

for name in [
    'pul.r',
    'sys.r',
]:
    fig.add_scatter(
        y=[x['model_state_dict'][name] for x in kf.checkpoints],
        name=name,
        row=3, col=1,
    )
fig.update_yaxes(title_text='Circulation resistances', row=3)

for name in [
    'vc.e_es',
    'pa.e_es',
    'pu.e_es',
    'ao.e_es',
    'lvf.e_es',
    'rvf.e_es',
]:
    fig.add_scatter(
        y=[x['model_state_dict'][name] for x in kf.checkpoints],
        name=name,
        row=4, col=1,
    )
fig.update_yaxes(title_text='e_es', row=4)

for name in [
    'lvf.p_0',
    'rvf.p_0',
]:
    fig.add_scatter(
        y=[x['model_state_dict'][name] for x in kf.checkpoints],
        name=name,
        row=5, col=1,
    )
fig.update_yaxes(title_text='p_0', row=5)

for i in range(24, len(kf.checkpoints), 25):
    fig.add_vline(i, line_color='black', line_dash='dot')
fig.update_xaxes(row=5, title_text='Training iterations (epochs marked with dotted lines)')
fig.update_layout(height=800)