In [1]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import equinox as eqx

In [2]:
import exponax as ex

In [3]:
class TaylorGreenVorticity(eqx.Module):
    nu: float

    def __init__(self, domain_extent, diffusivity):
        if domain_extent != (2 * jnp.pi):
            raise ValueError("Domain extent must be 2 * pi")
        self.nu = diffusivity

    def __call__(self, t, x):
        f_term = jnp.exp(-2 * self.nu * t)
        vorticity = 2 * jnp.sin(x[0:1]) * jnp.cos(x[1:2]) * f_term

        return vorticity

In [4]:
grid = ex.make_grid(2, 2 * jnp.pi, 60)

2024-03-20 10:14:47.179970: W external/xla/xla/service/gpu/nvptx_compiler.cc:679] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.3.52). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [5]:
tg = TaylorGreenVorticity(2 * jnp.pi, 0.1)

In [6]:
ic = tg(0.0, grid)

In [7]:
ns_stepper = ex.stepper.NavierStokesVorticity(
    2,
    2 * jnp.pi,
    60,
    0.1,
    diffusivity=0.1,
)

In [8]:
def rel_error(pred, ref):
    diff_norm = jnp.linalg.norm(pred - ref)
    ref_norm = jnp.linalg.norm(ref)
    return diff_norm / ref_norm

In [9]:
rel_error(ns_stepper(ic), tg(0.1, grid))

Array(2.3546949e-07, dtype=float32)

In [10]:
rel_error(ex.repeat(ns_stepper, 10)(ic), tg(1.0, grid))

Array(1.3551877e-06, dtype=float32)

In [11]:
rel_error(ex.repeat(ns_stepper, 100)(ic), tg(10.0, grid))

Array(1.2082977e-05, dtype=float32)