<a href="https://colab.research.google.com/github/profteachkids/CHE2064/blob/master/Demo_ODE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone --depth 1 https://github.com/profteachkids/CHE2064.git
!pip install DotMap
import sys
sys.path.insert(1, "/content/CHE2064")

Cloning into 'CHE2064'...
remote: Enumerating objects: 79, done.[K
remote: Counting objects: 100% (79/79), done.[K
remote: Compressing objects: 100% (34/34), done.[K
remote: Total 79 (delta 47), reused 54 (delta 43), pack-reused 0[K
Unpacking objects: 100% (79/79), done.
Collecting DotMap
  Downloading https://files.pythonhosted.org/packages/52/47/9ca39d01b872c1bf2da0f0031cb3c4e3a016170c181e34d889253e404d59/dotmap-1.3.17-py3-none-any.whl
Installing collected packages: DotMap
Successfully installed DotMap-1.3.17


In [2]:
import jax
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)
from scipy.integrate import solve_ivp
from plotly.subplots import make_subplots
import plotly.io as pio
from dotmap import DotMap
import tools.tree_array_transform as tat
pio.templates.default='plotly_dark'


In [49]:

def switch(f_orig, start=0., end=jnp.inf, sharp=100):
    def f(t,*args, **kwargs):
        return (jax.nn.sigmoid(sharp*(t-start)) - jax.nn.sigmoid(sharp*(t-end)))*f_orig(t-start, *args, **kwargs)
    return f

def onhold(f_orig, start, hold, sharp=100):
  def f(t, *args, **kwargs):
    return (switch(f_orig, start=start, sharp=sharp)(t, *args, **kwargs) - 
            switch(f_orig, start=start+hold, sharp=sharp)(t, *args, **kwargs))
  return f

def q1in(t):
    return jnp.array([ (0.1 + onhold(lambda t: 0.05*t/50, 100, 50)(t) -
                        onhold(lambda t: 0.05*t/50, 200, 50)(t)),
                      0.1 + switch(lambda t: 0.05, 10, 50)(t),
                      0.1 + switch(lambda t: 0.05, 300, 325)(t)])

def model(t, v, s, dv):

    V1 = jnp.sum(v.m1/s.rho)
    rho1 = jnp.sum(v.m1)/ V1
    q1out=s.Cv1*jnp.sqrt(V1/s.A1)

    w1 = v.m1/jnp.sum(v.m1)
    dv.m1=s.rho*q1in(t) - rho1*q1out * w1
    return dv


In [54]:

s=DotMap()
s.A1 = 2.0
s.Cv1 = 0.5
s.rho = jnp.array([1000., 900., 800.])

v=DotMap()
v.m1 = jnp.array([150., 150., 200.])

m = tat.VX(v, s)
tend=400.
model_f = jax.jit(m.transform(model))

res = solve_ivp(model_f, (0.,tend), m.x, method='Radau', dense_output=True, jac=jax.jacfwd(model_f,1))


In [55]:

t=jnp.linspace(0,tend,1000)
m=res.sol(t)
msum=jnp.sum(m,axis=0)
fig=make_subplots(2,2)
colors=['red', 'green','blue']
q1in_vec=jax.vmap(q1in)

fig.add_scatter(x=t, y=msum, mode='lines', name='total', line_color='rgb(200,200,200)', row=1,col=1)

for i in range(3):
  fig.add_scatter(x=t,y=q1in_vec(t)[:,i], line_color=colors[i],showlegend=False, row=1, col=2)
  fig.add_scatter(x=t, y=m[i], mode='lines', name=f'{i}', line_color=colors[i], row=2,col=1)
  fig.add_scatter(x=t, y=m[i]/msum, mode='lines', line_color=colors[i], showlegend=False, row=2,col=2)

fig.update_layout(width=600, height=600)
fig.show()