In [1]:
import os

import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
import networkx as nx
from scipy.optimize import minimize

from cdcm import *
from cdcm_utils.derivatives import set_derivative
from figures import *

In [2]:
dt = 1.
class SomeSystem(System):

    def __init__(self, name, **kwargs):
        super().__init__(name=name, **kwargs)

    def define_internal_nodes(self, dt, **kwargs):
                       
        theta = Parameter(
            value=1.,
            name='theta',
            units=None,
        )
        alpha = Parameter(
            value=0.001,
            name='alpha',
            units=None,
        )
        y_cap = Variable(
            value=0.,
            name='y_cap',
            units=None
        )
        y = Parameter(
            value=0.,
            name='y',
            units=None
        )
        x = State(
            value=0.,
            name='x',
            units=None
        )
        loss = Variable(
            value=10.,
            name='loss',
            units=None,
        )

        @make_function(x)
        def calc_x(x):
            return x+0.1
        
        @make_function(y_cap)
        def calc_y_cap(x=x,t=theta):
            return x*t+5.

        @make_function(loss)
        def calc_loss(y=y,yc=y_cap):
            res = jnp.square(yc - y).mean()
            return res

In [3]:
dt = 1.0
with System(name='diff_sys') as diff_sys:
    clock = make_clock(dt=dt, units='seconds')
    some_sys = SomeSystem(name='some_sys', dt=clock.dt)

In [4]:
ds,ss = diff_sys,diff_sys.some_sys

update_seq = set_derivative(
    ds,
    ss.loss,
    ss.theta,
    "dldt",
    update_seq=True
)

def update_loss_grad(update_seq):
    for n in update_seq:
        n.forward()

In [5]:
def calibrate_theta(tol,simulator,dt,max_iter,i=0.):
    
    def event():
        print('loss,theta:',ss.loss.value,ss.theta.value)
        t = ss.theta.value
        a = ss.alpha.value
        dldt = ds.dldt.value
        new_t = t - a*dldt
        ss.theta.value = new_t
        update_loss_grad(update_seq)
        ss.calc_loss.forward()
        
    print(f'calibration starts at t = {i*dt}')
    iter=1
    while ss.loss.value>tol and iter<max_iter:
        simulator.add_event(i*dt,event)
        simulator.forward()
        iter += 1
    print(f"calibration over with {iter} iterations")
    print('loss,theta:',ss.loss.value,ss.theta.value)
    

In [6]:
def update_x_y():
    ss.x.value = np.array([10.,20.])
    ss.y.value = np.array([6.,7.])

In [7]:
simulator = Simulator(ds)
tol=1e-5
for i in range(4):
    simulator.forward()
    simulator.transition()
    if i == 3:
        simulator.add_event(i*dt, update_x_y)
        simulator.forward()
        calibrate_theta(tol,simulator,dt,100,i=i)

calibration starts at t=3.0
loss,theta: 202.5 1.0
loss,theta: 50.624992 0.54999995
loss,theta: 12.656246 0.32499996
loss,theta: 3.1640625 0.21249998
loss,theta: 0.7910156 0.15624997
loss,theta: 0.19775324 0.12812497
loss,theta: 0.049438477 0.11406249
loss,theta: 0.012359619 0.10703124
loss,theta: 0.0030899048 0.10351562
loss,theta: 0.0007724762 0.1017578
loss,theta: 0.00019311905 0.100878894
loss,theta: 4.8279762e-05 0.100439444
loss,theta: 1.2069941e-05 0.10021972
calibration over with 14 iterations
loss,theta: 3.0174851e-06 0.10010985
