<a href="https://colab.research.google.com/github/profteachkids/CHE2064/blob/master/Tank_PID_control.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone --depth 1 https://github.com/profteachkids/CHE2064.git &> /dev/null
!pip install DotMap &> /dev/null
import sys
sys.path.insert(1, "/content/CHE2064") #Path to CHE module imports

In [2]:
import jax
import jax.numpy as jnp
from jax.config import config
from jax.experimental.host_callback import id_print
config.update("jax_enable_x64", True) #JAX default is 32 bit - enable 64 bit - double precision

from scipy.integrate import solve_ivp

from plotly.subplots import make_subplots
import plotly.io as pio
pio.templates.default='plotly_dark'

from dotmap import DotMap
from tools.dynamics import onoff_val, ramp, VX


In [115]:
def hsp(t):
  return 1. + onoff_val(1.,start=20., end=40.,sharp=10)(t)

dhsp = jax.grad(hsp)
d2hsp = jax.grad(dhsp)

@jax.jit
def dy(t, y):
  h = y[0]
  qin = y[1]
  dh = (qin - Cv*jnp.sqrt(h))/A

  dqin = (Kc*(dhsp(t)-dh)+Kc/taui*(hsp(t)-h)  + Kc/taud/A*Cv/2/jnp.sqrt(h)*(d2hsp(t)-dh))/(1+Kc/taud/A)
  return [dh, dqin]

Kc=1.
taui=0.2
taud=3
A = 1.
Cv = 2.

h0 = 1.
qin0 = Cv*jnp.sqrt(h0)/A
tend = 60

res = solve_ivp(dy, (0.,tend), [h0,qin0], method='Radau', dense_output=True, jac=jax.jacfwd(dy,1))
t_plot = jnp.linspace(0,tend,1000)
h, qin = res.sol(t_plot)
fig = make_subplots(rows=1, cols=2)
fig.add_scatter(x=t_plot, y=h, row=1, col=1, name='height')
fig.add_scatter(x=t_plot, y=hsp(t_plot), row=1, col=1, name='hsp')
fig.add_scatter(x=t_plot, y=qin, row=1, col=2, name='qin')
fig.update_layout(width=1000, height=500)

In [5]:
res

  message: 'The solver successfully reached the end of the integration interval.'
     nfev: 44
     njev: 1
      nlu: 10
      sol: <scipy.integrate._ivp.common.OdeSolution object at 0x7f4254c769b0>
   status: 0
  success: True
        t: array([0.        , 0.0109078 , 0.0218156 , 0.04861618, 0.1654074 ,
       1.1129144 , 5.        ])
 t_events: None
        y: array([[1.        , 1.00000996, 1.00002081, 1.00004759, 1.00016431,
        1.00111074, 1.00498405],
       [1.        , 1.00094949, 1.00100686, 1.00102307, 1.00108149,
        1.00155361, 1.00348345]])
 y_events: None