# JAX-Fluids: Automatic Differentiation demo
This demo will demonstrate how you can differentiate through a simple 1D simulation with JAX-Fluids.

In [None]:
import json
import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
from nlfvs.input_reader import InputReader
from nlfvs.initializer import Initializer
from nlfvs.simulation_manager import SimulationManager

In [None]:
case_dict = json.load(open("01_numerical_setup_sod.json"))
numerical_setup_dict = json.load(open("numerical_setup.json"))

input_reader = InputReader(case_dict, numerical_setup_dict)
initializer  = Initializer(input_reader)
sim_manager  = SimulationManager(input_reader)

# PRE SHOCK CONDITIONS
gamma_L, gamma_R = 1.4
rho_R = p_R = 1.0
a_R   = np.sqrt(gamma_R * p_R / rho_R)
u_R   = 0.0
M_R   = u_R / a_R

@jax.jit
def fun(M_S: float = 2.0):
    traj_length = 5
    time_step   = 1e-2
    res = case_dict["nx"]

    dx = 1.0 / res
    x_cf   = jnp.linspace(0, 1, num=res+1)
    x_cc = 0.5 * (x_cf[1:] + x_cf[:-1])

    # POST SHOCK RANKINE HUGONIOT CONDITIONS
    p_L   = p_R * ( 1/(gamma_L + 1) * (gamma_R * (M_R - M_S)**2 + 1) + jnp.sqrt( (1/(gamma_L + 1) * (gamma_R * (M_R - M_S)**2 + 1))**2 - (gamma_L-1)/(gamma_L+1) * ((M_R-M_S)**2 * 2 * gamma_R/(gamma_R - 1) - 1) )) 
    rho_L = rho_R *  (gamma_R - 1)/(gamma_L - 1) * ( p_L / p_R + (gamma_L - 1)/ (gamma_L + 1) ) / ( p_L / p_R * (gamma_R - 1) / (gamma_L + 1) + (gamma_R + 1) / (gamma_L + 1) ) 
    u_L   = a_R * ( rho_R/rho_L * (M_R - M_S) + M_S )

    # INTIAL BUFFER
    prime_init      = jnp.zeros((1, 5, res, 1, 1))
    prime_init      = prime_init.at[0,0,:,0,0].set(jnp.where(x_cc > 0.5, rho_R, rho_L))
    prime_init      = prime_init.at[0,1,:,0,0].set(jnp.where(x_cc > 0.5, u_R, u_L))
    prime_init      = prime_init.at[0,4,:,0,0].set(jnp.where(x_cc > 0.5, p_R, p_L))
    levelset_init   = None

    # FORWARD SIMULATION
    data_series, _ = sim_manager.feed_forward(
        prime_init, 
        levelset_init, 
        traj_length, 
        time_step, 
        0.0, 1, None, None)
    data_series = data_series[0]

    # COMPUTE SCALAR OUTPUT QUANTITY
    entropy = data_series[:,4] / data_series[:,0]**gamma_L
    total_entropy = jnp.mean(data_series[-1,0] * entropy[-1] - data_series[0,0] * entropy[0]) 
    return total_entropy

In [None]:
# RUN THE FUNCTION ONCE FOR COMPILATION
M_S = 2.0
print("TOTAL ENTROPY:", fun(M_S))

# COMPUTE AUTO-DIFFERENTIATION GRADIENT
fun_val_and_grad = jax.value_and_grad(fun)
mass = fun_val_and_grad(jnp.array([M_S]))
print("JAX GRADIENT:", mass)

# COMPUTE FINITE-DIFFERENCE GRADIENTS
eps_list = [1e-1, 3e-2, 1e-2, 3e-3, 1e-3, 3e-4, 1e-4]
err_list = []
for eps in eps_list:
    mass_eps = (fun(M_S + eps) - fun(M_S - eps)) / (2 * eps)
    print("EPS GRADIENT:", mass_eps)
    err_list.append(np.abs(mass[1] - mass_eps))
eps_array = np.array(eps_list)
err_array = np.array(err_list)

In [None]:
# PLOTTING
fig, ax = plt.subplots()
ax.plot(eps_array, err_array, marker="o")
ax.plot(eps_array, 0.5 * err_array[0] / eps_array[0]**2 * eps_array**(2), color="black")
ax.set_xscale("log")
ax.set_yscale("log")
ax.set_xlabel(r"$\epsilon$")
ax.set_ylabel(r"$\vert g_{AD} - g_{FD}^{\epsilon} \vert_1$")
# plt.savefig("./figs/gradient_check_single.png")
plt.show()
plt.close()