## JAX-Fluids Tutorial: Parallel Simulations

JAX-Fluids supports parallel simulations on GPUs and TPUs. A homogenous domain decomposition strategy is used to partition the computational domain and distribute it over the specified XLA devices. The domain decomposition requires specification of the number of devices in the respective axis directions.

For example, $(S_x, S_y, S_z) = (2, 2, 2)$ corresponds to a simulation on 8 XLA devices, where the domain is split into 2 in each spatial axis. $(S_x, S_y, S_z) = (8, 1, 1)$ would also use 8 XLA devices, but in this case the computational domain would only by split in x-direction. By default, $(S_x, S_y, S_z) = (1, 1, 1)$ is used. Only active axes can be split.

Parallel simulations require the "decomposition" key word within the "domain" section in the case setup file. To specify the domain decomposition. For example the "domain" settings for a 3D simulation, where we split the domain into $2 \times 2 \times 2$ blocks would look as follows.

In [None]:
{
    "domain": {
        "x": {
            "cells": 256,
            "range": [0.0, 1.0]
        },
        "y": {
            "cells": 256,
            "range": [0.0, 1.0]
        },
        "z": {
            "cells": 256,
            "range": [0.0, 1.0]
        },
        "decomposition": {
            "split_x": 2,
            "split_y": 2,
            "split_z": 2
        }
    }
}

In multi-host settings, the JAX distributed system needs to be initialized before a JAX-Fluids simulation can start. This is done by using the jax.distributed.initialize() command, following https://docs.jax.dev/en/latest/_autosummary/jax.distributed.initialize.html