In [1]:
import time
import numpy as np

import jax
import jax.numpy as jnp

from pointscat.hankel import h0
from pointscat.forward_problem import compute_foldy_matrix

from jax.config import config
config.update("jax_enable_x64", True)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [2]:
# setting problem
amplitudes = 1 * np.array([1, 2, 1, 0.5, 3, 1, 0.6, 2])
locations = 0.4 * np.array([[-4.3, -4.7], [-4.0, 4.5], [4.2, 3.6], [0, 0], [2.5, 2.1], [-1.2, 3.4],
                            [-1/0.4, 1/0.4], [1/0.4, -1/0.4]])
wave_number = 1

Green function

In [3]:
def g(x, y):
    assert jnp.ndim(x) == jnp.ndim(y)
    assert x.shape[-1] == y.shape[-1] == 2

    if jnp.ndim(x) == 1:
        return 1j/4 * h0(wave_number * jnp.linalg.norm(x-y))
    else:
        return 1j/4 * h0(wave_number * jnp.linalg.norm(x[:, jnp.newaxis] - y[jnp.newaxis, :], axis=-1))

In [4]:
def h(x, y):
    return 1j/4 * h0(wave_number * jnp.linalg.norm(x[:, jnp.newaxis] - y[jnp.newaxis, :], axis=-1))

In [5]:
jit_g = jax.jit(g)
jit_h = jax.jit(h)

In [6]:
x, y = np.random.random((1000, 2)), np.random.random((1000, 2))

In [7]:
%timeit g(x, y)

252 ms ± 10.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
%time jit_g(x, y);

CPU times: user 454 ms, sys: 49.5 ms, total: 503 ms
Wall time: 365 ms


In [9]:
%timeit jit_g(x, y)

13.3 ms ± 1.19 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [10]:
%time jit_h(x, y);

CPU times: user 365 ms, sys: 7.02 ms, total: 372 ms
Wall time: 282 ms


In [11]:
%timeit jit_h(x, y)

13.5 ms ± 850 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


Green function computation seems OK

Compute Foldy matrix

In [50]:
def f(x, a):
    norm = jnp.linalg.norm(x[:, jnp.newaxis, :] - x[jnp.newaxis, :, :], axis=-1)
    res = jnp.where(norm == 0, 0, -wave_number**2 * 1j/4 * h0(wave_number * norm))
    res = jnp.dot(res, jnp.diag(a)) + jnp.eye(len(a))
    return res

In [57]:
a = f(locations, amplitudes)
b = compute_foldy_matrix(locations, amplitudes, wave_number)
print(np.allclose(a, b))

True


In [58]:
jit_f = jax.jit(f)

In [59]:
start = time.time()
jit_f(locations, amplitudes)
stop = time.time()
print(stop-start)

0.31208205223083496


In [60]:
%timeit f(locations, amplitudes)

1.01 ms ± 11.6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [61]:
%timeit jit_f(locations, amplitudes)

9.3 µs ± 23.1 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [62]:
start = time.time()
jax.jit(compute_foldy_matrix)(locations, amplitudes, wave_number)
stop = time.time()
print(stop-start)

14.683596134185791


In [63]:
%timeit jax.jit(compute_foldy_matrix)(locations, amplitudes, wave_number)

99.3 µs ± 743 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


Jacobian of compute Foldy matrix

In [64]:
jac_f = jax.jacfwd(f, argnums=(0, 1))

In [65]:
%timeit jac_f(locations, amplitudes)

75.7 ms ± 1.43 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [67]:
jac_compute_foldy_matrix = jax.jacfwd(compute_foldy_matrix, argnums=(0, 1))

In [68]:
%timeit jac_compute_foldy_matrix(locations, amplitudes, wave_number)

3.86 s ± 75.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [69]:
start = time.time()
jax.jit(jac_f)(locations, amplitudes)
stop = time.time()
print(stop-start)

0.8016571998596191


In [70]:
%timeit jax.jit(jac_f)(locations, amplitudes)

205 µs ± 17.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [71]:
start = time.time()
jax.jit(jac_compute_foldy_matrix)(locations, amplitudes, wave_number)
stop = time.time()
print(stop-start)

103.04824924468994
