<a href="https://colab.research.google.com/github/JA4S/JANC/blob/main/examples/janc_basic_example2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install JANC and import relevant libraries

In [None]:
!pip install git+https://github.com/JA4S/JANC.git

In [None]:
from janc.preprocess.nondim import x0, P0, T0, t0
import janc.thermodynamics.thermo as thermo
import janc.solver.solver as solver

import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# set JAX to use GPU
jax.config.update("jax_enable_x64", True)
jax.config.update('jax_platform_name', 'gpu')

# Example : H2-Air Premixed RDC

# Set grid

In [None]:
Lx = 0.20/x0
Ly = 0.08/x0

nx = 2000
ny = 800

dx = Lx/nx
dy = Ly/ny

# Set thermodynamics (thermo & chemical properties) of gas mixture

In [None]:
thermo_config = {'is_detailed_chemistry':True,
        'thermo_model':'nasa7',
        'mechanism_diretory':'9sp-19r-H2-Air.yaml'}
thermo.set_thermo(thermo_config)

# User-defined boundary conditions and source terms

In [None]:
Lx = 0.20/x0
Ly = 0.08/x0

nx = 2000
ny = 800

dx = Lx/nx
dy = Ly/ny

In [None]:
##inlet boundary
#injection equivalence ratio
ratio = 1.0
yH2inj = (ratio/8*0.232)/(1 + ratio/8*0.232)
yO2inj = 0.232/(1 + ratio/8*0.232)

##outlet boundary
#back pressure
Pb = 101325/P0

Yinj = jnp.concatenate([jnp.array([yH2inj,yO2inj]),1e-20*jnp.ones((8-2))],axis=0)
Yinj_cor = jnp.expand_dims(Yinj,(1,2))
Yinj_cor = jnp.tile(Yinj_cor,(1,nx+2*3,1))
na = 8


def inj_model(p):
    A1 = 1
    A3 = 5
    A2 = A3-A1
    R = thermo.get_R(Yinj_cor)
    gamma = 1.29
    C0 = jnp.sqrt(gamma*R*1.0)

    M = jnp.zeros_like(p)
    P1 = 1.0*(1+(gamma-1)/2*M**2)**(-gamma/(gamma-1))
    V1 = M*(1+(gamma-1)/2*M**2)**(-0.5)*C0
    MFC = A1*1.0/jnp.sqrt(1.0)*jnp.sqrt(gamma/R)*M*(1+(gamma-1)/2*M**2)**(-(gamma+1)/2/(gamma-1))
    A = 0.5
    P3 = p
    B = gamma/(gamma-1)*P3*A3/MFC
    C = -gamma/(gamma-1)*R*1.0
    V3 = (-B+jnp.sqrt(B**2-4*A*C))/(2*A)
    P2 = (MFC*(V3-V1)-P1*A1+P3*A3)/A2

    M1 = jnp.zeros_like(p)
    M2 = jnp.ones_like(p)
    p_cor = p

    for i in range(20):
        M = 0.5*(M1+M2)
        P1 = 1.0*(1+(gamma-1)/2*M**2)**(-gamma/(gamma-1))
        V1 = M*(1+(gamma-1)/2*M**2)**(-0.5)*C0
        MFC = A1*1.0/jnp.sqrt(1.0)*jnp.sqrt(gamma/R)*M*(1+(gamma-1)/2*M**2)**(-(gamma+1)/2/(gamma-1))
        A = 0.5
        B = gamma/(gamma-1)*P3*A3/MFC
        C = -gamma/(gamma-1)*R*1.0
        V3 = (-B+jnp.sqrt(B**2-4*A*C))/(2*A)
        P2 = (MFC*(V3-V1)-P1*A1+P3*A3)/A2

        M2 = jax.lax.select(P2>=P1,M,M2)
        M1 = jax.lax.select(P2<P1,M,M1)

    rho_cor = MFC/V3/A3
    v_cor = V3
    T_cor = p_cor/(R*rho_cor)
    _, gamma, h_cor, _, _ = thermo.get_thermo(T_cor,Yinj_cor)
    return v_cor, T_cor, h_cor, gamma



