# LagrangeBench dataset generation

This code was used to generate the 7 datasets from the LagrangeBench paper.
It builds on and extends the functionality of the original code by Fabian Fritz in [this notebook](taylor-green-vortex_3d_sph.ipynb).

In [1]:
from argparse import Namespace
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import jax.numpy as jnp
import h5py
import numpy as np
from jax import grad, jit, vmap, random
from jax_md import space
from jax_md.partition import Sparse


from jax import jit, vmap, grad, ops
from jax_md import space, partition

import jax.numpy as jnp
import numpy as onp

from functools import partial, namedtuple

from lagrangebench.case_setup import partition

2024-02-11 03:18:16.315188: W external/xla/xla/service/gpu/nvptx_compiler.cc:679] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.3.107). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [2]:
class Kernel:
    """The kernel object is a polymorphic base class for all derived kernels."""

    def __init__(self, h, dim):
        self._dim = dim
        self._h = h
        self._one_over_h = 1.0 / h

    def w(self, r):
        """Evaluates the kernel at the radial coordinate r."""
        return self._w(r)

    def grad_w(self, r):
        """Evaluates the kernel gradient at the radial coordinate r by utilization of automatic differentation."""
        return grad(self.w)(r)

class QuinticKernel(Kernel):
    """The quintic kernel function of Morris."""

    def __init__(self, h, dim = 3):
        Kernel.__init__(self, h, dim)
        self._normalized_cutoff = 3.0
        self.cutoff = self._normalized_cutoff * h
        self._sigma_1d = 120 * self._one_over_h
        self._sigma_2d = 7.0 / 478.0 / jnp.pi * self._one_over_h * self._one_over_h
        self._sigma_3d = 3.0 / 359.0 / jnp.pi * self._one_over_h * self._one_over_h * self._one_over_h
        self._sigma = jnp.where(dim == 3, self._sigma_3d, jnp.where(dim == 2, self._sigma_2d, self._sigma_1d))

    def _w(self, r):

        q = r * self._one_over_h

        q1 = jnp.maximum(0.0, 1.0 - q)
        q2 = jnp.maximum(0.0, 2.0 - q)
        q3 = jnp.maximum(0.0, 3.0 - q)

        return self._sigma * (q3 * q3 * q3 * q3 * q3 - 6.0 * q2 * q2 * q2 * q2 * q2 + 15.0 * q1 * q1 * q1 * q1 * q1) 

def write_h5(data_dict, path):
    """Write a dict of numpy or jax arrays to a .h5 file"""
    hf = h5py.File(path, "w")
    for k, v in data_dict.items():
        hf.create_dataset(k, data=np.array(v))
    hf.close()

def pos_init_cartesian_2d(box_size, dx):
    n = np.array((box_size / dx).round(), dtype=int)
    grid = np.meshgrid(range(n[0]), range(n[1]), indexing="xy")
    r = (jnp.vstack(list(map(jnp.ravel, grid))).T + 0.5) * dx
    return r

def pos_init_cartesian_3d(box_size, dx):
    n = np.array((box_size / dx).round(), dtype=int)
    grid = np.meshgrid(range(n[0]), range(n[1]), range(n[2]), indexing="xy")
    r = (jnp.vstack(list(map(jnp.ravel, grid))).T + 0.5) * dx
    return r


