This is a mini notebook that presents some computations of interest, roughly following the outline of my BHI presentation. 

In [2]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt

import ehtim as eh

Welcome to eht-imaging! v 1.2.4 



Let's do some microbenchmarking of JAX functions!

Here is sample_uv as computed for a circular Gaussian model. We can try computing the Fourier transforms for all u and v points using the EHT2025 array:

In [4]:
params = {'F0':1.3, 'x0':0, 'y0':0, 'FWHM':50*eh.RADPERUAS}

eht = eh.array.load_txt('EHT2025.txt')
model = eh.model.Model()
model = model.add_circ_gauss(**params)
tint_sec = 5
tadv_sec = 3600
tstart_hr = 0
tstop_hr = 24
bw_hz = 1e9
obs = model.observe(eht, tint_sec, tadv_sec, tstart_hr, tstop_hr, bw_hz, ampcal=True, phasecal=True,seed=4)
u = obs.data['u']
v = obs.data['v']

def np_circ_gauss_sample_uv(u, v):
    val = (params['F0'] 
            * np.exp(-np.pi**2/(4.*np.log(2.)) * (u**2 + v**2) * params['FWHM']**2)
            * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0']))) 
    return val

def jnp_circ_gauss_sample_uv(u, v):
    val = (params['F0'] 
            * jnp.exp(-jnp.pi**2/(4.*jnp.log(2.)) * (u**2 + v**2) * params['FWHM']**2)
            * jnp.exp(1j * 2.0 * jnp.pi * (u * params['x0'] + v * params['y0']))) 
    return val

jit_jnp_circ_gauss_sample_uv = jax.jit(jnp_circ_gauss_sample_uv)

#Running this once before the actual benchmark will ensure JIT compilation time for our function isn't mistakenly added to our benchmarks:
jit_jnp_circ_gauss_sample_uv(u[0], v[0])

%timeit np_circ_gauss_sample_uv(u,v)
%timeit jnp_circ_gauss_sample_uv(u,v).block_until_ready()
%timeit jit_jnp_circ_gauss_sample_uv(u,v).block_until_ready()

Generating empty observation file . . . 
Adding gain + phase errors to data and applying a priori calibration . . . 
Adding thermal noise to data . . . 
41.3 µs ± 508 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
114 µs ± 695 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
14.6 µs ± 91.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


Autograd for an arbitrary mathematical function:

In [4]:
def f(x):
    return jnp.exp( (-1)*(x**3 + x + jnp.sin(jnp.pi*x)) )


def manual_grad_f(x):
    return ((-1)*(3*x**2 + 1 + jnp.cos(jnp.pi*x)*jnp.pi) 
           * jnp.exp( (-1)*(x**3 + x + jnp.sin(jnp.pi*x)) ))

jax_grad_f = jax.grad(f)
jax_vectorized_grad_f = jax.vmap(jax_grad_f)

x = jnp.linspace(0, 5, 100)
y = manual_grad_f(x)
yy = jax_vectorized_grad_f(x)
jnp.allclose(y, yy)

DeviceArray(True, dtype=bool)

jax.scipy isn't fully implemented yet. One might try to just use regular scipy...

In [None]:
import scipy.special as sps

@jax.jit
def f(n, x):
    return sps.jv(n, x)

print(f(2, 3.5)) #Error! Can't trace through scipy functions with JAX

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
While tracing the function f at /tmp/ipykernel_62806/90201900.py:3 for jit, this concrete value was not available in Python because it depends on the value of the argument 'n'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

Some vmap performance characteristics:

In [12]:
# replacement for scipy.special.jv, which is not available in jax
# computed via trapz using the integral definition of J
def bessel_j(n, z, num_samples=100):
    # print("z", z, z.shape, "n", n, n.shape)
    z = jnp.asarray(z)
    scalar = z.ndim == 0
    if scalar:
        z = z[np.newaxis]
    z = z[:, np.newaxis]
    tau = np.linspace(0, jnp.pi, num_samples)
    integrands = jnp.trapz(jnp.cos(n*tau - z*jnp.sin(tau)), x=tau)
    if scalar:
        return (1./jnp.pi)*integrands.squeeze()
    return (1./jnp.pi)*integrands

def bessel_j_vtest(n, z, num_samples=100):
    tau = np.linspace(0, jnp.pi, num_samples)
    integrands = jnp.trapz(jnp.cos(n*tau - z*jnp.sin(tau)), x=tau)
    return (1./jnp.pi)*integrands

jnp_bessel_j_vtest = jax.vmap(bessel_j_vtest, in_axes=(0, 0), out_axes=0)

def np_bessel_j(n, z, num_samples=100):
    # print("z", z, z.shape, "n", n, n.shape)
    z = np.asarray(z)
    scalar = z.ndim == 0
    if scalar:
        z = z[np.newaxis]
    z = z[:, np.newaxis]
    tau = np.linspace(0, np.pi, num_samples)
    integrands = np.trapz(np.cos(n*tau - z*np.sin(tau)), x=tau)
    if scalar:
        return (1./np.pi)*integrands.squeeze()
    return (1./np.pi)*integrands

def np_bessel_j_vtest(n, z, num_samples=100):
    tau = np.linspace(0, np.pi, num_samples)
    integrands = np.trapz(np.cos(n*tau - z*np.sin(tau)), x=tau)
    return (1./np.pi)*integrands

np_bessel_j_vtest = np.vectorize(np_bessel_j_vtest)

%timeit np_bessel_j_vtest(np.arange(100), np.arange(100))
%timeit jnp_bessel_j_vtest(jnp.arange(100), jnp.arange(100))
%timeit jax.jit(jnp_bessel_j_vtest)(jnp.arange(100), jnp.arange(100))


9.47 ms ± 83.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
3.06 ms ± 454 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
523 µs ± 13.7 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


Making and printing of jaxprs:

In [6]:
def f(x):
    return jnp.exp( (-1)*(x**3 + x + jnp.sin(jnp.pi*x)) )

print(jax.make_jaxpr(f)(1.))

{ lambda ; a:f32[]. let
    b:f32[] = integer_pow[y=3] a
    c:f32[] = add b a
    d:f32[] = mul a 3.141592653589793
    e:f32[] = sin d
    f:f32[] = add c e
    g:f32[] = mul f -1.0
    h:f32[] = exp g
  in (h,) }


Tracing of JAX code: (note the overwriting of x and y inputs with JAX Tracers)

In [14]:
@jax.jit
def f(x, y):
  print("Running f():")
  print(f"  x = {x}")
  print(f"  y = {y}")
  result = jnp.dot(x + 1, y + 1)
  print(f"  result = {result}")
  return result

x = np.random.randn(3, 4)
y = np.random.randn(4)
f(x, y)

Running f():
  x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=0/1)>
  y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=0/1)>
  result = Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=0/1)>


DeviceArray([2.1702843 , 2.0629063 , 0.35211653], dtype=float32)

Sometimes JAX objects can leak into global Python code and cause unexpected behavior!

In [14]:
y = []

@jax.jit
def f(x):
    x+=1
    y.append(x)

f(3)

for elt in y:
    print(elt + 3)

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with shape () and dtype int32 to escape.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was f at /tmp/ipykernel_62806/3481876261.py:3 traced for jit.
------------------------------
The leaked intermediate value was created on line /tmp/ipykernel_62806/3481876261.py:5 (f). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/home/nova/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3080 (run_cell_async)
/home/nova/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3277 (run_ast_nodes)
/home/nova/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3366 (run_code)
/tmp/ipykernel_62806/3481876261.py:8 (<cell line: 8>)
/tmp/ipykernel_62806/3481876261.py:5 (f)
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

Control flow in JAX can be a little tricky!

In [21]:
@jax.jit
def relu(x):
    if x<0:
        return 0
    return x

relu(-3)


ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
While tracing the function relu at /tmp/ipykernel_62806/2780043258.py:1 for jit, this concrete value was not available in Python because it depends on the value of the argument 'x'.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError