In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from jax import random
from jwave.geometry import Domain
import jax
from jax import numpy as jnp

seeds = random.split(random.PRNGKey(42), 20)

domain = Domain((32,35), (.5,.6))
x = jnp.array([1., 2.])

2021-07-28 09:13:49.492929: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_SYSTEM_DRIVER_MISMATCH: system has unsupported display driver / cuda driver combination
2021-07-28 09:13:49.493002: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_diagnostics.cc:313] kernel version 465.31.0 does not match DSO version 470.57.2 -- cannot find working devices in this configuration


In [3]:
from jwave.core import operator, Field

In [4]:
from jwave.discretization import Arbitrary
from jax import numpy as jnp
from jax.experimental import stax

init_random_params, predict = stax.serial(
    stax.Dense(32), stax.Relu,
    stax.Dense(32), stax.Relu,
    stax.Dense(1))

init_params = lambda seed, domain: init_random_params(seed, (len(domain.N),))[1]

def get_fun(params, x):
    return predict(params, x)

arbitrary_discr = Arbitrary(domain, get_fun, init_params)

u_arb_params = arbitrary_discr.random_field(seeds[0])
u_arbitrary = Field(arbitrary_discr, params=u_arb_params, name='u')

v_arb_params = arbitrary_discr.random_field(seeds[1])
v_arbitrary = Field(arbitrary_discr, params=v_arb_params, name='v')

In [5]:
from jwave.discretization import RealFourierSeries
fourier_discr = RealFourierSeries(domain)

u_fourier_params = fourier_discr.random_field(seeds[0])
u_fourier = Field(fourier_discr, params=u_fourier_params, name='u')

v_fourier_params = fourier_discr.random_field(seeds[1])
v_fourier = Field(fourier_discr, params=v_fourier_params, name='v')

# `__call__` (non-jittable)

In [6]:
u_arbitrary(x)

DeviceArray([-0.11763091], dtype=float32)

In [7]:
u_fourier(x)

DeviceArray([-0.28718653], dtype=float32)

# `get_field` (jittable)

In [8]:
u_arbitrary.get_field()(u_arb_params, x)

DeviceArray([-0.11763091], dtype=float32)

In [9]:
u_fourier.get_field()(u_fourier_params, x)

DeviceArray([-0.28718653], dtype=float32)

# `add`

In [10]:
v_arbitrary.get_field()(v_arb_params, x)

DeviceArray([-0.31997383], dtype=float32)

In [11]:
@operator()
def new_op(u, v):
    return u + v

In [12]:
out_field = new_op(u=u_arbitrary, v=v_arbitrary)
global_params = out_field.get_global_params()
out_field.get_field(0)(
    global_params, 
    {"u": u_arb_params, "v": v_arb_params}, 
    x
)

DeviceArray([-0.43760473], dtype=float32)

In [13]:
f = out_field.get_field(0)
jax.make_jaxpr(f)(
    global_params, 
    {"u": u_arb_params, "v": v_arb_params}, 
    x
)

