# Navier-Stokes Data Generator

This notebook has been forked from [JAXPI](https://github.com/PredictiveIntelligenceLab/jaxpi/tree/main/examples/ns_tori) to create the reference data for the Navier-Stokes on a torus benchmark.

In [None]:
import jax
import jax.numpy as jnp

from jax import vmap

import jax_cfd.base as cfd
import jax_cfd.base.grids as grids
import jax_cfd.spectral as spectral

from jax_cfd.spectral import utils as spectral_utils

import os

current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)
data_dir = os.path.join(parent_dir, "data")

In [None]:
%%time 
# physical parameters
viscosity = 1e-2
max_velocity = 3
grid = grids.Grid((64, 64), domain=((0, 2 * jnp.pi), (0, 2 * jnp.pi)))
dt = 5e-4

# setup step function using crank-nicolson runge-kutta order 4
smooth = True # use anti-aliasing 
step_fn = spectral.time_stepping.crank_nicolson_rk4(
    spectral.equations.NavierStokes2D(viscosity, grid, smooth=smooth), dt)

# final_time = 1.0
outer_steps = 21
inner_steps = 10

trajectory_fn = cfd.funcutils.trajectory(
    cfd.funcutils.repeated(step_fn, inner_steps), outer_steps)

# create an initial velocity field and compute the fft of the vorticity.
# the spectral code assumes an fft'd vorticity for an initial state
v0 = cfd.initial_conditions.filtered_velocity_field(jax.random.PRNGKey(0), grid, max_velocity, 2)
vorticity0 = cfd.finite_differences.curl_2d(v0).data
vorticity_hat0 = jnp.fft.rfftn(vorticity0)

_, trajectory = trajectory_fn(vorticity_hat0)


In [None]:
w = jnp.fft.irfftn(trajectory, axes=(1,2))

velocity_solve = spectral_utils.vorticity_to_velocity(grid)

u_hat, v_hat = vmap(velocity_solve)(trajectory)
u = vmap(jnp.fft.irfftn)(u_hat)
v = vmap(jnp.fft.irfftn)(v_hat)

x = jnp.arange(grid.shape[0]) * 2 * jnp.pi / grid.shape[0]
y = jnp.arange(grid.shape[0]) * 2 * jnp.pi / grid.shape[0]
t = dt * jnp.arange(outer_steps) * inner_steps

u0 = u[0, :, :]
v0 = v[0, :, :]
w0 = w[0, :, :]

data = {'w': w, 'u':u, 'v':v, 'u0':u0, 'v0':v0, 'w0':w0, 'x':x, 'y': y, 't':t, 'viscosity':viscosity}

jnp.save(os.path.join(data_dir,'ns.npy'), data)