# Neural Network Training Using JAX-Fluids

This notebook sketches one possibility of training neural networks that are embedded into the JAX-Fluids solver. We consider the following toy problem:

Given the target velocity profile $u(y)$, i.e., the velocity in x-direction as a function of y, we are interested in learning the forcing term $S_x(y)$. We neglect the convective fluxes and only solve for the dissipative fluxes $F^d$.

$\frac{\partial U}{\partial t} = \nabla \cdot F^d + S_x$

The forcing is modeled by a neural network $NN_{\theta}$ which receives the y-coordinates as input.

$S_x = NN_{\theta}(y)$

$\theta$ are weights and biases of the neural network.

The true forcing is $F_x(y) = \sin (2 * \pi * y)$.

In [None]:
import json

import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

import matplotlib.pyplot as plt

from flax import linen as nn
import optax

from jaxfluids import InputManager, InitializationManager, SimulationManager
from jaxfluids.feed_forward.data_types import FeedForwardSetup
from jaxfluids.data_types.ml_buffers import ParametersSetup
from jaxfluids.callbacks import Callback
from jaxfluids_postprocess import load_data

The neural network is a simple MLP with two hidden layers.

In [None]:
class MLP(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = x[...,None]
        x = nn.Dense(32)(x)
        x = nn.tanh(x)
        x = nn.Dense(32)(x)
        x = nn.tanh(x)
        x = nn.Dense(1)(x)
        x = x[...,0]
        return x

The forcing is implemented via a Callback. Here, the MLP is evaluated and the resulting forcing of the x-velocity is added to the right-hand side buffer (rhs_buffers).

In [None]:
class ForceCallback(Callback):

    def __init__(self, model):
        super().__init__()

        self.model = model

    def after_compute_rhs(self, rhs_buffers, ml_setup, **kwargs):

        params = ml_setup.parameters.callbacks
        
        X, Y = self.domain_information.compute_device_mesh_grid()
        force_u = self.model.apply(params, Y)

        rhs_conservatives = rhs_buffers.euler_buffers.conservatives
        rhs_conservatives = rhs_conservatives.at[1].add(force_u)

        rhs_buffers = rhs_buffers._replace(
            euler_buffers=rhs_buffers.euler_buffers._replace(
                conservatives=rhs_conservatives
            )
        )

        return rhs_buffers

Training data is generated by running the forward simulation.

In [None]:
def generate_training_data():
    input_manager = InputManager(
        "shear_flow.json",
        "numerical_setup.json"
    )
    init_manager = InitializationManager(input_manager)
    sim_manager = SimulationManager(input_manager)

    jxf_buffers = init_manager.initialization()
    sim_manager.simulate(jxf_buffers)

    return sim_manager.output_writer.save_path_case

In [None]:
def load_training_data(data_path):
    jxf_data = load_data(data_path, quantities=["density", "velocity", "pressure"])

    x,y,z = jxf_data.cell_centers
    density = jxf_data.data["density"]
    velocity = jxf_data.data["velocity"]
    pressure = jxf_data.data["pressure"]

    primitives_init = jnp.stack([
        density[0],
        velocity[0,0],
        velocity[0,1],
        jnp.zeros_like(density[0]),
        pressure[0]
    ], axis=0)

    velX_target = jxf_data.data["velocity"][-1,0]

    return primitives_init, velX_target, y

We define the step function which consists of the following steps:
- Execute the forward pass by calling the feed_forward method of the SimulationManager
- Calculates the loss: Here, the squared difference between the target velocity profile and the velocity profile obtained from the forward pass
- Compute the gradient of the loss with respect to network parameters
- Perform a gradient descent step

Note: Gradients are propagated through the entire simulation.

In [None]:
def get_step_function(model, optimizer):

    case_setup_dict = json.load(open("shear_flow.json", "r"))
    numerical_setup_dict = json.load(open("numerical_setup.json", "r"))
    numerical_setup_dict["active_forcings"]["is_custom_forcing"] = False

    force_cb = ForceCallback(model)

    input_manager_train = InputManager(
        case_setup_dict, numerical_setup_dict
    )
    sim_manager_train = SimulationManager(
        input_manager_train,
        callbacks=force_cb
    )

    feed_forward_setup = FeedForwardSetup(
        outer_steps=200, inner_steps=1, is_scan=True
    ) 

    def loss_fn(params, primitives_init, velX_target):
        parameters_setup = ParametersSetup(callbacks=params)

        solution_array, _ = sim_manager_train._feed_forward(
            primes_init=primitives_init,
            physical_timestep_size=1.5e-3,
            t_start=0.0,
            feed_forward_setup=feed_forward_setup,
            ml_parameters=parameters_setup
        )

        primitives = solution_array["primitives"][-1]
        domain_information = sim_manager_train.domain_information
        nhx, nhy, nhz = domain_information.domain_slices_conservatives
        primitives = primitives[...,nhx,nhy,nhz]
        velX = primitives[1]

        loss = jnp.sum(jnp.square(velX - velX_target))
        return loss, velX
    
    def step_function(params, opt_state, primitives_init, velX_target):
        value_and_grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (loss, velX), grad = value_and_grad_fn(params, primitives_init, velX_target)
        updates, opt_state = optimizer.update(grad, opt_state)
        params = optax.apply_updates(params, updates)
        return params, opt_state, loss, grad, velX

    return step_function

In the training loop, we use the Adam optimizer.

In [None]:
def train_model(data_path):

    forcing = lambda y: jnp.sin(2 * jnp.pi * y)

    primitives_init, velX_target, y = load_training_data(data_path)

    model = MLP()
    params = model.init(jax.random.key(42), jnp.ones((1,1,1)))
    start_learning_rate = 1e-3
    optimizer = optax.adam(start_learning_rate)
    opt_state = optimizer.init(params)

    step_fn = get_step_function(model, optimizer)
    step_fn = jax.jit(step_fn)

    loss_history = []

    for i in range(10000):
        params, opt_state, loss, grad, velX = step_fn(
            params, opt_state, primitives_init, velX_target
        )
        loss_history.append(loss)

        if i % 100 == 0:
            title_str = f"STEP = {i}, LOSS = {loss:.2e}"
            print(title_str)

            fig, ax = plt.subplots(ncols=3, figsize=(24,5))
            fig.suptitle(title_str)
            ax[0].plot(y, velX_target[0,:,0], label="target")
            ax[0].plot(y, primitives_init[1,0,:,0], color="gray", linestyle="--", label=r"$t_0$")
            ax[0].plot(y, velX[0,:,0], linestyle="--", label=r"$t_N$")
            ax[0].legend()
            ax[0].set_ylim([-0.2, 0.2])
            ax[0].set_xlabel(r"$y$")
            ax[0].set_ylabel(r"$u$")

            ax[1].plot(y, forcing(y), label="True forcing")
            ax[1].plot(y, model.apply(params, y), linestyle="--", label="NN forcing")
            ax[1].legend()
            ax[1].set_ylim([-1.25, 1.25])
            ax[1].set_xlabel(r"$y$")
            ax[1].set_ylabel(r"$F_x$")

            ax[2].plot(jnp.array(loss_history))
            ax[2].set_xlabel("Steps")
            ax[2].set_ylabel("Loss")
            ax[2].set_yscale("log")

            plt.show()


As we can see, the neural network fairly quickly learns a forcing which results in a good approximation of the target velocity profile. Due to the dissipative nature of the underlying process, the final velocity profile is not very sensitive to small scale changes in the forcing term.

In [None]:
data_path = generate_training_data()
train_model(data_path)