{ lambda  ; a b c d e f g h i j k l m.
  let n = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))
                       precision=None
                       preferred_element_type=None ] m a
      o = add n b
      p = custom_jvp_call_jaxpr[ fun_jaxpr={ lambda  ; a.
                                             let b = max a 0.0
                                             in (b,) }
                                 jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7fa9b07911f0>
                                 num_consts=0 ] o
      q = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))
                       precision=None
                       preferred_element_type=None ] p c
      r = add q d
      s = custom_jvp_call_jaxpr[ fun_jaxpr={ lambda  ; a.
                                             let b = max a 0.0
                                             in (b,) }
                                 jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7fa9b

In [18]:
out_field = new_op(u=u_fourier, v=u_fourier)
global_params = out_field.get_global_params()
out_field.get_field(0)(
    global_params, 
    {"u": u_fourier_params, "v": v_fourier_params}, 
    x
)

DeviceArray([0.25046518], dtype=float32)

In [20]:
f = out_field.get_field_on_grid(0)
jax.make_jaxpr(f)(
    global_params, 
    {"u": u_fourier_params, "v": v_fourier_params}
)

{ lambda  ; a b.
  let c = add a b
  in (c,) }

In [22]:
f = out_field.get_field(0)
jax.make_jaxpr(f)(
    global_params, 
    {"u": u_fourier_params, "v": v_fourier_params},
    x
)

{ lambda a b c d e ; f g h.
  let i = add f g
      j = broadcast_in_dim[ broadcast_dimensions=(  )
                            shape=(32,) ] 0.0
      k = convert_element_type[ new_dtype=float32
                                weak_type=False ] a
      l = broadcast_in_dim[ broadcast_dimensions=(  )
                            shape=(1,) ] 0
      m = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))
                   indices_are_sorted=True
                   unique_indices=True
                   update_consts=(  )
                   update_jaxpr={ lambda  ; a b.
                                  let 
                                  in (b,) } ] j l k
      n = convert_element_type[ new_dtype=float32
                                weak_type=False ] b
      o = broadcast_in_dim[ broadcast_dimensions=(  )
                            shape=(1,) ] 16
      p = scatter[ dimension_numbers=ScatterDimen

# `add_scalar`

In [11]:
from jwave.operators import Dx

@operator()
def get_gradient(u):
    return Dx(u)

In [17]:
out_field = add_two(u=u_arbitrary)
global_params = out_field.get_global_params()
out_field.get_field(0)(global_params, {"u": u_arb_params}, x)

DeviceArray([1.882369], dtype=float32)

In [13]:
out_field = add_two(u=u_arbitrary)

In [14]:
print(out_field)

DiscretizedOperator :: [Arbitrary], ['_k5'] 

 Input fields: ('u',)

Globals: {'shared': {}, 'independent': {'AddScalar_jR': {'scalar': 2.0}}}

Operations:
- _k5: Arbitrary <-- AddScalar ('u',) | (independent) AddScalar_jR



In [18]:
f = out_field.get_field(0)
global_params = out_field.get_global_params()

f(global_params, {"u": u_arb_params}, x)

DeviceArray([1.882369], dtype=float32)

In [19]:
from jax import make_jaxpr

print(make_jaxpr(f)(global_params, {"u": u_arb_params}, x))

{ lambda  ; a b c d e f g h.
  let i = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))
                       precision=None
                       preferred_element_type=None ] h b
      j = add i c
      k = custom_jvp_call_jaxpr[ fun_jaxpr={ lambda  ; a.
                                             let b = max a 0.0
                                             in (b,) }
                                 jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f5528497170>
                                 num_consts=0 ] j
      l = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))
                       precision=None
                       preferred_element_type=None ] k d
      m = add l e
      n = custom_jvp_call_jaxpr[ fun_jaxpr={ lambda  ; a.
                                             let b = max a 0.0
                                             in (b,) }
                                 jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f5528497290>
 

In [20]:
out_field = add_two(u=u_fourier)

global_params = out_field.get_global_params()
out_field.get_field(0)(global_params, {"u": u_fourier_params}, x)

DeviceArray([1.712813], dtype=float32)

In [22]:
f = out_field.get_field(0)
global_params = out_field.get_global_params()

f(global_params, {"u": u_fourier_params}, x)

DeviceArray([1.712813], dtype=float32)

In [24]:
print(make_jaxpr(f)(global_params, {"u": u_fourier_params}, x))

{ lambda a b c d e ; f g h.
  let i = convert_element_type[ new_dtype=float32
                                weak_type=False ] f
      j = add g i
      k = broadcast_in_dim[ broadcast_dimensions=(  )
                            shape=(32,) ] 0.0
      l = convert_element_type[ new_dtype=float32
                                weak_type=False ] a
      m = broadcast_in_dim[ broadcast_dimensions=(  )
                            shape=(1,) ] 0
      n = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))
                   indices_are_sorted=False
                   unique_indices=False
                   update_consts=(  )
                   update_jaxpr={ lambda  ; a b.
                                  let 
                                  in (b,) } ] k m l
      o = convert_element_type[ new_dtype=float32
                                weak_type=False ] b
      p = broadcast_in_dim[ broadcast_dimen