In [3]:
def SPH(displacement_fn, g_ext_fn, args):
    """Smoothed Particle Hydrodynamics solver"""
    dx, dim, dt, is_bc_trick, is_rho_evol, artificial_alpha, p_bg_factor, u_ref, rho_ref, mass, viscosity = args.dx, args.dim, args.dt, args.is_bc_trick, args.density_evolution, args.artificial_alpha, args.p_bg_factor, args.u_ref, args.rho_ref, args.mass, args.viscosity

    kernel_fn = QuinticKernel(h=dx, dim=dim)
    c_ref = 10.0 * u_ref
    p_ref = rho_ref * c_ref * c_ref
    p_ref = rho_ref * c_ref * c_ref
    p_background = p_bg_factor * p_ref

    @jit
    def _pressure_fn(density, density_reference, speed_of_sound, background_pressure):
        return speed_of_sound * speed_of_sound * (density - density_reference) + background_pressure
    pressure_fn = partial(_pressure_fn, density_reference=rho_ref, speed_of_sound=c_ref, background_pressure=p_background)

    @jit
    def _density_fn(pressure, density_reference, speed_of_sound, background_pressure):
        return (pressure - background_pressure) / speed_of_sound / speed_of_sound + density_reference
    density_fn = partial(_density_fn, density_reference=rho_ref, speed_of_sound=c_ref, background_pressure=p_background)


    def forward(state, neighbors):
        r, tag, u, dudt, rho, p = state["r"], state["tag"], state["u"], state["dudt"], state["rho"], state["p"]
        N = len(r)
        i_s, j_s = neighbors.idx
        r_i_s, r_j_s = r[i_s], r[j_s]
        dr_i_j = vmap(displacement_fn)(r_i_s, r_j_s)
        dist = space.distance(dr_i_j)
        w_dist = vmap(kernel_fn.w)(dist)

        grad_w_dist =  vmap(kernel_fn.grad_w)(dist)[:, None] * dr_i_j / (dist[:, None] + jnp.finfo(float).eps)
        g_ext = g_ext_fn(r)

        # density and pressure
        if is_rho_evol:
            drhodt = rho * ops.segment_sum((mass / rho)[j_s] * ((u[i_s] - u[j_s]) * grad_w_dist).sum(axis=1), i_s, N)
            rho = rho + dt * drhodt
        else:
            rho = mass * ops.segment_sum(w_dist, i_s, N)
        p = vmap(pressure_fn)(rho)

        if is_bc_trick: 
            # Based on: "A generalized wall boundary condition [...]", Adami, Hu, Adams, 2012
            w_j_s_fluid = w_dist * jnp.where(tag[j_s] == 0, 1.0, 0.0)
            # no-slip boundary condition
            u_wall = ops.segment_sum(w_j_s_fluid[:, None] * u[j_s], i_s, N) / ( ops.segment_sum(w_j_s_fluid, i_s, N)[:, None] + jnp.finfo(float).eps)
            u = jnp.where(tag[:, None] > 0, 2 * u - u_wall, u)

            p_wall = ( ops.segment_sum(w_j_s_fluid * p[j_s], i_s, N) + (g_ext * ops.segment_sum((rho[j_s] * w_j_s_fluid)[:, None] * dr_i_j, i_s, N)).sum(axis=1)) / (ops.segment_sum(w_j_s_fluid, i_s, N) + jnp.finfo(float).eps)
            p = jnp.where(tag > 0, p_wall, p)
            rho = vmap(density_fn)(p)

        def acceleration_fn(r_ij,d_ij,rho_i, rho_j, u_i, u_j,p_i,p_j):
            p_ij = (rho_j * p_i + rho_i * p_j) / (rho_i + rho_j)
            return ((mass / rho_i) ** 2 + (mass / rho_j) ** 2) / mass * kernel_fn.grad_w(d_ij) / (d_ij + jnp.finfo(float).eps) * (-p_ij * r_ij + viscosity * (u_i - u_j))

        res = vmap(acceleration_fn)(dr_i_j, dist, rho[i_s], rho[j_s], u[i_s], u[j_s], p[i_s], p[j_s])
        dudt = ops.segment_sum(res, i_s, N)

        if artificial_alpha != 0.0:
            numerator = (mass * artificial_alpha * dx * c_ref * ((u[i_s] - u[j_s]) * dr_i_j).sum(axis=1))[:, None] * grad_w_dist
            denominator = ((rho[i_s] + rho[j_s]) / 2 * (dist**2 + 0.01 * dx**2))[:, None]
            res = jnp.where((tag[j_s] == 0) * (tag[i_s] == 0), 1.0, 0.0)[:, None] * numerator / denominator
            dudt_artif = ops.segment_sum(res, i_s, N)
        else:
            dudt_artif = jnp.zeros_like(dudt)

        return {"r": r, "tag": tag, "u": u, "dudt": dudt + g_ext + dudt_artif, "rho": rho, "p": p}

    return forward

## Particle Relaxations

The trajectories in, e.g., 2D Taylor-Green vortex differ in the initial particle configuration. To get different physical configurations, for each trajectory we run an SPH relaxations of 5000 steps starting from cartesian coordinates plus some Gaussian noise on top. Obtaining a relaxed state for a 2D TGV simulation is demonstrated below.

In [4]:
args = Namespace(
    case="Rlx", 
    solver="SPH",
    dim=2, 
    dx=0.02, 
    dt=0.0,
    t_end=0.2,
    box_size=np.array([1.0, 1.0]), 
    pbc=np.array([True, True]),
    seed=0,
    write_h5=True,
    write_every=5000,
    r0_noise_factor=0.25,
    viscosity=0.01,
    relax_pbc=True,
    data_path="data_relaxed",
    p_bg_factor=0.0,
    is_bc_trick=False,
    density_evolution=False, 
    artificial_alpha=0.0,
    free_slip=False,
    u_ref=1.0,
    rho_ref=1.0,
    sequence_length=5000,
)
args.mass = args.dx ** args.dim * args.rho_ref
args.dt = args.dt if args.dt > 0 else 0.25*args.dx / (11*args.u_ref)

g_ext_fn = lambda r: jnp.zeros_like(r)
bc_fn = lambda state: state
key = random.PRNGKey(args.seed)

displacement_fn, shift_fn = space.periodic(side=args.box_size)
r_init = pos_init_cartesian_2d(args.box_size, args.dx)
noise = args.r0_noise_factor * args.dx * random.normal(key, r_init.shape)
r_init = shift_fn(r_init, noise)

state = {
    "r": r_init,
    "u": jnp.zeros(r_init.shape),
    "tag": jnp.zeros(r_init.shape[0]),
    "dudt": jnp.zeros(r_init.shape),
    "rho": jnp.ones(r_init.shape[0]),
    "p": jnp.zeros(r_init.shape[0]),
}


# Initialize a neighbor list
neighbor_fn = partition.neighbor_list(
    displacement_fn,
    args.box_size,
    r_cutoff=3 * args.dx,
    backend="jaxmd_vmap",
    capacity_multiplier=1.25,
    mask_self=False,
    format=Sparse,
    num_particles_max=state["r"].shape[0],
    pbc=args.pbc,
)
neighbors = neighbor_fn.allocate(state["r"], num_particles=(state["tag"] != -1).sum())

