In [1]:
import os

import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
import networkx as nx
import scipy.optimize as optimize

from cdcm import *
from cdcm_utils.derivatives import set_derivative
from figures import *

In [2]:
dt = 1.
class SomeSystem(System):

    def __init__(self, name, **kwargs):
        super().__init__(name=name, **kwargs)

    def define_internal_nodes(self, dt, **kwargs):
                       
        theta = Parameter(
            value=np.array([1., 1.]),
            name='theta',
            units=None,
        )
        alpha = Parameter(
            value=0.001,
            name='alpha',
            units=None,
        )
        y_cap = Variable(
            value=0.,
            name='y_cap',
            units=None
        )
        y = Parameter(
            value=0.,
            name='y',
            units=None
        )
        x = State(
            value=0.,
            name='x',
            units=None
        )
        loss = Variable(
            value=10.,
            name='loss',
            units=None,
        )
        
        @make_function(x)
        def calc_x(x=x):
            return x+0.1
            

        @make_function(y_cap)
        def calc_y_cap(x=x,t=theta):
            return x*t[0]+t[1]

        @make_function(loss)
        def calc_loss(y=y,yc=y_cap):
            res = jnp.square(yc - y).mean()
            return res

In [3]:
dt = 1.0
with System(name='diff_sys') as diff_sys:
    clock = make_clock(dt=dt, units='seconds')
    some_sys = SomeSystem(name='some_sys', dt=clock.dt)

In [4]:
ds,ss = diff_sys,diff_sys.some_sys

update_seq = set_derivative(
    ds,
    ss.loss,
    ss.theta,
    "dldt",
    update_seq=True
)

def update_loss_grad(update_seq):
    for n in update_seq:
        n.forward()

In [5]:
print([i.name for i in update_seq])

['calc_x', 'calc_y_cap', 'calc_pdy_capdtheta', 'calc_pdlossdy_cap', 'calc_dldt']


In [6]:
def calibrate_theta(simulator,dt,i=0):
    
    def set_theta(t):
        
        def event():
            ss.theta.value = t

        return event
        
    
    def objective_fn(t):
        simulator.add_event(i*dt,set_theta(t))
        simulator.forward()
        print('loss,theta:',ss.loss.value,ss.theta.value)
        update_loss_grad(update_seq)
        dldt = ds.dldt.value
        ss.calc_loss.forward()
        return ss.loss.value, dldt

        
    print(f'calibration starts at t = {i*dt}')
    print('loss,theta:',ss.loss.value,ss.theta.value)
    sol = optimize.minimize(
        objective_fn,
        ss.theta.value,
        jac=True,
        method='BFGS',
        options={'disp':True},
    )
    print(f"calibration over with new theta {sol.x}")

In [7]:
def update_x_y_theta():
    ss.x.value = np.array([10.,20.])
    ss.y.value = np.array([6.,7.])
    ss.theta.value = np.array([1., 6.])

In [8]:
simulator = Simulator(ds)

for i in range(5):
    simulator.forward()
    simulator.transition()
    if i==3:
        simulator.add_event(i*dt, update_x_y_theta)
        simulator.forward()
        calibrate_theta(simulator,dt,i=i)

calibration starts at t = 3.0
loss,theta: 230.5 [1. 6.]
loss,theta: 230.5 [1. 6.]
loss,theta: 0.75942075 [-0.00816169  5.93909023]
loss,theta: 0.05691846 [0.05471526 5.75444268]
loss,theta: 0.03592367 [0.06498169 5.59784613]
loss,theta: 0.0060410444 [0.08589765 5.24423292]
loss,theta: 1.9456394e-07 [0.10003145 4.99994031]
loss,theta: 1.2718215e-09 [0.10000294 4.99998839]
loss,theta: 1.0345581e-14 [0.09999998 5.0000003 ]
Optimization terminated successfully.
         Current function value: 0.000000
         Iterations: 7
         Function evaluations: 8
         Gradient evaluations: 8
calibration over with new theta [0.09999998 5.0000003 ]
