In [1]:
from scipy.integrate import solve_ivp
from dotmap import DotMap
import jax.numpy as jnp
import jax
from plotly.subplots import make_subplots
import plotly.io as pio
pio.templates.default='plotly_dark'

In [25]:
def switch(f_orig, start=0., end=jnp.inf, sharp=100):
    def f(t,y=None):
        return (jax.nn.sigmoid(sharp*(t-start)) - jax.nn.sigmoid(sharp*(t-end)))*f_orig(t-start,y)
    return f

def qin(t,h):
    return 0.1 + switch(lambda t,h: 0.1, 100, 200)(t,h)

def dh(t, h):
    h1=h[0]
    h2=h[1]
    q=s.Cv1*jnp.copysign(jnp.sqrt(jnp.abs(h1-h2)), h1-h2)
    dh1=(qin(t,h)-q)/s.A1
    dh2=(q-s.Cv2*jnp.sqrt(h2))/s.A2
    return jnp.asarray([dh1, dh2])

In [26]:
s=DotMap()
s.A1 = 0.5
s.A2 = 2.
s.Cv1 = 0.1
s.Cv2 = 0.2
s.h_init = jnp.array([1.,1.5])
s.tend=300.


res = solve_ivp(dh, (0.,s.tend), s.h_init, method='Radau', dense_output=True, jac=jax.jacfwd(dh,1))
t=jnp.linspace(0,s.tend,100)
h1 = res.sol(t)[0]
h2 = res.sol(t)[1]

fig=make_subplots()
fig.add_scatter(x=t, y=h1, mode='lines', name='h1')
fig.add_scatter(x=t, y=h2, mode='lines', name='h2')
fig.show()