# create data directory
os.makedirs(args.data_path, exist_ok=True)

@partial(jit, static_argnums=(3, 4, 5))
def advance(dt, state, neighbors, sph, shift_fn, bc_fn):
    state["u"] += 1.0 * dt * state["dudt"]
    state["r"] = shift_fn(state["r"], 1.0 * dt * state["u"])
    num_particles = (state["tag"] != -1).sum()
    neighbors = neighbors.update(state["r"], num_particles=num_particles)

    # update time-derivatives
    state = sph(state, neighbors)
    state = bc_fn(state)
    return state, neighbors

solver = SPH(displacement_fn, g_ext_fn, args)
for step in range(args.sequence_length + 1):
    if step % args.write_every == 0:
        write_h5(state, os.path.join(args.data_path, f"state_{step}.h5"))
    state, neighbors = advance(args.dt, state, neighbors, solver, shift_fn, bc_fn)

    # Check whether the edge list is too small and if so, create longer one
    if neighbors.did_buffer_overflow:
        neighbors = neighbor_fn.allocate(state["r"], num_particles=(state["tag"] != -1).sum())

    if step % args.write_every == 0: 
        u_max = jnp.sqrt(jnp.square(state["u"]).sum(axis=1)).max()
        print(f"{step} / {args.sequence_length}, u_max = {u_max:.4f}")


0 / 5000, u_max = 0.0000
5000 / 5000, u_max = 0.0092


## 2D TGV

In [5]:
args = Namespace(
    case="TGV", 
    solver="SPH",
    dim=2, 
    dx=0.02, 
    dt=0.0004,
    t_end=5,
    box_size=np.array([1.0, 1.0]), 
    pbc=np.array([True, True]),
    seed=0,
    write_h5=True,
    write_every=100,
    r0_noise_factor=0.0,
    viscosity=0.01,
    relax_pbc=True,
    data_path="datasets/2D_TGV_2500_10kevery100",
    p_bg_factor=0.0,
    is_bc_trick=False,
    density_evolution=False, 
    artificial_alpha=0.0,
    u_ref=1.0,
    rho_ref=1.0,
    sequence_length=1000,  # this one should go to 12500
)
args.mass = args.dx ** args.dim * args.rho_ref
args.dt = args.dt if args.dt > 0 else 0.25*args.dx / (11*args.u_ref)

g_ext_fn = lambda r: jnp.zeros_like(r)
bc_fn = lambda state: state
key = random.PRNGKey(args.seed)

# load relaxed state
r_init = np.array(h5py.File("data_relaxed/state_5000.h5", "r")['r'])
state = {
    "r": r_init,
    "u": jnp.array([ -1.0 * jnp.cos(2.0 * jnp.pi * r_init[:,0]) * jnp.sin(2.0 * jnp.pi * r_init[:,1]),
                     +1.0 * jnp.sin(2.0 * jnp.pi * r_init[:,0]) * jnp.cos(2.0 * jnp.pi * r_init[:,1])]).T,
    "tag": jnp.zeros(r_init.shape[0]),
    "dudt": jnp.zeros(r_init.shape),
    "rho": jnp.ones(r_init.shape[0]),
    "p": jnp.zeros(r_init.shape[0]),
}

displacement_fn, shift_fn = space.periodic(side=args.box_size)
noise = args.r0_noise_factor * args.dx * random.normal(key, state["r"].shape)
state["r"] = shift_fn(state["r"], noise)

# Initialize a neighbor list
neighbor_fn = partition.neighbor_list(
    displacement_fn,
    args.box_size,
    r_cutoff=3 * args.dx,
    backend="jaxmd_vmap",
    capacity_multiplier=1.25,
    mask_self=False,
    format=Sparse,
    num_particles_max=state["r"].shape[0],
    pbc=args.pbc,
)
neighbors = neighbor_fn.allocate(state["r"], num_particles=(state["tag"] != -1).sum())

# create data directory
os.makedirs(args.data_path, exist_ok=True)

@partial(jit, static_argnums=(3, 4, 5))
def advance(dt, state, neighbors, sph, shift_fn, bc_fn):
    state["u"] += 1.0 * dt * state["dudt"]
    state["r"] = shift_fn(state["r"], 1.0 * dt * state["u"])
    num_particles = (state["tag"] != -1).sum()
    neighbors = neighbors.update(state["r"], num_particles=num_particles)

    # update time-derivatives
    state = sph(state, neighbors)
    state = bc_fn(state)
    return state, neighbors

solver = SPH(displacement_fn, g_ext_fn, args)
for step in range(args.sequence_length + 1):
    if step % args.write_every == 0:
        write_h5(state, os.path.join(args.data_path, f"state_{step}.h5"))
    state, neighbors = advance(args.dt, state, neighbors, solver, shift_fn, bc_fn)

    # Check whether the edge list is too small and if so, create longer one
    if neighbors.did_buffer_overflow:
        neighbors = neighbor_fn.allocate(state["r"], num_particles=(state["tag"] != -1).sum())

    if step % args.write_every == 0: 
        u_max = jnp.sqrt(jnp.square(state["u"]).sum(axis=1)).max()
        print(f"{step} / {args.sequence_length}, u_max = {u_max:.4f}")


