In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt

import jax
from jax import jit
from jax.lib import xla_bridge
print(f"Jax is using: {xla_bridge.get_backend().platform}")

In [None]:
import sys
sys.path.append('../')

from jwave_utils import get_domain, get_skull_point_medium, get_plane_wave_excitation, get_data

# Simulate data

In [None]:
# define linear ultrasound transducer (P4-1)
nelements = 64
element_pitch = 2.95e-4 # distance between transducer elements
transducer_extent = (nelements - 1) * element_pitch # length of the transducer [m]
transducer_frequency = 1e6 # frequency of the transducer [Hz]
transducer_magnitude = 1e6 # magnitude of the transducer [Pa]
print(f"Transducer extent: {transducer_extent:.3f} m")

In [None]:
# define skull slice

import pydicom
skull_dicom = pydicom.dcmread("../data/skull_slice.dcm")
skull_array = skull_dicom.pixel_array.astype(np.int16)
skull_array = skull_array[:500, 200:-200]
skull_array = np.flipud(skull_array)
downsampling_factor = 2
new_shape = (skull_array.shape[0] // downsampling_factor, skull_array.shape[1] // downsampling_factor)
skull_array = jax.image.resize(skull_array, new_shape, method='nearest')
plt.imshow(skull_array, cmap='gray')
plt.gca().invert_yaxis()
plt.show()

In [None]:
# define spatial parameters
N = np.array(skull_array.shape).astype(int) # grid size [grid points]
dx = np.array([9.07935931401377e-5*downsampling_factor, 9.07935931401377e-5*downsampling_factor]) # grid spacing [m]
pml = np.array([20, 20]) # size of the perfectly matched layer [grid points]

In [None]:
# define transducer position in domain
transducer_depth = pml[1] # depth of the transducer [grid points]
transducer_x_start = N[1]//2 - nelements//2 * element_pitch // dx[1] # start index of the transducer in the x-direction [grid points]
element_positions = np.array([
    (N[0] - transducer_depth) * np.ones(nelements),
    np.linspace(transducer_x_start, transducer_x_start + (nelements - 1) * element_pitch // dx[1], nelements),
], dtype=int)
element_positions

In [None]:
# define jwave medium
domain = get_domain(N, dx)
medium = get_skull_point_medium(domain, skull_array, rho0=1000, c0=1500, scatterer_radius=10, pml_size=pml[0])

ext = [0, N[1]*dx[1], N[0]*dx[0], 0]
plt.scatter(element_positions[1]*dx[1], element_positions[0]*dx[0],
            c='r', marker='o', s=5, label='transducer element')
plt.imshow(medium.sound_speed.params, cmap='gray', extent=ext)
plt.colorbar(label='Speed of sound [m/s]')
plt.xlabel('[m]')
plt.ylabel('[m]')
plt.legend(prop={'size': 7})
plt.gca().invert_yaxis()
plt.show()

In [None]:
from jwave.geometry import TimeAxis
time_axis = TimeAxis.from_medium(medium, cfl=0.3)
sources = get_plane_wave_excitation(domain, time_axis, transducer_magnitude, transducer_frequency, element_positions)

plt.plot(sources.signals[0])
plt.xlabel('Time point')
plt.ylabel('Amplitude [Pa]')
plt.show()

In [None]:
# simulate data using jwave
pressure, data = get_data(medium, time_axis, sources, element_positions)

In [None]:
from jwave.utils import show_field

t_idx = 900
show_field(pressure[t_idx])
plt.title(f"Pressure field at t={time_axis.to_array()[t_idx]} seconds")
plt.show()

In [None]:
plt.imshow(data, aspect='auto', cmap='seismic')
plt.xlabel('Transducer elements')
plt.ylabel('Time point')
plt.show()