# JAX-Fluids: Rusanov-NN 
This notebook demonstrates how you can use neural networks during a simulation with JAX-Fluids. We simulate a 2D Riemann problem with a data-driven variant of the classical Rusanov (Local Lax-Friedrichs) Riemann solver. We will compare the performance of the Rusanov scheme with the Rusanov-NN scheme.

If we want to use a machine-learning supported numerical model in JAX-Fluids we have two options:
- Encode the network and the pre-trained network weights into the corresponding function
- Pass the network and network weigths into JAX-Fluids via the buffer_dictionary

In particular, the simulate() method of the SimulationManager gets as input the buffer_dictionary with holds the inital physical fields. Under the key machinelearning_modules, the buffer_dictionary holds two further dictionaries:
- buffer_dictionary
    - machinelearning_modules
        - ml_parameters_dict
        - ml_networks_dict
        
ml_parameters_dict and ml_networks_dict are passed to most of the compute-heavy subroutines (e.g. cell-face reconstruction, Riemann solver, right-hand-side evaluation) in JAX-Fluids. Therefore, user-specified parameters and networks can simply be added to the machinelearning_modules dictionary. It is the users task to implement how these neural networks are then used in JAX-Fluids.

In this notebook, we demonstrate how the Rusanov-NN Riemann solver can be used by passing a multi-layer perceptron (build in haiku) and pre-trained weights to JAX-Fluids. 

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pickle
import jax
import jax.numpy as jnp
import haiku as hk
from jaxfluids import InputReader, Initializer, SimulationManager
from jaxfluids.post_process import load_data, create_lineplot

# Simulation Setup
We load the case setup file and the numerical setup file.

In [None]:
import json
case_setup = json.load(open("07_case_setup_riemann2D.json"))
numerical_setup = json.load(open("07_numerical_setup_riemann2D.json"))

## HLLC 128x128
First, we conduct a highly-resolved simulation (128x128 cells) with the HLLC scheme. We choose the HLLC Riemann solver in the numerical setup and set the resolution in x- and y-direction in the case setup file accordingly.

In [None]:
numerical_setup["conservatives"]["convective_fluxes"]["riemann_solver"] = "HLLC"
for xi in ["x", "y"]:
    case_setup["domain"][xi]["cells"] = 128
input_reader = InputReader(case_setup, numerical_setup)
initializer  = Initializer(input_reader)
sim_manager  = SimulationManager(input_reader)
buffer_dictionary = initializer.initialization()
sim_manager.simulate(buffer_dictionary)
path = sim_manager.output_writer.save_path_domain
quantities = ["density", "velocity", "pressure"]
cell_centers_128, _, times_128, data_dict_128 = load_data(path, quantities)

## Rusanov 32x32
Second, we conduct a coarse simulation (32x32 cells) with the classical Rusanov scheme. We choose the Rusanov Riemann solver in the numerical setup file and set the resolution in x- and y-direction in the case setup file accordingly.

In [None]:
numerical_setup["conservatives"]["convective_fluxes"]["riemann_solver"] = "RUSANOV"
for xi in ["x", "y"]:
    case_setup["domain"][xi]["cells"] = 32
input_reader = InputReader(case_setup, numerical_setup)
initializer  = Initializer(input_reader)
sim_manager  = SimulationManager(input_reader)
buffer_dictionary = initializer.initialization()
sim_manager.simulate(buffer_dictionary)
path = sim_manager.output_writer.save_path_domain
quantities = ["density", "velocity", "pressure"]
cell_centers_32, _, times_32, data_dict_32 = load_data(path, quantities)

## Rusanov-NN 32x32
Third, we conduct a coarse simulation (32x32 cells) with the Rusanov-NN scheme. We choose the Rusanov-NN Riemann solver in the numerical setup file and set the resolution in x- and y-direction in the case setup file accordingly.

In [None]:
numerical_setup["conservatives"]["convective_fluxes"]["riemann_solver"] = "RUSANOVNN"
for xi in ["x", "y"]:
    case_setup["domain"][xi]["cells"] = 32
