**Dynamic shapes can't be used in JIT**

In [5]:
import jax.numpy as jnp
import jax

def nansum(x:jnp.array):
  mask = ~jnp.isnan(x)
  arr_wi_nan = x[mask]
  return arr_wi_nan.sum()

print(nansum(jnp.array([1,2,3,jnp.nan,4])))

# Can't work in jax.jit
# print(jax.jit(nansum)(jnp.array([1,2,3,jnp.nan,4])))

def jaxnansum(x:jnp.array):
  mask = ~jnp.isnan(x)
  # Use jnp.where will fill in zero for nan so no more dynamic shape.
  x = jnp.where(mask, x, 0)
  return jnp.nansum(x)

print(jaxnansum(jnp.array([1,2,3,jnp.nan,4])))

10.0
10.0