def left_boundary(U_periodic_pad,aux_periodic_pad):
    #To accommodate periodic boundaries, the shape of `U_periodic_pad` is [flux_num,nx+2*3,ny+2*3], and `aux_periodic_pad` is [2,nx+2*3,ny+2*3].
    #aux[0]:gamma aux[1]:Tempreature
    #Among them, the top and bottom boundaries have already been padded with periodic values.
    #To replace the periodic boundaries, simply index the left boundary using 【U_periodic_pad[:,3:4,:]】 and use it to compute the new boundary values.
    #The function needs to 【replace】 the ghost cells on the left side 【without changing the original shape】 of the input.
    return U_periodic_pad,aux_periodic_pad

def right_boundary(U_periodic_pad,aux_periodic_pad):
    #To accommodate periodic boundaries, the shape of `U_periodic_pad` is [flux_num,nx+2*3,ny+2*3], and `aux_periodic_pad` is [2,nx+2*3,ny+2*3].
    #aux[0]:gamma aux[1]:Tempreature
    #Among them, the top and bottom boundaries have already been padded with periodic values.
    #To replace the periodic boundaries, simply index the right boundary using 【U_periodic_pad[:,-4:-3,:]】 and use it to compute the new boundary values.
    #The function needs to 【replace】 the ghost cells on the right side 【without changing the original shape】 of the input.
    return U_periodic_pad,aux_periodic_pad

def bottom_boundary(U_periodic_pad_with_x_boundaries,aux_periodic_pad_with_x_boundaries):
    #The input must have already had the left and right boundary conditions replaced.
    #The top and bottom boundary conditions are still pre-padded with periodic values.
    state_periodic = U_periodic_pad_with_x_boundaries
    gamma_periodic = aux_periodic_pad_with_x_boundaries[0:1,:,:]
    T_periodic = aux_periodic_pad_with_x_boundaries[1:2,:,:]

    # inlet boundary
    # 【Note】:physical boundary starts at index 3:4, not 0:1，since 0：3 is padded with periodical boundaries
    state_in = state_periodic[:,:,3:4]
    gamma_in = gamma_periodic[:,:,3:4]
    T_in = T_periodic[:,:,3:4]

    rho_in = state_in[0:1,:,:]
    u_in = state_in[1:2,:,:]/rho_in
    v_in = state_in[2:3,:,:]/rho_in
    Y_in = state_in[4:,:,:]/rho_in
    _, _, h_in, R_in, _ = thermo.get_thermo(T_in,Y_in)
    p_in = rho_in*R_in*T_in
    a_in = jnp.sqrt(gamma_in*p_in/rho_in)

    u_temp = jnp.zeros_like(u_in)
    Y_temp = Yinj_cor
    v_temp, T_temp, h_temp, gamma_temp = inj_model(p_in)
    R_temp = thermo.get_R(Y_temp)
    rho_temp = p_in/(R_temp*T_temp)

    mask_in = (p_in >= 1.0)
    rho_cor_in = jax.lax.select(mask_in,rho_in,rho_temp)
    u_cor_in = jax.lax.select(mask_in,u_in,u_temp)
    v_cor_in = jax.lax.select(mask_in,-v_in,v_temp)
    T_cor_in = jax.lax.select(mask_in,T_in,T_temp)
    p_cor_in = p_in
    h_cor_in = jax.lax.select(mask_in,h_in,h_temp)
    na = Yinj_cor.shape[0]
    Y_cor_in = jax.lax.select(jnp.tile(mask_in,(na,1,1)),Y_in,Y_temp)
    gamma_cor_in = jax.lax.select(mask_in,gamma_in,gamma_temp)

    U_lower_bound_state = jnp.concatenate([rho_cor_in, rho_cor_in * u_cor_in, rho_cor_in * v_cor_in,
                     rho_cor_in*h_cor_in - p_cor_in + 0.5 * rho_cor_in * (u_cor_in ** 2 + v_cor_in ** 2),
                     rho_cor_in * Y_cor_in], axis=0)
    aux_lower_bound_state = jnp.concatenate([gamma_cor_in,T_cor_in], axis=0)

    U = jnp.concatenate([U_lower_bound_state,U_lower_bound_state,U_lower_bound_state,U_periodic_pad_with_x_boundaries[:,:,3:]],axis=2)
    aux = jnp.concatenate([aux_lower_bound_state,aux_lower_bound_state,aux_lower_bound_state,aux_periodic_pad_with_x_boundaries[:,:,3:]],axis=2)
    return U,aux

