In [None]:
import numpy as np
from jax import jit
from jax import numpy as jnp
from matplotlib import pyplot as plt

from jwave import FourierSeries
from jwave.acoustics import acoustic_solver
from jwave.geometry import *
from jwave.geometry import Sensors, circ_mask, points_on_circle
from jwave.utils import show_field
from jwave.ode import TimeAxis
from jwave.ode import SemiImplicitEulerCorrected
from diffrax import RecursiveCheckpointAdjoint, BacksolveAdjoint, DirectAdjoint

domain = Domain((256, 256), (0.1e-3, 0.1e-3))
medium = Medium(domain=domain, sound_speed=1500.0)
time_axis = TimeAxis.from_cfl_number(medium, cfl=0.3)

In [None]:
# Defining the initial pressure

N = domain.N
mask1 = circ_mask(N, 16, (100, 100))
mask2 = circ_mask(N, 10, (160, 120))
mask3 = circ_mask(N, 20, (128, 128))
mask4 = circ_mask(N, 60, (128, 128))
p0 = 5.0 * mask1 + 3.0 * mask2 + 4.0 * mask3 + 0.5 * mask4

p0 = 1.0 * jnp.expand_dims(p0, -1)
p0 = FourierSeries(p0, domain)

In [None]:
show_field(p0)
plt.title("Initial pressure")

In [None]:
num_sensors = 48
x, y = points_on_circle(num_sensors, 100, (128, 128))
sensors_positions = (x, y)
sensors = Sensors(positions=sensors_positions)

print("Sensors parameters:")
Sensors.__annotations__

In [None]:
@jit
def compiled_simulator(medium, p0):
    return acoustic_solver(medium,
                           time_axis,
                           SemiImplicitEulerCorrected(),
                           p0=p0,
                           sensors=sensors)

In [None]:
sensors_data = compiled_simulator(medium, p0)[..., 0]

In [None]:
%timeit _ = compiled_simulator(medium, p0)[..., 0].block_until_ready()

In [None]:
_field = FourierSeries(sensors_data.T, domain)
show_field(_field, "Recorded acoustic signals")
plt.xlabel("Time step")
plt.ylabel("Sensor position")
plt.axis("on")
plt.show()

In this notebooks, we will work on simulated measurements. To make things (a tiny bit) more realistic, we'll add some coloured noise to each sensor trace

In [None]:
import numpy as np
from jax import grad, random

from jwave.signal_processing import smooth

# Add colored noise
noise = random.normal(random.PRNGKey(42), sensors_data.shape)
for i in range(noise.shape[1]):
    noise = noise.at[:, i].set(smooth(noise[:, i]))

noisy_data = sensors_data + 0.2 * noise

# Show noisy traces
_field = FourierSeries(noisy_data.T, domain)
show_field(_field, "Noisy acoustic signals")
plt.xlabel("Time step")
plt.ylabel("Sensor position")
plt.axis("on")
plt.show()

## Automatic differentiation

In `jwave`, it is possible to take the gradient with respect to any scalar loss, as shown in the following example. The gradients will have the same datatypes as the inputs of to the function, so `Field` gradients will be mapped to `Field` objects.

Here, we write a simple time reversal algorithm using autodiff. Note that time-reversal is proportional to the derivative operator of the `MSE` loss with respect to the measurement data, applied with zero initial conditions.

In [None]:
# The following uses the reciprocity of the wave equation to generate
# a time reversal imaging algorithm
def solver(p0):
    return acoustic_solver(medium,
                           time_axis,
                           SemiImplicitEulerCorrected(),
                           p0=p0,
                           sensors=sensors,
                           adjoint=RecursiveCheckpointAdjoint(checkpoints=1250),
                           max_steps=1250)


@jit  # Compile the whole algorithm
def lazy_time_reversal(measurements):
    def mse_loss(p0, measurements):
        p0 = p0.replace_params(p0.params)
        p_pred = solver(p0)[..., 0]
        return 0.5 * jnp.sum(jnp.abs(p_pred - measurements) ** 2)

    # Start from an empty field
    p0 = FourierSeries.empty(domain)

    # Take the gradient of the MSE loss w.r.t. the
    # measured data
    p_grad = grad(mse_loss)(p0, measurements)

    return -p_grad

In [None]:
# Reconstruct initial pressure distribution
recon_image = lazy_time_reversal(noisy_data)

# Show reconstructed field
show_field(recon_image, "Reconstructed initial pressure using autograd")

In [None]:
# Timings for the reconstruction algorithm, should be
# ~ 2x the forward function.
%timeit lazy_time_reversal(noisy_data).params.block_until_ready()

In [None]:
# Timings for the reconstruction algorithm, should be
# ~ 2x the forward function.
%timeit lazy_time_reversal(noisy_data).params.block_until_ready()

In [None]:
from scipy.interpolate import interp1d

# Save for paper
fig, ax = plt.subplots(1,3, figsize=(10,3), dpi=100)

im1 = ax[0].imshow(p0.on_grid, cmap="RdBu_r", vmin=-6, vmax=6)
cbar = fig.colorbar(im1, ax=ax[0])
cbar.ax.get_yaxis().labelpad = 5
cbar.ax.set_ylabel('A.U.', rotation=270)
ax[0].axis('off')
ax[0].set_title('Initial pressure')
ax[0].scatter(x, y, label="sensors", marker='.')
ax[0].legend(loc="lower right")

# Plotting traces
# adapted from this gist: https://gist.github.com/kwinkunks/f594b243e582666b5a808520e9add262
data = np.asarray(noisy_data).T
time = np.asarray(time_axis.to_array())

skip=2
perc=99.0
gain=1.3
oversampling=100
rgb=(0, 0, 0)
alpha=1.0
lw=0.5

ntraces, nt = data.shape
rgba = list(rgb) + [alpha]
sc = np.percentile(data, perc)  # Normalization factor
wigdata = data[::skip]
xpos = np.arange(ntraces)[::skip]

for y_trace, trace in zip(xpos, wigdata):
    amp = gain * trace / sc + y_trace
    t = 1e6 * time
    hypertime = np.linspace(t[0], t[-1], (oversampling * t.size - 1) + 1)
    interp = interp1d(t, amp, kind='cubic')
    hyperamp = interp(hypertime)

    # Plot the line, then the fill.
    ax[1].plot(hypertime, hyperamp, 'k', lw=lw)
    ax[1].fill_between( hypertime, hyperamp, y_trace,
                     where=hyperamp > y_trace,
                     facecolor=rgba,
                     interpolate=True,
                     lw=0,
                     )

ax[1].yaxis.tick_right()
ax[1].set_title('Noisy traces')
ax[1].set_ylabel("Sensor number")
ax[1].set_xlabel("Time $\mu s$")

im1 = ax[2].imshow(recon_image.on_grid, cmap="RdBu_r", vmin=-0.3, vmax=0.3)
cbar = fig.colorbar(im1, ax=ax[2])
cbar.ax.get_yaxis().labelpad = 5
cbar.ax.set_ylabel('A.U.', rotation=270)
ax[2].axis('off')
ax[2].set_title('Recovered initial pressure')

fontprops = fm.FontProperties(size=12)
scalebar = AnchoredSizeBar(
    ax[2].transData,
    100, '1 cm', 'lower right', 
    pad=0.3,
    color='black',
    frameon=False,
    size_vertical=2,
    fontproperties=fontprops)
ax[2].add_artist(scalebar)

fig.tight_layout()

plt.savefig("initial_pressure_recon.pdf")