0 / 1000, u_max = 1.0000
100 / 1000, u_max = 0.9785
200 / 1000, u_max = 0.9655
300 / 1000, u_max = 0.9315
400 / 1000, u_max = 0.9077
500 / 1000, u_max = 0.8852
600 / 1000, u_max = 0.8356
700 / 1000, u_max = 0.8147
800 / 1000, u_max = 0.7770
900 / 1000, u_max = 0.7494
1000 / 1000, u_max = 0.7230


## 3D TGV

In [6]:
args = Namespace(
    case="TGV", 
    solver="SPH",
    dim=3, 
    dx=0.314159265, 
    dt=0.005,
    t_end=30,
    box_size=np.array([2*np.pi, 2*np.pi, 2*np.pi]), 
    pbc=np.array([True, True, True]),
    seed=0,
    write_h5=True,
    write_every=100,
    r0_noise_factor=0.0,
    viscosity=0.02,
    relax_pbc=True,
    data_path="datasets/3D_TGV_8000_10kevery100",
    p_bg_factor=0.0,
    is_bc_trick=False,
    density_evolution=False, 
    artificial_alpha=0.0,
    u_ref=1.0,
    rho_ref=1.0,
    sequence_length=1000,  # this one should go to 6000
)
args.mass = args.dx ** args.dim * args.rho_ref
args.dt = args.dt if args.dt > 0 else 0.25*args.dx / (11*args.u_ref)

g_ext_fn = lambda r: jnp.zeros_like(r)
bc_fn = lambda state: state
key = random.PRNGKey(args.seed)

r_init = pos_init_cartesian_3d(args.box_size, args.dx)  # replace with relaxed state
state = {
    "r": r_init,
    "u": jnp.array([+jnp.sin(r_init[:,0]) * jnp.cos(r_init[:,1]) * jnp.cos(r_init[:,2]),
                    -jnp.cos(r_init[:,0]) * jnp.sin(r_init[:,1]) * jnp.cos(r_init[:,2]),
                    0*r_init[:,1]]).T,
    "tag": jnp.zeros(r_init.shape[0]),
    "dudt": jnp.zeros(r_init.shape),
    "rho": jnp.ones(r_init.shape[0]),
    "p": jnp.zeros(r_init.shape[0]),
}

displacement_fn, shift_fn = space.periodic(side=args.box_size)
noise = args.r0_noise_factor * args.dx * random.normal(key, state["r"].shape)
state["r"] = shift_fn(state["r"], noise)

# Initialize a neighbor list
neighbor_fn = partition.neighbor_list(
    displacement_fn,
    args.box_size,
    r_cutoff=3 * args.dx,
    backend="jaxmd_vmap",
    capacity_multiplier=1.25,
    mask_self=False,
    format=Sparse,
    num_particles_max=state["r"].shape[0],
    pbc=args.pbc,
)
neighbors = neighbor_fn.allocate(state["r"], num_particles=(state["tag"] != -1).sum())

# create data directory
os.makedirs(args.data_path, exist_ok=True)

@partial(jit, static_argnums=(3, 4, 5))
def advance(dt, state, neighbors, sph, shift_fn, bc_fn):
    state["u"] += 1.0 * dt * state["dudt"]
    state["r"] = shift_fn(state["r"], 1.0 * dt * state["u"])
    num_particles = (state["tag"] != -1).sum()
    neighbors = neighbors.update(state["r"], num_particles=num_particles)

    # update time-derivatives
    state = sph(state, neighbors)
    state = bc_fn(state)
    return state, neighbors

solver = SPH(displacement_fn, g_ext_fn, args)
for step in range(args.sequence_length + 1):
    if step % args.write_every == 0:
        write_h5(state, os.path.join(args.data_path, f"state_{step}.h5"))
    state, neighbors = advance(args.dt, state, neighbors, solver, shift_fn, bc_fn)

    # Check whether the edge list is too small and if so, create longer one
    if neighbors.did_buffer_overflow:
        neighbors = neighbor_fn.allocate(state["r"], num_particles=(state["tag"] != -1).sum())

    if step % args.write_every == 0: 
        u_max = jnp.sqrt(jnp.square(state["u"]).sum(axis=1)).max()
        print(f"{step} / {args.sequence_length}, u_max = {u_max:.4f}")


0 / 1000, u_max = 0.9638
100 / 1000, u_max = 0.9366
200 / 1000, u_max = 0.7744
300 / 1000, u_max = 0.9897
400 / 1000, u_max = 0.8363
500 / 1000, u_max = 0.6221
600 / 1000, u_max = 0.6590
700 / 1000, u_max = 0.6121
800 / 1000, u_max = 0.4958
900 / 1000, u_max = 0.4383
1000 / 1000, u_max = 0.3921


# 2D RPF

In [7]:
args = Namespace(
    case="RPF", 
    solver="SPH",
    dim=2,
    dx=0.025,
    dt=0.0005,
    t_end=2050,
    box_size=np.array([1, 2.]), 
    pbc=np.array([True, True]),
    seed=0,
    write_h5=True,
    write_every=100,
    r0_noise_factor=0.0,
    viscosity=0.01,
    relax_pbc=True,
    data_path="datasets/2D_RPF_3200_20kevery100",
    p_bg_factor=0.05,
    is_bc_trick=False,
    density_evolution=False, 
    artificial_alpha=0.0,
    u_ref=1.0,
    rho_ref=1.0,
    sequence_length=1000,  # this one should go to 4.100.000
)
args.mass = args.dx ** args.dim * args.rho_ref
args.dt = args.dt if args.dt > 0 else 0.25*args.dx / (11*args.u_ref)