def up_boundary(U_periodic_pad_with_x_boundaries,aux_periodic_pad_with_x_boundaries):
    #The input must have already had the left and right boundary conditions replaced.
    #The top and bottom boundary conditions are still pre-padded with periodic values.
    state_periodic = U_periodic_pad_with_x_boundaries
    gamma_periodic = aux_periodic_pad_with_x_boundaries[0:1,:,:]
    T_periodic = aux_periodic_pad_with_x_boundaries[1:2,:,:]

    # outlet boundary
    # pressure outlet
    # 【Note】:physical boundary starts at index -4:-3, not -3:，since -3: is padded with periodical boundaries
    state_out = state_periodic[:,:,-4:-3]
    gamma_out = gamma_periodic[:,:,-4:-3]
    T_out = T_periodic[:,:,-4:-3]

    rho_out = state_out[0:1,:,:]
    u_out = state_out[1:2,:,:]/rho_out
    v_out = state_out[2:3,:,:]/rho_out
    Y_out = state_out[4:,:,:]/rho_out
    R_out = thermo.get_R(Y_out)
    p_out = rho_out*(R_out*T_out)
    a_out = jnp.sqrt(gamma_out*p_out/rho_out)
    mask = (v_out/a_out < 1)
    rho_cor_out = jax.lax.select(mask, Pb / (p_out / rho_out),rho_out)
    p_cor_out = jax.lax.select(mask, Pb*jnp.ones_like(p_out),p_out)
    T_cor_out = jax.lax.select(mask, p_cor_out/(rho_cor_out*R_out),T_out)
    _, gamma_out, h_out, _, _ = thermo.get_thermo(T_cor_out,Y_out)
    U_upper_bound_state = jnp.concatenate([rho_cor_out, rho_cor_out * u_out, rho_cor_out * v_out,
                      rho_cor_out*h_out - p_cor_out + 0.5 * rho_cor_out * (u_out ** 2 + v_out ** 2),
                      rho_cor_out * Y_out], axis=0)
    aux_upper_bound_state = jnp.concatenate([gamma_out,T_cor_out], axis=0)
    U = jnp.concatenate([U_periodic_pad_with_x_boundaries[:,:,:-3],U_upper_bound_state,U_upper_bound_state,U_upper_bound_state],axis=2)
    aux = jnp.concatenate([aux_periodic_pad_with_x_boundaries[:,:,:-3],aux_upper_bound_state,aux_upper_bound_state,aux_upper_bound_state],axis=2)
    return U, aux

boundary_conditions = {'left_boundary':left_boundary,
             'right_boundary':right_boundary,
             'bottom_boundary':bottom_boundary,
             'up_boundary':up_boundary}

boundary_set = {'boundary_conditions':boundary_conditions}

# Set boundary conditions

Here we present an easy example to implement user-defined boundary conditions.
In JANC, implementations of user-defined boudnary conditions are simple,there isn't any .py file to be modified.
All you need is to defiend a function right here, with following requirements:

In [None]:
#def usr_boundary(U_bd,aux_bd,theta=None):

##U_bd and aux_bd is 3 nearest layers of grids to the boundary

##for example, if this function is used on right boundary,
##the shape of U_bd and aux_bd would be (variable_num,3,ny),(2,3,ny),
##which means U_bd = U[:,-3:,:]
##if this function is used on bottom boundary,U_bd = U[:,:,0:3]

##theta is a pytree (dict) containing any parameters you might need to define your functions

##the outputs of this function is 3 layers of ghost cells,
##and they should have the same shapes as the inputs
#return U, aux


##inlet boundary
#injection equivalence ratio
ratio = 1.0
yH2inj = (ratio/8*0.232)/(1 + ratio/8*0.232)
yO2inj = 0.232/(1 + ratio/8*0.232)

