In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from jax import random
from jwave.geometry import Domain
import jax
from jax import numpy as jnp

seed = random.PRNGKey(42)

domain = Domain((32,35), (.5,.6))
x = jnp.array([1., 2.])

In [3]:
from jwave.core import operator, Field

In [4]:
from jwave.discretization import Arbitrary
from jax import numpy as jnp
from jax.experimental import stax

init_random_params, predict = stax.serial(
    stax.Dense(1024), stax.Relu,
    stax.Dense(1024), stax.Relu,
    stax.Dense(1))

init_params = lambda seed, domain: init_random_params(seed, (len(domain.N),))[1]

def get_fun(params, x):
    return predict(params, x)

arbitrary_discr = Arbitrary(domain, get_fun, init_params)
arbitrary_field = arbitrary_discr.random_field(seed)
u_arbitrary = Field(arbitrary_discr, params=arbitrary_field, name='u')

In [5]:
from jwave.discretization import RealFourierSeries
fourier_discr = RealFourierSeries(domain)
fourier_field = fourier_discr.random_field(seed)
u_fourier = Field(fourier_discr, params=fourier_field, name='u')

# `__call__` (non-jittable)

In [6]:
u_arbitrary(x)

DeviceArray([0.04352659], dtype=float32)

In [7]:
u_fourier(x)

DeviceArray([-0.89064735], dtype=float32)

# `get_field` (jittable)

In [8]:
u_arbitrary.get_field()(arbitrary_field, x)

DeviceArray([0.04352659], dtype=float32)

In [9]:
u_fourier.get_field()(fourier_field, x)

DeviceArray([-0.89064735], dtype=float32)

# `add_scalar`

In [13]:
@operator()
def new_op(u):
    return u + 2.

In [14]:
out_field = new_op(u=u_arbitrary)
global_params = out_field.get_global_params()
out_field.get_field(0)(global_params, {"u": arbitrary_field}, x)

DeviceArray([2.0435266], dtype=float32)

In [16]:
out_field = new_op(u=u_fourier)
global_params = out_field.get_global_params()
out_field.get_field(0)(global_params, {"u": fourier_field}, x)

DeviceArray([1.1093526], dtype=float32)