g_ext_fn = lambda r: jnp.where(r[:, 1] > 1.0, -1.0, 1.0)[:, None] * jnp.array([1.0, 0.0])
bc_fn = lambda state: state
key = random.PRNGKey(args.seed)

r_init = pos_init_cartesian_2d(args.box_size, args.dx)  # replace with relaxed state
state = {
    "r": r_init,
    "u": jnp.zeros_like(r_init),
    "tag": jnp.zeros(r_init.shape[0]),
    "dudt": jnp.zeros_like(r_init),
    "rho": jnp.ones(r_init.shape[0]),
    "p": jnp.zeros(r_init.shape[0]),
}

displacement_fn, shift_fn = space.periodic(side=args.box_size)

# Initialize a neighbor list
neighbor_fn = partition.neighbor_list(
    displacement_fn,
    args.box_size,
    r_cutoff=3 * args.dx,
    backend="jaxmd_vmap",
    capacity_multiplier=1.25,
    mask_self=False,
    format=Sparse,
    num_particles_max=state["r"].shape[0],
    pbc=args.pbc,
)
neighbors = neighbor_fn.allocate(state["r"], num_particles=(state["tag"] != -1).sum())

# create data directory
os.makedirs(args.data_path, exist_ok=True)

@partial(jit, static_argnums=(3, 4, 5))
def advance(dt, state, neighbors, sph, shift_fn, bc_fn):
    state["u"] += 1.0 * dt * state["dudt"]
    state["r"] = shift_fn(state["r"], 1.0 * dt * state["u"])
    num_particles = (state["tag"] != -1).sum()
    neighbors = neighbors.update(state["r"], num_particles=num_particles)

    # update time-derivatives
    state = sph(state, neighbors)
    state = bc_fn(state)
    return state, neighbors

solver = SPH(displacement_fn, g_ext_fn, args)
for step in range(args.sequence_length + 1):
    if step % args.write_every == 0:
        write_h5(state, os.path.join(args.data_path, f"state_{step}.h5"))
    state, neighbors = advance(args.dt, state, neighbors, solver, shift_fn, bc_fn)

    # Check whether the edge list is too small and if so, create longer one
    if neighbors.did_buffer_overflow:
        neighbors = neighbor_fn.allocate(state["r"], num_particles=(state["tag"] != -1).sum())

    if step % args.write_every == 0: 
        u_max = jnp.sqrt(jnp.square(state["u"]).sum(axis=1)).max()
        print(f"{step} / {args.sequence_length}, u_max = {u_max:.4f}")


0 / 1000, u_max = 0.0000
100 / 1000, u_max = 0.0500
200 / 1000, u_max = 0.1000
300 / 1000, u_max = 0.1500
400 / 1000, u_max = 0.2000
500 / 1000, u_max = 0.2500
600 / 1000, u_max = 0.3000
700 / 1000, u_max = 0.3500
800 / 1000, u_max = 0.4000
900 / 1000, u_max = 0.4500
1000 / 1000, u_max = 0.5001


## 3D RPF

In [8]:
args = Namespace(
    case="RPF", 
    solver="SPH",
    dim=3,
    dx=0.05,
    dt=0.001,
    t_end=2050,
    box_size=np.array([1, 2, 0.5]), 
    pbc=np.array([True, True, True]),
    seed=0,
    write_h5=True,
    write_every=100,
    r0_noise_factor=0.0,
    viscosity=0.1,
    relax_pbc=True,
    data_path="datasets/3D_RPF_8000_10kevery100",
    p_bg_factor=0.02,
    is_bc_trick=False,
    density_evolution=False, 
    artificial_alpha=0.0,
    u_ref=1.0,
    rho_ref=1.0,
    sequence_length=1000,  # this one should go to 2.050.000
)
args.mass = args.dx ** args.dim * args.rho_ref
args.dt = args.dt if args.dt > 0 else 0.25*args.dx / (11*args.u_ref)

g_ext_fn = lambda r: jnp.where(r[:, 1] > 1.0, -1.0, 1.0)[:, None] * jnp.array([1.0, 0.0, 0.0])
bc_fn = lambda state: state
key = random.PRNGKey(args.seed)

r_init = pos_init_cartesian_3d(args.box_size, args.dx)  # replace with relaxed state
state = {
    "r": r_init,
    "u": jnp.zeros_like(r_init),
    "tag": jnp.zeros(r_init.shape[0]),
    "dudt": jnp.zeros_like(r_init),
    "rho": jnp.ones(r_init.shape[0]),
    "p": jnp.zeros(r_init.shape[0]),
}

displacement_fn, shift_fn = space.periodic(side=args.box_size)
noise = args.r0_noise_factor * args.dx * random.normal(key, state["r"].shape)
state["r"] = shift_fn(state["r"], noise)

# Initialize a neighbor list
neighbor_fn = partition.neighbor_list(
    displacement_fn,
    args.box_size,
    r_cutoff=3 * args.dx,
    backend="jaxmd_vmap",
    capacity_multiplier=1.25,
    mask_self=False,
    format=Sparse,
    num_particles_max=state["r"].shape[0],
    pbc=args.pbc,
)
neighbors = neighbor_fn.allocate(state["r"], num_particles=(state["tag"] != -1).sum())