from janc.solver.aux_func import U_to_prim
def bottom_boundary(U_bd, aux_bd, theta=None):
    #for bottom boundary (injection plane),the shape of U_bd is (flux_num,nx,3)
    #aux_bd shape is (2,nx,3), aux_bd[0:1]:gamma, aux_bd[1:2]:Tempreature
    #U_bd = U[:,:,0:3]
    #In this case, theta contains mass fractions of the injection properllants

    U_in = U_bd[:,:,0:1]
    aux_in = aux_bd[:,:,0:1]
    rho_in,u_in,v_in,Y_in,p_in,a_in = U_to_prim(U_in,aux_in)
    T_in = aux_in[1:2,:,0:1]
    gamma_in = aux_in[1:2,:,0:1]
    _, _, h_in, R_in, _ = thermo.get_thermo(T_in,Y_in)

    #interior pressure < injection pressure: velocity inlet
    u_inj = jnp.zeros_like(u_in)
    Y_inj = theta['Yinj']
    v_inj, T_inj, h_inj, gamma_inj = inj_model(p_in,Y_inj)
    R_inj = thermo.get_R(Y_inj)
    rho_inj = p_in/(R_inj*T_inj)

    #interior pressure >= injection pressure: slip wall
    #injection is blocked
    mask_block = (p_in >= 1.0)
    rho_cor_in = jax.lax.select(mask_block,rho_in,rho_inj)
    u_cor_in = jax.lax.select(mask_block,u_in,u_inj)
    v_cor_in = jax.lax.select(mask_block,-v_in,v_inj)
    T_cor_in = jax.lax.select(mask_block,T_in,T_inj)
    p_cor_in = p_in
    h_cor_in = jax.lax.select(mask_block,h_in,h_inj)
    Y_cor_in = jax.lax.select(jnp.tile(mask_block,(theta['Yinj'].shape[0],1,1)),Y_in,Y_inj)
    gamma_cor_in = jax.lax.select(mask_block,gamma_in,gamma_inj)

    U_lower_bound_state = jnp.concatenate([rho_cor_in, rho_cor_in * u_cor_in, rho_cor_in * v_cor_in,
                        rho_cor_in*h_cor_in - p_cor_in + 0.5 * rho_cor_in * (u_cor_in ** 2 + v_cor_in ** 2),
                        rho_cor_in * Y_cor_in], axis=0)
    aux_lower_bound_state = jnp.concatenate([gamma_cor_in,T_cor_in], axis=0)

    U = jnp.tile(U_lower_bound_state,(1,1,3))
    aux = jnp.tile(aux_lower_bound_state,(1,1,3))
    return U, aux


#injection model from Fievisohn et al.2017, see https://doi.org/10.2514/1.B36103 for details
#this is a pressure feedback function, using pressure from the interior grid to infer the velocity et al.
#(like velocity inlet in ANSYS FLUENT)
def inj_model(p,Y_inj):
    A1 = 1
    A3 = 5
    A2 = A3-A1
    R = thermo.get_R(Y_inj)
    gamma = 1.29
    C0 = jnp.sqrt(gamma*R*1.0)

    M = jnp.zeros_like(p)
    P1 = 1.0*(1+(gamma-1)/2*M**2)**(-gamma/(gamma-1))
    V1 = M*(1+(gamma-1)/2*M**2)**(-0.5)*C0
    MFC = A1*1.0/jnp.sqrt(1.0)*jnp.sqrt(gamma/R)*M*(1+(gamma-1)/2*M**2)**(-(gamma+1)/2/(gamma-1))
    A = 0.5
    P3 = p
    B = gamma/(gamma-1)*P3*A3/MFC
    C = -gamma/(gamma-1)*R*1.0
    V3 = (-B+jnp.sqrt(B**2-4*A*C))/(2*A)
    P2 = (MFC*(V3-V1)-P1*A1+P3*A3)/A2

    M1 = jnp.zeros_like(p)
    M2 = jnp.ones_like(p)
    p_cor = p

    for i in range(20):
        M = 0.5*(M1+M2)
        P1 = 1.0*(1+(gamma-1)/2*M**2)**(-gamma/(gamma-1))
        V1 = M*(1+(gamma-1)/2*M**2)**(-0.5)*C0
        MFC = A1*1.0/jnp.sqrt(1.0)*jnp.sqrt(gamma/R)*M*(1+(gamma-1)/2*M**2)**(-(gamma+1)/2/(gamma-1))
        A = 0.5
        B = gamma/(gamma-1)*P3*A3/MFC
        C = -gamma/(gamma-1)*R*1.0
        V3 = (-B+jnp.sqrt(B**2-4*A*C))/(2*A)
        P2 = (MFC*(V3-V1)-P1*A1+P3*A3)/A2

        M2 = jax.lax.select(P2>=P1,M,M2)
        M1 = jax.lax.select(P2<P1,M,M1)

    rho_cor = MFC/V3/A3
    v_cor = V3
    T_cor = p_cor/(R*rho_cor)
    _, gamma, h_cor, _, _ = thermo.get_thermo(T_cor,Y_inj)
    return v_cor, T_cor, h_cor, gamma


