# IDAKLU-Jax interface

PyBaMM provides two mechanisms to interface battery models with JAX. The first (JaxSolver) implements PyBaMM models directly in native JAX, and as such provides the greatest flexibility. However, these models can be very slow to compile, especially during their initial run, and can require large amounts of memory.

The second (the IDAKLU-Jax interface) instead provides a JAX-compliant interface to the IDAKLU solver. IDAKLU is a fast (compiled) solver based on SUNDIALS. By exposing the IDAKLU solver to JAX, we provide a fast solver capable of interfacing with third-party JAX-compatible software libraries, such as numpyro.

Despite the apparent advantages, there are some limitations to this approach. The most notable is that model derivatives are limited to first-order (i.e. sensitivities), since the IDAKLU solver is not capable of auto-differentiation.

## Setup a basic DFN model

To demonstrate use of the IDAKLU-Jax interface, we first set-up a basic model, choosing the DFN model in this case. We will provide two `inputs` to the model and will specify a list of variables of interest (`output_variables`). Specifying `output_variables` is a requirement for use of the IDAKLU-Jax interface, though providing `inputs` is not.

In [None]:
import pybamm
import numpy as np
import jax
import jax.numpy as jnp

inputs = {
    "Current function [A]": 0.222,
    "Separator porosity": 0.3,
}

model = pybamm.lithium_ion.DFN()
geometry = model.default_geometry
param = model.default_parameter_values
param.update({key: "[input]" for key in inputs.keys()})
param.process_geometry(geometry)
param.process_model(model)
var = pybamm.standard_spatial_vars
var_pts = {var.x_n: 20, var.x_s: 20, var.x_p: 20, var.r_n: 10, var.r_p: 10}
mesh = pybamm.Mesh(geometry, model.default_submesh_types, var_pts)
disc = pybamm.Discretisation(mesh, model.default_spatial_methods)
disc.process_model(model)
t_eval = np.linspace(0, 360, 10)
idaklu_solver = pybamm.IDAKLUSolver(rtol=1e-6, atol=1e-6)

# Declare which variables to track
output_variables = [
    "Terminal voltage [V]",
    "Discharge capacity [A.h]",
    "Loss of lithium inventory [%]",
]

We jaxify the IDAKLU solver in the same way that we would run any native IDAKLU solve. The only difference is that the `jaxify()` function returns an `IDAKLUJax` object, instead of a `Solution` object. We will keep track of this object, and can request a JAX-expression from it using the `get_jaxpr()` method, as below.

In [None]:
# This is how we would normally create a Solve using IDAKLU
sim = idaklu_solver.solve(
    model,
    t_eval,
    inputs=inputs,
    output_variables=output_variables,
    calculate_sensitivities=True,
)

# Instead, we Jaxify the IDAKLU solver using the same arguments...
jax_solver = idaklu_solver.jaxify(
    model,
    t_eval,
    inputs=inputs,
    output_variables=output_variables,
    calculate_sensitivities=True,
)

# ... and then obtain JAX interface function
f = jax_solver.get_jaxpr()

The JAX expression (that we named `f` in our example) can be used as any other native JAX expression. This means that it can be included in broader JAX expressions, and can even be JIT compiled (though we only support CPU, not GPU or TPU, compilation at present). The only limitation is that derivatives cannot be taken beyond first-order.

Here is the most basic usage example:

In [None]:
# Print all output variables, evaluated over the provided time vector
data = f(t_eval, inputs)
print(data)

In addition to JAX compatibility, the IDAKLU-Jax interface provides several helper functions. Notably, we provide the facility to isolate a single variables from the JAX expression using the `get_var` helper function, or multiple variables provided as a list by using the `get_vars` helper function.

In [None]:
# Isolate a single variables
data = jax_solver.get_var(f, "Terminal voltage [V]")(t_eval, inputs)
print(data)

# Isolate two variables from the solver
data = jax_solver.get_vars([
    "Terminal voltage [V]",
    "Discharge capacity [A.h]",
])(t_eval, inputs)


As with any JAX expression, we build functional expressions and only evaluate the outer-most expression. In other words, we wrap the expression `f` within an enclosing function `jax_solver.get_var(f, ...)` to form a new functional JAX expression. We evaluate that expression by passing our arguments `(t_eval, inputs)` at the end of the expression (so that they are passed to the highest functional).

To compute the Jacobian matrix (the matrix of derivates of output variables with respect to each input parameter), make use of the Jacobian forward derivation `jax.jacfwd` and Jacobian reverse derivation `jax.jacrev` functions.

When calling these functions we note that `argnums=1` signifies that we are taking the Jacobian with respect to the second argument (indexing from 0: `inputs`). Since `inputs` is a dictionary of input parameters, the result will also be a dictionary of derivatives with respect to each dictionary key / input parameter. These two methods (`jacfwd` and `jacrev`) will produce the same output, it is simply their derivation that differs.

In [None]:
# Calculate the Jacobian matrix (via forward autodiff)
out = jax.jacfwd(f, argnums=1)(t_eval, inputs)
print(out)

# Calculate Jacobian matrix (via backward autodiff)
out = jax.jacrev(f, argnums=1)(t_eval, inputs)
print(out)

The gradient (`grad`) on the other hand must return a scalar value. The function must therefore be called with scalar arguments (including scalar time) and can only be evaluted for one output variable at a time. These restrictions can be overcome through use of the `get_var` function, and the `vmap` function (which provides vector-mapping over time), as demonstrated below.

In [None]:
# Example evaluation using the `grad` function
data = jax.vmap(
    jax.grad(
        jax_solver.get_var(f,"Terminal Voltage [V]"),
        argnums=1,  # take derivative with respect to `inputs`
    ),
    in_axes=(0, None)  # map time over the 0th dimension and do not map inputs
)(t_eval, inputs)
print(data)

## A use-case example

As a use-case example, consider a fitting procedure where we want to compare simulation data against some experimental data. We achieve this by computing the sum-of-squared errors between the two. Many fitting procedures will converge more quickly (with fewer iterations) if both the value *and gradient* of the SSE function are provided. By making use of JAX-expressions we can derive these effortlessly.

*Note*: We do not need to map over time when calling `value_and_grad` in this example since the `sse` function returns a scalar (despite taking a vector as input).

In [None]:
# Simulate some experimental data using our original parameter settings
data = sim["Terminal Voltage [V]"](t_eval)

# Sum-of-squared errors
def sse(t, inputs):
    modelled = jax_solver.get_var(f, "Terminal voltage [V]")(t_eval, inputs)
    return jnp.sum((modelled - data) ** 2)

# Provide some predicted model inputs (these could come from a fitting procedure)
inputs_pred = {
    "Current function [A]": 0.150,
    "Separator porosity": 0.333,
}

# Get the value and gradient of the SSE function
value, gradient = jax.value_and_grad(sse, argnums=1)(t_eval, inputs_pred)
print(f"{value=}, {gradient=}")

All of the above expressions can be JIT compiled (onto CPU) by using the `jax.jit` directive. Practically, this provides a wrap-around back to the Python interface of the IDAKLU Solver, so is only provided to afford maximum downstream compatibility (where JIT may be called outside of the user's control). For speed/efficiency reasons it is recommended to avoid the `jax.jit` directive.