<a href="https://colab.research.google.com/github/JA4S/JANC/blob/main/examples/janc_amr_example1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install JANC and import relevant libraries

In [None]:
# Copyright Â© 2025 Haocheng Wen, Faxuan Luo
# SPDX-License-Identifier: MIT

!pip install git+https://github.com/JA4S/JANC.git
!wget https://raw.githubusercontent.com/JA4S/JANC/main/examples/9sp-19r-H2-Air.yaml

In [None]:
from janc.thermodynamics import thermo
from janc.solver import solver
from jaxamr import amr
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

jax.config.update("jax_enable_x64", True)
jax.config.update('jax_platform_name', 'gpu')

# Set grid and AMR parameters

In [None]:
Lx = 0.05
Ly = 0.0125

nx = 2000
ny = 500

dx = Lx/nx
dy = Ly/ny

base_grid = {'Lx':Lx,'Ly':Ly,'Nx':nx,'Ny':ny}

n_block = [
    [1, 1],  # Level 0
    [100, 25], # Level 1
    [2, 2],  # Level 2
    [2, 2],  # Level 3
    [2, 2]   # Level 4
    ] # x-direction, y-direction

template_node_num = 3

buffer_num = 5

refinement_tolerance = {
    'density': 20.0,
    'velocity': 0.5
}

amr_config = {'base_grid':base_grid,
        'n_block':n_block,
        'template_node_num':template_node_num,
        'buffer_num':buffer_num,
        'refinement_tolerance':refinement_tolerance
}

amr.set_amr(amr_config)

dx = [dx] # Grid size in refinement levels
dy = [dy]
for i, (bx, by) in enumerate(n_block[1:], 1):
    dx.append(Lx/nx / (2.0**i))
    dy.append(Ly/ny / (2.0**i))

# Set thermodynamics (thermo & chemical properties) of gas mixture

In [None]:
thermo_config = {'is_detailed_chemistry':True,
        'thermo_model':'nasa7',
        'mechanism_diretory':'9sp-19r-H2-Air.yaml'}
thermo.set_thermo(thermo_config)

# Set boundary conditions

In [None]:
boundary_config = {'left_boundary':'slip_wall',
           'right_boundary':'slip_wall',
           'bottom_boundary':'slip_wall',
           'up_boundary':'slip_wall'}

# Initializations

In [None]:
advance_one_step, rhs = solver.set_solver(thermo_config,boundary_config,solver_mode='amr')
advance_one_step_L0, rhs0 = solver.set_solver(thermo_config,boundary_config,solver_mode='base')

