In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt

from jax import jit
from jax.lib import xla_bridge
print(f"Jax is using: {xla_bridge.get_backend().platform}")

# Setup

## Transducer

In [None]:
# define linear ultrasound transducer (P4-1)
nelements = 64
element_pitch = 2.95e-4 # distance between transducer elements
transducer_extent = (nelements - 1) * element_pitch # length of the transducer [m]
transducer_frequency = 1e6 # frequency of the transducer [Hz]
transducer_magnitude = 1e6 # magnitude of the transducer [Pa]
print(f"Transducer extent: {transducer_extent:.3f} m")

## Domain

In [None]:
# define spatial domain
N = np.array([128, 128]).astype(int) # grid size [grid points]
dx = np.array([element_pitch, element_pitch]) # grid spacing [m]
extent = N * dx # grid size [m]
pml = np.array([20, 20]) # size of the perfectly matched layer [grid points]
print(f"Number of grid points: {N}\nGrid size: {extent} m\nGrid spacing: {dx} m")
assert transducer_extent < extent[0] - 2*pml[0]*dx[0], "Transducer extent is larger than grid size"

from jwave.geometry import Domain
domain = Domain((N[0],N[1]), (dx[0], dx[1]))

In [None]:
# define transducer positions in spatial domain
transducer_depth = pml[1] # depth of the transducer [grid points]
transducer_x_start = N[0]//2 - nelements//2 # start index of the transducer in the x-direction [grid points]
element_positions = np.array([
    np.linspace(transducer_x_start, transducer_x_start + nelements - 1, nelements),
    (N[1] - transducer_depth) * np.ones(nelements),
], dtype=int)

# from jwave.geometry import Sensors
# sensors = Sensors(positions=(element_positions[0], element_positions[1]))

## Acoustic medium

In [None]:
np.random.seed(28)

# define reference properties
c0 = 1500 # reference speed of sound [m/s]
rho0 = 1000 # reference density [kg/m^3]

# define a random distribution of scatterers for the medium
background_map_mean = 1
background_map_std = 0.008
background_map = background_map_mean + background_map_std * np.random.randn(N[0], N[1])
sound_speed = c0 * np.ones(N) * background_map
density = rho0 * np.ones(N) * background_map

# define highly scattering region
scatterer_radius = 2 # radius of scatterers [grid points]
scatterer_contrast = 1.1 # contrast of scatterers
scatterer_positions = np.array([[N[0]//2, N[1]//2]], dtype=int)
scatterer_map = np.zeros(N)
x, y = np.ogrid[:N[0], :N[1]]
for scatterer_position in scatterer_positions:
    scatterer_map[(x - scatterer_position[0])**2 + (y - scatterer_position[1])**2 <= (scatterer_radius)**2] = 1
sound_speed[scatterer_map == 1] = c0*scatterer_contrast
density[scatterer_map == 1] = rho0*scatterer_contrast

# define medium
from jwave import FourierSeries
from jwave.geometry import Medium
sound_speed = FourierSeries(np.expand_dims(sound_speed, -1), domain)
density = FourierSeries(np.expand_dims(density, -1), domain)
medium = Medium(domain=domain, sound_speed=sound_speed, density=density, pml_size=pml[0])
print(medium)


In [None]:
ext = [0, N[0]*dx[0], N[1]*dx[1], 0]
plt.scatter(element_positions[1]*dx[1], element_positions[0]*dx[0],
            c='r', marker='o', s=5, label='transducer element')
plt.imshow(sound_speed.params, cmap='gray', extent=ext)
plt.colorbar(label='Speed of sound [m/s]')
plt.xlabel('[m]')
plt.ylabel('[m]')
plt.legend(prop={'size': 7})  # Decreased the size of the legend
plt.gca().invert_yaxis()
plt.show()

## Time

In [None]:
from jwave.geometry import TimeAxis

time_axis = TimeAxis.from_medium(medium, cfl=0.3)

## Source

In [None]:
from jwave.signal_processing import gaussian_window

t = jnp.arange(0, time_axis.t_end, time_axis.dt)
s = transducer_magnitude * jnp.sin(2 * jnp.pi * transducer_frequency * t)
variance = 2/transducer_frequency
mean = 3*variance
s = gaussian_window(s, t, mean, variance)

plt.plot(s)
plt.xlabel('Time point')
plt.ylabel('Amplitude [Pa]')
plt.show()

In [None]:
from jwave.geometry import Sources

sources = Sources(
    positions=tuple(map(tuple, element_positions)),
    signals=jnp.vstack([s for _ in range(element_positions.shape[1])]),
    dt=time_axis.dt,
    domain=domain,
)

# Run simulation

In [None]:
from jwave.acoustics import simulate_wave_propagation

@jit
def compiled_simulator(sources):
    pressure = simulate_wave_propagation(medium, time_axis, sources=sources)
    return pressure

In [None]:
pressure = compiled_simulator(sources)

In [None]:
from jwave.utils import show_field

t_idx = 100
show_field(pressure[t_idx])
plt.title(f"Pressure field at t={time_axis.to_array()[t_idx]} seconds")
plt.show()

In [None]:
data = np.squeeze(pressure.params[:, element_positions[0], element_positions[1]])

In [None]:
plt.imshow(data, aspect='auto', cmap='seismic')
plt.xlabel('Transducer elements')
plt.ylabel('Time point')
plt.show()

# Beamforming