##inlet boundary
#injection equivalence ratio
ratio = 1.0
yH2inj = (ratio/8*0.232)/(1 + ratio/8*0.232)
yO2inj = 0.232/(1 + ratio/8*0.232)

Yinj = jnp.concatenate([jnp.array([yH2inj,yO2inj]),1e-20*jnp.ones((8-2))],axis=0)
Yinj_cor = jnp.expand_dims(Yinj,(1,2))
Yinj_cor = jnp.tile(jnp.expand_dims(Yinj,(1,2)),(1,nx+2*3,1))

##outlet boundary:pressure_outlet
#JANC has built-in pressure_outlet boundary conditions
#theta should contain 'Pb' when using pressure_outlet bc.
#back pressure
Pb = 101325/P0

theta = {'Yinj': jnp.tile(jnp.expand_dims(Yinj,(1,2)),(1,nx,1)),
      'Pb': Pb}

boundary_config = {'left_boundary':'periodic',
           'right_boundary':'periodic',
           'bottom_boundary':bottom_boundary,
           'up_boundary':'pressure_outlet'}

# Initializations

In [None]:
thermo_config = {'is_detailed_chemistry':True,
         'thermo_model':'nasa7',
         'mechanism_diretory':'9sp-19r-H2-Air.yaml'}

advance_one_step, rhs = solver.set_solver(thermo_config,boundary_config)
#advance_one_step: time advance functions, advance current state one time step dt.
#rhs: right-hand side of the Euler-equations: dUdt = rhs,
#normally, 【advance_one_step is all you need】. However, when it comes to machine-learning tasks,
#rhs can be embedded in a differentiable optimization loop

In [None]:
def initial_conditions():
    Penv = 1*101325/P0; Tenv = 300/T0; yH2env = 0; yO2env = 0.232;
    Pignition = 20*101325/P0; Tignition = 3000/T0;
    Yenv = jnp.array([yH2env,yO2env,0,0,0,0,0,0])
    Yfill = jnp.array([yH2inj,yO2inj,0,0,0,0,0,0])

    ignition_width = 108;ignition_height = 288

    Y_init = jnp.broadcast_to(Yenv[:,None,None],(8,nx,ny))
    Y_fill = jnp.broadcast_to(Yfill[:,None,None],(8, nx//2, ignition_height))
    Y_init = Y_init.at[:, 0:nx//2, 0:ignition_height].set(Y_fill)

    T_init = jnp.full((1,nx,ny),Tenv)
    T_init = T_init.at[:,0:ignition_width,0:ignition_height].set(Tignition)

    P_init = jnp.full((1,nx,ny),Penv)
    P_init = P_init.at[:,0:ignition_width,0:ignition_height].set(Pignition)

    _,gamma_init,h_init,R_init,_ = thermo.get_thermo(T_init,Y_init)

    rho_init = P_init/(R_init*T_init)
    E_init = rho_init*h_init - P_init
    rhou_init = jnp.zeros((1,nx,ny))
    rhov_init = jnp.zeros((1,nx,ny))


    U_init = jnp.concatenate([rho_init,rhou_init,rhov_init,E_init,rho_init*Y_init],axis=0)
    aux_init = jnp.concatenate([gamma_init,T_init],axis=0)
    return U_init, aux_init

U, aux = initial_conditions()
plt.figure(figsize=(10, 4))
x = jnp.linspace(0, Lx, nx)
y = jnp.linspace(0, Ly, ny)
X, Y = jnp.meshgrid(x, y, indexing='ij')
plt.contourf(X, Y, aux[-1], levels=50, cmap='viridis')

# Main loop of time advance

In [None]:
##minimum implementations of 【advance_one_step】:
dt = 5e-9/t0
nt = 20000
field = jnp.concatenate([U,aux],axis=0)
for step in tqdm(range(nt),desc="progress", unit="step"):
  field = advance_one_step(field,dx,dy,dt,theta)

# Plot

In [None]:
plt.figure(figsize=(10, 4))
plt.contourf(X, Y, field[-1], levels=50, cmap='viridis')
#plt.clim(0, 4)
plt.xlabel('x')
plt.ylabel('y')
plt.colorbar()
plt.tight_layout()
plt.axis('equal')
plt.show()