# create data directory
os.makedirs(args.data_path, exist_ok=True)

@partial(jit, static_argnums=(3, 4, 5))
def advance(dt, state, neighbors, sph, shift_fn, bc_fn):
    state["u"] += 1.0 * dt * state["dudt"]
    state["r"] = shift_fn(state["r"], 1.0 * dt * state["u"])
    num_particles = (state["tag"] != -1).sum()
    neighbors = neighbors.update(state["r"], num_particles=num_particles)

    # update time-derivatives
    state = sph(state, neighbors)
    state = bc_fn(state)
    return state, neighbors

solver = SPH(displacement_fn, g_ext_fn, args)
for step in range(args.sequence_length + 1):
    if step % args.write_every == 0:
        write_h5(state, os.path.join(args.data_path, f"state_{step}.h5"))
    state, neighbors = advance(args.dt, state, neighbors, solver, shift_fn, bc_fn)

    # Check whether the edge list is too small and if so, create longer one
    if neighbors.did_buffer_overflow:
        neighbors = neighbor_fn.allocate(state["r"], num_particles=(state["tag"] != -1).sum())

    if step % args.write_every == 0: 
        u_max = jnp.sqrt(jnp.square(state["u"]).sum(axis=1)).max()
        print(f"{step} / {args.sequence_length}, u_max = {u_max:.4f}")


0 / 1000, u_max = 0.0000
100 / 1000, u_max = 0.1000
200 / 1000, u_max = 0.1989
300 / 1000, u_max = 0.2934
400 / 1000, u_max = 0.3814
500 / 1000, u_max = 0.4620
600 / 1000, u_max = 0.5355
700 / 1000, u_max = 0.6023
800 / 1000, u_max = 0.6630
900 / 1000, u_max = 0.7181
1000 / 1000, u_max = 0.7682


## 2D LDC

In [9]:
args = Namespace(
    case="LDC", 
    solver="SPH",
    dim=2,
    dx=0.02,
    dt=0.0004,
    t_end=850,
    box_size=np.array([1.12, 1.12]), 
    pbc=np.array([True, True]),
    seed=0,
    write_h5=True,
    write_every=100,
    r0_noise_factor=0.0,
    viscosity=0.01,
    relax_pbc=True,
    data_path="datasets/2D_LDC_2708_10kevery100",
    p_bg_factor=0.01,
    is_bc_trick=True,
    density_evolution=False, 
    artificial_alpha=0.0,
    u_ref=1.0,
    rho_ref=1.0,
    sequence_length=1000,  # this one should go to 2.125.000
)
args.mass = args.dx ** args.dim * args.rho_ref
args.dt = args.dt if args.dt > 0 else 0.25*args.dx / (11*args.u_ref)

g_ext_fn = lambda r: jnp.zeros_like(r)
def bc_fn(state):
    u_lid = jnp.array([1.0, 0.0])
    state["u"] = jnp.where(state["tag"][:, None] == 1, 0, state["u"])
    state["u"] = jnp.where(state["tag"][:, None] == 2, u_lid, state["u"])
    return state
key = random.PRNGKey(args.seed)

r_init = pos_init_cartesian_2d(args.box_size, args.dx)  # replace with relaxed state
# tags: {'0': water, '1': solid wall, '2': moving wall}
tag = jnp.ones(len(r_init), dtype=int)
tag = jnp.where(jnp.where(jnp.abs(r_init - r_init.mean(axis=0)).max(axis=1) < 0.5, True, False), 0, tag)
tag = jnp.where(jnp.where(r_init[:, 1] > 1 + 3 * args.dx, True, False), 2, tag)
state = {
    "r": r_init,
    "u": jnp.zeros_like(r_init),
    "tag": tag,
    "dudt": jnp.zeros_like(r_init),
    "rho": jnp.ones(r_init.shape[0]),
    "p": jnp.zeros(r_init.shape[0]),
}

displacement_fn, shift_fn = space.periodic(side=args.box_size)

# Initialize a neighbor list
neighbor_fn = partition.neighbor_list(
    displacement_fn,
    args.box_size,
    r_cutoff=3 * args.dx,
    backend="jaxmd_vmap",
    capacity_multiplier=1.25,
    mask_self=False,
    format=Sparse,
    num_particles_max=state["r"].shape[0],
    pbc=args.pbc,
)
neighbors = neighbor_fn.allocate(state["r"], num_particles=(state["tag"] != -1).sum())

# create data directory
os.makedirs(args.data_path, exist_ok=True)

@partial(jit, static_argnums=(3, 4, 5))
def advance(dt, state, neighbors, sph, shift_fn, bc_fn):
    state["u"] += 1.0 * dt * state["dudt"]
    state["r"] = shift_fn(state["r"], 1.0 * dt * state["u"])
    num_particles = (state["tag"] != -1).sum()
    neighbors = neighbors.update(state["r"], num_particles=num_particles)

    # update time-derivatives
    state = sph(state, neighbors)
    state = bc_fn(state)
    return state, neighbors