input_reader = InputReader(case_setup, numerical_setup)
initializer  = Initializer(input_reader)
sim_manager  = SimulationManager(input_reader)
buffer_dictionary = initializer.initialization()

The buffer_dictionary is a dictionary which is returned by the initializer. The dictionary has the following keys:
- material_fields
- levelset_quantities
- mass_flow_forcing
- machinelearning_modules
- time_control

The item machinelearning_modules holds network parameters and network callables. It is again a dictionary with the following two sub-keys:
- machinelearning_modules
    - ml_parameters_dict
    - ml_networks_dict

In [None]:
print("Buffer dictionary:\n  ", buffer_dictionary.keys())
print("Machine-learning modules:\n   ", buffer_dictionary["machinelearning_modules"].keys())

## Build network in haiku
We build a simple multi-layer perceptron in haiku.

In [None]:
def net_fn(x_in):
    """Multi-layer perceptron """
    x = jnp.transpose(x_in[:, :, :, 0])
    mlp = hk.Sequential([
        hk.Linear(32), jax.nn.relu,
        hk.Linear(32), jax.nn.relu,
        hk.Linear(1), 
    ])
    x_out = jnp.exp(mlp(x))
    x_out = jnp.expand_dims(jnp.transpose(x_out), axis=-1)
    return x_out
net = hk.without_apply_rng(hk.transform(net_fn))

## Load network parameters 
We load the pre-trained network parameters from a pickle file.

In [None]:
with open("07_RusanovNN_params.pkl", "rb") as file:
    ckpt = pickle.load(file)
    params = ckpt["params"]

In [None]:
ml_params_dict = {"riemannsolver": params}
buffer_dictionary["machinelearning_modules"]["ml_parameters_dict"] = ml_params_dict
ml_networks_dict = hk.data_structures.to_immutable_dict({"riemannsolver": net})
buffer_dictionary["machinelearning_modules"]["ml_networks_dict"] = ml_networks_dict
sim_manager.simulate(buffer_dictionary)
path = sim_manager.output_writer.save_path_domain
quantities = ["density", "velocity", "pressure"]
cell_centers_NN32, _, times_NN32, data_dict_NN32 = load_data(path, quantities)

## Visualize the results
We visualize the HLLC 

In [None]:
plot_times = [0.05, 0.15, 0.2, 0.5, 1.0, 2.0]
cmap = "seismic"
fig, ax = plt.subplots(figsize=(15,10), nrows=3, ncols=6)
for ii, plot_time in enumerate(plot_times):
    plot_id = np.argmin(np.abs(times_128 - plot_time))
    abs_vel_128 = np.sqrt(np.sum(data_dict_128["velocity"][plot_id,:,:,:,0]**2, axis=0))
    vmin = np.min(abs_vel_128)
    vmax = np.max(abs_vel_128)
    ax[0,ii].imshow(abs_vel_128.T, origin="lower", vmin=vmin, vmax=vmax, cmap=cmap)
    plot_id = np.argmin(np.abs(times_32 - plot_time))
    abs_vel_32 = np.sqrt(np.sum(data_dict_32["velocity"][plot_id,:,:,:,0]**2, axis=0))
    ax[1,ii].imshow(abs_vel_32.T, origin="lower", vmin=vmin, vmax=vmax, cmap=cmap)
    plot_id = np.argmin(np.abs(times_NN32 - plot_time))
    abs_vel_NN32 = np.sqrt(np.sum(data_dict_NN32["velocity"][plot_id,:,:,:,0]**2, axis=0))
    ax[2,ii].imshow(abs_vel_NN32.T, origin="lower", vmin=vmin, vmax=vmax, cmap=cmap)
for axi in ax.flatten():
    axi.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
ax[0,0].set_ylabel("EXACT", fontsize=12)
ax[1,0].set_ylabel("RUSANOV", fontsize=12)
ax[2,0].set_ylabel("RUSANOV-NN", fontsize=12)
plt.show()