In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from jax import random
from jwave.geometry import Domain
import jax
from jax import numpy as jnp

seeds = random.split(random.PRNGKey(42), 20)

domain = Domain((32,35), (.5,.6))
x = jnp.array([1., 2.])

In [None]:
from jwave.core import operator, Field

In [None]:
from jwave.discretization import Arbitrary
from jax import numpy as jnp
from jax.experimental import stax

init_random_params, predict = stax.serial(
    stax.Dense(32), stax.Relu,
    stax.Dense(32), stax.Relu,
    stax.Dense(1))

init_params = lambda seed, domain: init_random_params(seed, (len(domain.N),))[1]

def get_fun(params, x):
    return predict(params, x)

arbitrary_discr = Arbitrary(domain, get_fun, init_params)

u_arb_params = arbitrary_discr.random_field(seeds[0])
u_arbitrary = Field(arbitrary_discr, params=u_arb_params, name='u')

v_arb_params = arbitrary_discr.random_field(seeds[1])
v_arbitrary = Field(arbitrary_discr, params=v_arb_params, name='v')

In [None]:
from jwave.discretization import RealFourierSeries
fourier_discr = RealFourierSeries(domain)

u_fourier_params = fourier_discr.random_field(seeds[0])
u_fourier = Field(fourier_discr, params=u_fourier_params, name='u')

v_fourier_params = fourier_discr.random_field(seeds[1])
v_fourier = Field(fourier_discr, params=v_fourier_params, name='v')

# `__call__` (non-jittable)

In [None]:
u_arbitrary(x)

In [None]:
u_fourier(x)

# `get_field` (jittable)

In [10]:
u_arbitrary.get_field()(u_arb_params, x)

DeviceArray([-0.11763091], dtype=float32)

In [11]:
u_fourier.get_field()(u_fourier_params, x)

DeviceArray([-0.2871864], dtype=float32)

# `add`

In [13]:
@operator()
def new_op(u, v):
    return u + v

In [None]:
out_field = new_op(u=u_arbitrary)
global_params = out_field.get_global_params()
out_field.get_field(0)(global_params, {"u": arbitrary_field}, x)

# `add_scalar`

In [11]:
from jwave.operators import Dx

@operator()
def get_gradient(u):
    return Dx(u)

In [17]:
out_field = add_two(u=u_arbitrary)
global_params = out_field.get_global_params()
out_field.get_field(0)(global_params, {"u": u_arb_params}, x)

DeviceArray([1.882369], dtype=float32)

In [13]:
out_field = add_two(u=u_arbitrary)

In [14]:
print(out_field)

DiscretizedOperator :: [Arbitrary], ['_k5'] 

 Input fields: ('u',)

Globals: {'shared': {}, 'independent': {'AddScalar_jR': {'scalar': 2.0}}}

Operations:
- _k5: Arbitrary <-- AddScalar ('u',) | (independent) AddScalar_jR



In [18]:
f = out_field.get_field(0)
global_params = out_field.get_global_params()

f(global_params, {"u": u_arb_params}, x)

DeviceArray([1.882369], dtype=float32)

In [19]:
from jax import make_jaxpr

print(make_jaxpr(f)(global_params, {"u": u_arb_params}, x))

{ lambda  ; a b c d e f g h.
  let i = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))
                       precision=None
                       preferred_element_type=None ] h b
      j = add i c
      k = custom_jvp_call_jaxpr[ fun_jaxpr={ lambda  ; a.
                                             let b = max a 0.0
                                             in (b,) }
                                 jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f5528497170>
                                 num_consts=0 ] j
      l = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))
                       precision=None
                       preferred_element_type=None ] k d
      m = add l e
      n = custom_jvp_call_jaxpr[ fun_jaxpr={ lambda  ; a.
                                             let b = max a 0.0
                                             in (b,) }
                                 jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f5528497290>
 

In [20]:
out_field = add_two(u=u_fourier)

global_params = out_field.get_global_params()
out_field.get_field(0)(global_params, {"u": u_fourier_params}, x)

DeviceArray([1.712813], dtype=float32)

In [22]:
f = out_field.get_field(0)
global_params = out_field.get_global_params()

f(global_params, {"u": u_fourier_params}, x)

DeviceArray([1.712813], dtype=float32)

In [24]:
print(make_jaxpr(f)(global_params, {"u": u_fourier_params}, x))

{ lambda a b c d e ; f g h.
  let i = convert_element_type[ new_dtype=float32
                                weak_type=False ] f
      j = add g i
      k = broadcast_in_dim[ broadcast_dimensions=(  )
                            shape=(32,) ] 0.0
      l = convert_element_type[ new_dtype=float32
                                weak_type=False ] a
      m = broadcast_in_dim[ broadcast_dimensions=(  )
                            shape=(1,) ] 0
      n = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))
                   indices_are_sorted=False
                   unique_indices=False
                   update_consts=(  )
                   update_jaxpr={ lambda  ; a b.
                                  let 
                                  in (b,) } ] k m l
      o = convert_element_type[ new_dtype=float32
                                weak_type=False ] b
      p = broadcast_in_dim[ broadcast_dimen