solver = SPH(displacement_fn, g_ext_fn, args)
for step in range(args.sequence_length + 1):
    if step % args.write_every == 0:
        write_h5(state, os.path.join(args.data_path, f"state_{step}.h5"))
    state, neighbors = advance(args.dt, state, neighbors, solver, shift_fn, bc_fn)

    # Check whether the edge list is too small and if so, create longer one
    if neighbors.did_buffer_overflow:
        neighbors = neighbor_fn.allocate(state["r"], num_particles=(state["tag"] != -1).sum())

    if step % args.write_every == 0: 
        u_max = jnp.sqrt(jnp.square(state["u"]).sum(axis=1)).max()
        print(f"{step} / {args.sequence_length}, u_max = {u_max:.4f}")


0 / 1000, u_max = 1.0000
100 / 1000, u_max = 1.0000
200 / 1000, u_max = 1.0000
300 / 1000, u_max = 1.0000
400 / 1000, u_max = 1.0000
500 / 1000, u_max = 1.0000
600 / 1000, u_max = 1.0000
700 / 1000, u_max = 1.0000
800 / 1000, u_max = 1.0000
900 / 1000, u_max = 1.0000
1000 / 1000, u_max = 1.0000


## 3D LDC

In [10]:
args = Namespace(
    case="LDC", 
    solver="SPH",
    dim=3,
    dx=0.041666667,
    dt=0.0009,
    t_end=1850,
    box_size=np.array([1.25, 1.25, 0.5]), 
    pbc=np.array([True, True, True]),
    seed=0,
    write_h5=True,
    write_every=100,
    r0_noise_factor=0.0,
    viscosity=0.01,
    relax_pbc=True,
    data_path="datasets/3D_LDC_8160_10kevery100",
    p_bg_factor=0.01,
    is_bc_trick=True,
    density_evolution=False, 
    artificial_alpha=0.0,
    u_ref=1.0,
    rho_ref=1.0,
    sequence_length=1000,  # this one should go to 2.125.000
)
args.mass = args.dx ** args.dim * args.rho_ref
args.dt = args.dt if args.dt > 0 else 0.25*args.dx / (11*args.u_ref)

g_ext_fn = lambda r: jnp.zeros_like(r)
def bc_fn(state):
    u_lid = jnp.array([1.0, 0.0, 0.0])
    state["u"] = jnp.where(state["tag"][:, None] == 1, 0, state["u"])
    state["u"] = jnp.where(state["tag"][:, None] == 2, u_lid, state["u"])
    return state
key = random.PRNGKey(args.seed)

r_init = pos_init_cartesian_3d(args.box_size, args.dx)  # replace with relaxed state
# tags: {'0': water, '1': solid wall, '2': moving wall}
tag = jnp.ones(len(r_init), dtype=int)
tag = jnp.where(jnp.where(jnp.abs(r_init - r_init.mean(axis=0)).max(axis=1) < 0.5, True, False), 0, tag)
tag = jnp.where(jnp.where(r_init[:, 1] > 1 + 3 * args.dx, True, False), 2, tag)
state = {
    "r": r_init,
    "u": jnp.zeros_like(r_init),
    "tag": tag,
    "dudt": jnp.zeros_like(r_init),
    "rho": jnp.ones(r_init.shape[0]),
    "p": jnp.zeros(r_init.shape[0]),
}

displacement_fn, shift_fn = space.periodic(side=args.box_size)

# Initialize a neighbor list
neighbor_fn = partition.neighbor_list(
    displacement_fn,
    args.box_size,
    r_cutoff=3 * args.dx,
    backend="jaxmd_vmap",
    capacity_multiplier=1.25,
    mask_self=False,
    format=Sparse,
    num_particles_max=state["r"].shape[0],
    pbc=args.pbc,
)
neighbors = neighbor_fn.allocate(state["r"], num_particles=(state["tag"] != -1).sum())

# create data directory
os.makedirs(args.data_path, exist_ok=True)

@partial(jit, static_argnums=(3, 4, 5))
def advance(dt, state, neighbors, sph, shift_fn, bc_fn):
    state["u"] += 1.0 * dt * state["dudt"]
    state["r"] = shift_fn(state["r"], 1.0 * dt * state["u"])
    num_particles = (state["tag"] != -1).sum()
    neighbors = neighbors.update(state["r"], num_particles=num_particles)

    # update time-derivatives
    state = sph(state, neighbors)
    state = bc_fn(state)
    return state, neighbors

solver = SPH(displacement_fn, g_ext_fn, args)
for step in range(args.sequence_length + 1):
    if step % args.write_every == 0:
        write_h5(state, os.path.join(args.data_path, f"state_{step}.h5"))
    state, neighbors = advance(args.dt, state, neighbors, solver, shift_fn, bc_fn)

    # Check whether the edge list is too small and if so, create longer one
    if neighbors.did_buffer_overflow:
        neighbors = neighbor_fn.allocate(state["r"], num_particles=(state["tag"] != -1).sum())

    if step % args.write_every == 0: 
        u_max = jnp.sqrt(jnp.square(state["u"]).sum(axis=1)).max()
        print(f"{step} / {args.sequence_length}, u_max = {u_max:.4f}")