In [None]:
def initial_conditions():
    #nondimensionalize the pressure and tempreature using P0,T0
    Penv = 1.0*101325; Tenv = 300; yH2env = 0.028; yO2env = 0.226;
    Pignition = 75*101325; Tignition = 3500;
    #set the mass fractions for the species (except inert species N2, which leaves 8 species to set)
    Yenv = jnp.array([yH2env,yO2env,0,0,0,0,0,0])

    #set ignition zone (rectangular shape)
    ignition_width = 80
    ignition_height = ny
    Y_init = jnp.broadcast_to(Yenv[:,None,None],(8,nx,ny))
    T_init = jnp.full((1,nx,ny),Tenv)
    T_init = T_init.at[:,0:ignition_width,0:ignition_height].set(Tignition)
    P_init = jnp.full((1,nx,ny),Penv)
    P_init = P_init.at[:,0:ignition_width,0:ignition_height].set(Pignition)
    #set ignition zone (circle shape)
    temp_x = jnp.linspace(0, nx, nx)
    temp_y = jnp.linspace(0, ny, ny)
    temp_x, temp_y = jnp.meshgrid(temp_y, temp_x)
    radius = ignition_width//2
    #three semicircle ignition zone to induce detonation cells
    distance_1 = jnp.sqrt((temp_y[None,:,:] - ignition_width)**2 + (temp_x[None,:,:] - ignition_height//4)**2)
    distance_2 = jnp.sqrt((temp_y[None,:,:] - ignition_width)**2 + (temp_x[None,:,:] - (ignition_height//4)*2)**2)
    distance_3 = jnp.sqrt((temp_y[None,:,:] - ignition_width)**2 + (temp_x[None,:,:] - (ignition_height//4)*3)**2)
    mask = (distance_1 <= radius) | (distance_2 <= radius) | (distance_3 <= radius)
    T_init = T_init.at[mask].set(Tignition)
    P_init = P_init.at[mask].set(Pignition)

    #get relevant thermo properties from tempreature and species mass fractions
    _,gamma_init,h_init,R_init,_ = thermo.get_thermo(T_init,Y_init)

    rho_init = P_init/(R_init*T_init)
    E_init = rho_init*h_init - P_init
    rhou_init = jnp.zeros((1,nx,ny))
    rhov_init = jnp.zeros((1,nx,ny))

    #concatenate the conservative variables U, and thermo variables aux(gamma,T)
    U_init = jnp.concatenate([rho_init,rhou_init,rhov_init,E_init,rho_init*Y_init],axis=0)
    aux_init = jnp.concatenate([gamma_init,T_init],axis=0)
    return U_init, aux_init


U, aux = initial_conditions()
plt.figure(figsize=(36, 7.5))
x = jnp.linspace(0, Lx, nx)
y = jnp.linspace(0, Ly, ny)
X, Y = jnp.meshgrid(x, y, indexing='ij')
#Density Contour (Nondimensionalized)
plt.contourf(X, Y, U[0], levels=50, cmap='viridis')
plt.colorbar()
plt.axis('equal')


blk_data0 = jnp.concatenate([U,aux],axis=0)[None,:,:,:]

blk_info0 = {
  'number': 1,
  'index': jnp.array([0, 0]),
  'glob_index': jnp.array([[0, 0]]),
  'neighbor_index': jnp.array([[-1, -1, -1, -1]])}


# AMR Loop

In [None]:
dt = 1e-9  # time step

nt = 200 # computation steps

amr_update_step = 5 # AMR update steps

amr_initialized = False

for step in tqdm(range(nt), desc="Progress", unit="step"):

    if amr_initialized == False :

        blk_data1, blk_info1, max_blk_num1 = amr.initialize(1, blk_data0, blk_info0, 'density', dx[1], dy[1])
        blk_data2, blk_info2, max_blk_num2 = amr.initialize(2, blk_data1, blk_info1, 'density', dx[2], dy[2])
        blk_data3, blk_info3, max_blk_num3 = amr.initialize(3, blk_data2, blk_info2, 'density', dx[3], dy[3])

        amr_initialized = True

    elif (step % amr_update_step == 0):
        blk_data1, blk_info1, max_blk_num1 = amr.update(1, blk_data0, blk_info0, 'density', dx[1], dy[1], blk_data1, blk_info1, max_blk_num1)
        blk_data2, blk_info2, max_blk_num2 = amr.update(2, blk_data1, blk_info1, 'density', dx[2], dy[2], blk_data2, blk_info2, max_blk_num2)
        blk_data3, blk_info3, max_blk_num3 = amr.update(3, blk_data2, blk_info2, 'density', dx[3], dy[3], blk_data3, blk_info3, max_blk_num3)


    '''Crossover advance'''
    for _ in range(2):
        for _ in range(2):
            for _ in range(2):
                blk_data3 = advance_one_step(3, blk_data2, dx[3], dy[3], dt/8.0, blk_data3, blk_info3)
            blk_data2 = advance_one_step(2, blk_data1, dx[2], dy[2], dt/4.0, blk_data2, blk_info2)
        blk_data1 = advance_one_step(1, blk_data0, dx[1], dy[1], dt/2.0, blk_data1, blk_info1)
    blk_data0 = jnp.array([advance_one_step_L0(blk_data0[0], dx[0], dy[0], dt)])


    '''Synchronous advance'''
    #blk_data3 = solver.rk2(3, blk_data2, dx[3], dy[3], dt/8.0, blk_data3, blk_info3)
    #blk_data2 = solver.rk2(2, blk_data1, dx[2], dy[2], dt/8.0, blk_data2, blk_info2)
    #blk_data1 = solver.rk2(1, blk_data0, dx[1], dy[1], dt/8.0, blk_data1, blk_info1)
    #blk_data0 = solver.rk2_L0(blk_data0, dx[0], dy[0], dt/8.0)

    blk_data2 = amr.interpolate_fine_to_coarse(3, blk_data2, blk_data3, blk_info3)
    blk_data1 = amr.interpolate_fine_to_coarse(2, blk_data1, blk_data2, blk_info2)
    blk_data0 = amr.interpolate_fine_to_coarse(1, blk_data0, blk_data1, blk_info1)


#Visulization

In [None]:
# Density Contour (Nondimensionalized)
plt.figure(figsize=(36, 7.5))
ax = plt.gca()
from jaxamr import amraux
component = 0
vmin = jnp.min(blk_data0[:, component])
vmax = jnp.max(blk_data0[:, component])
vrange = (vmin, vmax)
fig = amraux.plot_block_data(blk_data0[:, component], blk_info0, ax, vrange) # Level 0
fig = amraux.plot_block_data(blk_data1[:, component], blk_info1, ax, vrange) # Level 1
fig = amraux.plot_block_data(blk_data2[:, component], blk_info2, ax, vrange) # Level 2
fig = amraux.plot_block_data(blk_data3[:, component], blk_info3, ax, vrange) # Level 3

plt.colorbar(fig, ax=ax, label='Density')
ax.set_xlabel('X')
ax.set_ylabel('Y')
plt.axis('equal')
plt.show()


# AMR Level Contour
plt.figure(figsize=(36, 7.5))
ax = plt.gca()
from jaxamr import amraux
component = 0
vmin = jnp.min(blk_data0[:, component])
vmax = jnp.max(blk_data0[:, component])
vrange = (0, 3)
fig = amraux.plot_block_data(0*jnp.ones_like(blk_data0[:, component]), blk_info0, ax, vrange) # Level 0
fig = amraux.plot_block_data(1*jnp.ones_like(blk_data1[:, component]), blk_info1, ax, vrange) # Level 1
fig = amraux.plot_block_data(2*jnp.ones_like(blk_data2[:, component]), blk_info2, ax, vrange) # Level 2
fig = amraux.plot_block_data(3*jnp.ones_like(blk_data3[:, component]), blk_info3, ax, vrange) # Level 3

plt.colorbar(fig, ax=ax, label='Refinement Level')
ax.set_xlabel('X')
ax.set_ylabel('Y')
plt.axis('equal')
plt.show()