<a href="https://colab.research.google.com/github/profteachkids/CHE2064/blob/master/Demo_ODE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Import libraries
This section is somewhat intimidating to new users, but modifications are made only to access advanced features.

In [None]:
!git clone --depth 1 https://github.com/profteachkids/CHE2064.git &> /dev/null
!pip install DotMap &> /dev/null

import sys
sys.path.insert(1, "/content/CHE2064") #Path to CHE module imports

import jax
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True) #JAX default is 32bit single precision

from scipy.integrate import solve_ivp

from plotly.subplots import make_subplots
import plotly.io as pio
pio.templates.default='plotly_dark'

from dotmap import DotMap
from tools.dynamics import onoff_val, ramp, VX

## Model

Models are written in easy to read DotMap (dictionary) structures.  The *dynamics.VX* class enables facile transformation methods between DotMap structures and flat arrays required by ODE solvers.

In [None]:
def q1in(t):
  # Volumetric flowrates of 3 streams, one for each component into the 1st tank.
  return jnp.array([ (0.1 + ramp(0.1, 100, 150)(t) -
                      ramp(0.08, 200, 250)(t)),
                    0.1 + onoff_val(0.05, 10, 50)(t),
                    0.1 + onoff_val(0.03, 300, sharp=0.1)(t)])
  
def q2in(t):
  # Volumetric flowrates of 3 streams, one for each component into the 2nd tank.
  return jnp.array([ onoff_val(0.02, 40, 60)(t),
                    onoff_val(0.01, 80, 100)(t),
                    onoff_val(0.02, 120, sharp=0.1)(t) ])

def model(t, v, s, dv):
  # dv is an automatically generated DotMap, structured identically to the
  # dynamic variables defined in v

  V1 = jnp.sum(v.m1 / s.rho)  # Liquid volume and density in each tank
  rho1 = jnp.sum(v.m1) / V1   # Ideal mixing
  V2 = jnp.sum(v.m2 / s.rho)
  rho2 = jnp.sum(v.m2) / V2

  # Liquid flows in/out of tanks: Cv * sqrt(height differential)
  h1 = V1 / s.A1
  h2 = V2 / s.A2
  q12=s.Cv1*jnp.copysign(jnp.sqrt(jnp.abs(h1-h2)),h1-h2)
  q2 = s.Cv2*jnp.sqrt(h2)

  w1 = v.m1/jnp.sum(v.m1) # Weight fractions
  w2 = v.m2/jnp.sum(v.m2)

  m12 = rho1 * q12 * w1  # Mass flow of each component from tank 1 to 2

  # Change in masses: In - Out
  dv.m1 = s.rho*q1in(t) - m12
  dv.m2 = s.rho*q2in(t) + m12 - rho2*q2 * w2
  return dv


In [None]:
# Static parameters
s=DotMap()
s.A1 = 2.0 # Tank cross-section areas
s.A2 = 1.5
s.Cv1 = 0.5 # Discharge coefficients 
s.Cv2 = 0.5
s.rho = jnp.array([950., 900., 800.]) # Pure component densities
s.tend = 400. # Simulation time 

# Variables (dynamic)
v=DotMap()
v.m1 = jnp.array([150., 150., 200.]) # Initial masses of each components 
v.m2 = jnp.array([100., 100., 150.]) # in each tank

# This is where the "magic" occurs.  It is a one-liner!
# VX class provides methods for converting between DotMap (dictionary) structures
# array structures.  And yes, the DotMap structures can have nested DotMaps in
# nested Lists with nested arrays.
vx = VX(v, s)

# VX transformed model is JAX compiled
# With automatic Jacobians from JAX, fully-implicit Radau method works very well.
model_f = jax.jit(vx.transform(model))
res = solve_ivp(model_f, (0.,tend), vx.x, method='Radau', dense_output=True, jac=jax.jacfwd(model_f,1))


No GPU/TPU found, falling back to CPU.



## Plots
Generating good-looking plots does take effort.

In [None]:
t=jnp.linspace(0,tend,1000)

# transform solution from array form back to DotMap structure of the dynamic
# variables.
r = vx.soltov(res.sol(t))
m1sum=jnp.sum(r.m1,axis=0)
m2sum=jnp.sum(r.m2,axis=0)
fig=make_subplots(2,4)
colors=['red', 'green','blue']
q1in_vec=jax.vmap(q1in)
q2in_vec=jax.vmap(q2in)

# Total mass in each tank
fig.add_scatter(x=t, y=m1sum, mode='lines', name='total', line_color='rgb(200,200,200)', row=1,col=1)
fig.add_scatter(x=t, y=m2sum, mode='lines', name='Total', line_color='rgb(200,200,200)', 
                showlegend=False, row=1,col=3)

for i in range(3):
  # External flows into each tank
  fig.add_scatter(x=t,y=q1in_vec(t)[:,i], line_color=colors[i],showlegend=False, row=1, col=2)
  fig.add_scatter(x=t,y=q2in_vec(t)[:,i], line_color=colors[i],showlegend=False, row=1, col=4)

  # Component masses for each tank
  fig.add_scatter(x=t, y=r.m1[i], mode='lines',
                  name=f'Comp {i+1}', line_color=colors[i], row=2,col=1)
  fig.add_scatter(x=t, y=r.m2[i], mode='lines', 
                  line_color=colors[i], showlegend=False, row=2,col=3)

  # Mass fractions for each tank
  fig.add_scatter(x=t, y=r.m1[i]/m1sum, mode='lines', line_color=colors[i], showlegend=False, row=2,col=2)
  fig.add_scatter(x=t, y=r.m2[i]/m2sum, mode='lines', line_color=colors[i], 
                  showlegend=False, row=2,col=4)

fig.add_annotation(text='Tank 1', xanchor='center', xref='paper', yref='paper',
                   font=dict(size=20), x=0.25, y=1.15, showarrow=False)
fig.add_annotation(text='Tank 2', xanchor='center', xref='paper', yref='paper',
                   font=dict(size=20), x=0.75, y=1.15, showarrow=False)

fig.update_layout(width=1000, height=600, yaxis4_matches='y2', yaxis3_matches='y1', yaxis7_matches='y5',
                  yaxis8_matches='y6')
fig.show()