0 / 1000, u_max = 1.0000
100 / 1000, u_max = 1.0000
200 / 1000, u_max = 1.0000
300 / 1000, u_max = 1.0000
400 / 1000, u_max = 1.0000
500 / 1000, u_max = 1.0000
600 / 1000, u_max = 1.0000
700 / 1000, u_max = 1.0000
800 / 1000, u_max = 1.0000
900 / 1000, u_max = 1.0000
1000 / 1000, u_max = 1.0000


## 2D DAM

In [11]:
args = Namespace(
    case="DAM", 
    solver="SPH",
    dim=2,
    dx=0.02,
    dt=0.0003,
    t_end=12,
    box_size=np.array([5.486, 2.12]), 
    pbc=np.array([True, True]),
    seed=0,
    write_h5=True,
    write_every=100,
    r0_noise_factor=0.0,
    viscosity=0.00005,
    relax_pbc=True,
    data_path="datasets/3D_LDC_5740_20kevery100",
    p_bg_factor=0.01,
    is_bc_trick=True,
    density_evolution=False, 
    artificial_alpha=0.1,
    u_ref=2**0.5,
    rho_ref=1.0,
    sequence_length=1000,  # this one should go to 40.000
)
args.mass = args.dx ** args.dim * args.rho_ref
args.dt = args.dt if args.dt > 0 else 0.25*args.dx / (11*args.u_ref)

g_ext_fn = lambda r: jnp.ones_like(r) * jnp.array([0.0, 1.0])
def bc_fn(state):
    state["u"] = jnp.where(state["tag"][:, None] == 1, 0, state["u"])
    return state
key = random.PRNGKey(args.seed)

def init_dam():
    L_wall = 5.366
    H_wall = 2.0
    L = 2.0
    H = 1.0
    dx = args.dx
    dx3 = 3 * args.dx
    dx6 = 6 * args.dx

    r_fluid = dx3 + pos_init_cartesian_2d(np.array([L, H]), dx)
    # horizontal and vertical blocks
    vertical = pos_init_cartesian_2d(np.array([dx3, H_wall + dx6]), dx)
    horiz = pos_init_cartesian_2d(np.array([L_wall, dx3]), dx)
    # wall: left, bottom, right, top
    wall_l = vertical.copy()
    wall_b = horiz.copy() + np.array([dx3, 0.0])
    wall_r = vertical.copy() + np.array([L_wall + dx3, 0.0])
    wall_t = horiz.copy() + np.array([dx3, H_wall + dx3])

    res = np.concatenate([wall_l, wall_b, wall_r, wall_t, r_fluid])
    # tag the walls as "1" and the fluid as "0"
    tag = np.ones(len(wall_l) + len(wall_b) + len(wall_r) + len(wall_t))
    tag = np.concatenate([tag, np.zeros(len(r_fluid))])
    return res, tag

r_init, tag = init_dam()  # replace with relaxed fluid state
state = {
    "r": r_init,
    "u": jnp.zeros_like(r_init),
    "tag": tag,
    "dudt": jnp.zeros_like(r_init),
    "rho": jnp.ones(r_init.shape[0]),
    "p": jnp.zeros(r_init.shape[0]),
}

displacement_fn, shift_fn = space.periodic(side=args.box_size)

# Initialize a neighbor list
neighbor_fn = partition.neighbor_list(
    displacement_fn,
    args.box_size,
    r_cutoff=3 * args.dx,
    backend="jaxmd_vmap",
    capacity_multiplier=1.25,
    mask_self=False,
    format=Sparse,
    num_particles_max=state["r"].shape[0],
    pbc=args.pbc,
)
neighbors = neighbor_fn.allocate(state["r"], num_particles=(state["tag"] != -1).sum())

# create data directory
os.makedirs(args.data_path, exist_ok=True)

@partial(jit, static_argnums=(3, 4, 5))
def advance(dt, state, neighbors, sph, shift_fn, bc_fn):
    state["u"] += 1.0 * dt * state["dudt"]
    state["r"] = shift_fn(state["r"], 1.0 * dt * state["u"])
    num_particles = (state["tag"] != -1).sum()
    neighbors = neighbors.update(state["r"], num_particles=num_particles)

    # update time-derivatives
    state = sph(state, neighbors)
    state = bc_fn(state)
    return state, neighbors

solver = SPH(displacement_fn, g_ext_fn, args)
for step in range(args.sequence_length + 1):
    if step % args.write_every == 0:
        write_h5(state, os.path.join(args.data_path, f"state_{step}.h5"))
    state, neighbors = advance(args.dt, state, neighbors, solver, shift_fn, bc_fn)

    # Check whether the edge list is too small and if so, create longer one
    if neighbors.did_buffer_overflow:
        neighbors = neighbor_fn.allocate(state["r"], num_particles=(state["tag"] != -1).sum())

    if step % args.write_every == 0: 
        u_max = jnp.sqrt(jnp.square(state["u"]).sum(axis=1)).max()
        print(f"{step} / {args.sequence_length}, u_max = {u_max:.4f}")


0 / 1000, u_max = 0.0000
100 / 1000, u_max = 1.8534
200 / 1000, u_max = 0.8289
300 / 1000, u_max = 0.7241
400 / 1000, u_max = 0.8026
500 / 1000, u_max = 0.7806
600 / 1000, u_max = 0.7751
700 / 1000, u_max = 0.7977
800 / 1000, u_max = 0.8188
900 / 1000, u_max = 0.8487
1000 / 1000, u_max = 0.8456
