In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt

import sys
sys.path.append('../')

import jax
from jax import jit
from jax.lib import xla_bridge
print(f"Jax is using: {xla_bridge.get_backend().platform}")

# Simulate data

In [None]:
# define linear ultrasound transducer (P4-1)
nelements = 64
element_pitch = 2.95e-4 # distance between transducer elements
transducer_extent = (nelements - 1) * element_pitch # length of the transducer [m]
transducer_frequency = 2e6 # frequency of the transducer [Hz]
transducer_magnitude = 1e6 # magnitude of the transducer [Pa]
print(f"Transducer extent: {transducer_extent:.3f} m")

In [None]:
import pydicom

# define skull slice
skull_dicom = pydicom.dcmread("../data/skull_slice.dcm")
skull_array = skull_dicom.pixel_array.astype(np.int16)
skull_array = skull_array[:500, 600:-600]
skull_array = np.flipud(skull_array)
downsampling_factor = 2
new_shape = (skull_array.shape[0] // downsampling_factor, skull_array.shape[1] // downsampling_factor)
skull_array = jax.image.resize(skull_array, new_shape, method='nearest').T
plt.imshow(skull_array, cmap='gray')
plt.gca().invert_yaxis()
plt.show()

In [None]:
# define spatial parameters
N = np.array(skull_array.shape).astype(int) # grid size [grid points]
dx = np.array([9.07935931401377e-5*downsampling_factor, 9.07935931401377e-5*downsampling_factor]) # grid spacing [m]
pml = np.array([20, 20]) # size of the perfectly matched layer [grid points]

In [None]:
# define transducer position in domain
transducer_depth = pml[1] # depth of the transducer [grid points]
transducer_x_start = N[0]//2 - nelements//2 # start index of the transducer in the x-direction [grid points]
element_positions = np.array([
    np.linspace(transducer_x_start, transducer_x_start + nelements - 1, nelements),
    (N[1] - transducer_depth) * np.ones(nelements),
], dtype=int)
virtual_positions = element_positions.copy()
virtual_positions[1] -= 80
element_positions

In [None]:
from jwave_utils import get_domain, get_homogeneous_medium, get_skull_medium

# define jwave medium
domain = get_domain(N, dx)
speed_homogeneous, density_homogeneous = get_homogeneous_medium(domain, background_std=0, background_seed=29)

c0 = 1500 # speed of sound in water [m/s]
scatterer_positions = np.array([[domain.N[0]//2, domain.N[1]//2]], dtype=int)
speed_skull, density_skull = get_skull_medium(domain, skull_array, background_std=0, background_seed=29)
speed, density = get_skull_medium(domain, skull_array, scatterer_positions, 
                          background_std = 0.000, scatterer_radius=1, scatterer_contrast=2, 
                          background_seed=28)

ext = [0, N[1]*dx[1], N[0]*dx[0], 0]
plt.scatter(element_positions[1]*dx[1], element_positions[0]*dx[0],
            c='r', marker='o', s=5, label='transducer element')
plt.scatter(virtual_positions[1]*dx[1], virtual_positions[0]*dx[0],
            c='b', marker='o', s=5, label='virtual element')
plt.imshow(speed, cmap='gray', extent=ext)
# plt.colorbar(label='Speed of sound [m/s]')
plt.xlabel('[m]')
plt.ylabel('[m]')
plt.legend(prop={'size': 7})
plt.gca().invert_yaxis()
plt.show()

In [None]:
from jwave.geometry import TimeAxis
from jwave.geometry import Medium
from jwave import FourierSeries
from jwave_utils import get_plane_wave_excitation

angle = 0
medium = Medium(domain, FourierSeries(jnp.expand_dims(speed, -1), domain), FourierSeries(jnp.expand_dims(density, -1), domain), pml_size=pml[0])
time_axis = TimeAxis.from_medium(medium, cfl=0.3)
# sources = get_plane_wave_excitation(domain, time_axis, transducer_magnitude, transducer_frequency, element_positions)
# sources, signal, carrier_signal = get_plane_wave_excitation(domain, time_axis, transducer_magnitude, transducer_frequency, dx[0], element_positions, angle=angle)
sources, signal, carrier_signal = get_plane_wave_excitation(domain, time_axis, transducer_magnitude, transducer_frequency, dx[0], virtual_positions, angle=angle)

plt.plot(sources.signals[0])
plt.xlabel('Time point')
plt.ylabel('Amplitude [Pa]')
plt.show()

In [None]:
from jwave_utils import get_data

# simulate data using jwave
pressure, data = get_data(speed, density, domain, time_axis, sources, element_positions)
_, data_homogenous = get_data(speed_homogeneous, density_homogeneous, domain, time_axis, sources, element_positions)
pressure_skull, data_skull = get_data(speed_skull, density_skull, domain, time_axis, sources, element_positions)

In [None]:
from jwave.utils import show_field

t_idx = 600
show_field(pressure[t_idx])
plt.title(f"Pressure field at t={time_axis.to_array()[t_idx]} seconds")
plt.show()

In [None]:
plt.imshow(data, aspect='auto', cmap='seismic')
plt.xlabel('Transducer elements')
plt.ylabel('Time point')
plt.show()

In [None]:
output_data = data-data_skull
plt.imshow(output_data, aspect='auto', cmap='seismic')
plt.xlabel('Transducer elements')
plt.ylabel('Time point')
plt.show()

## Time reversal

In [None]:
time_reversed_data = jnp.squeeze(pressure_skull.params[:, element_positions[0], element_positions[1]])
time_reversed_data = jnp.flip(time_reversed_data, axis=0)
plt.imshow(time_reversed_data, aspect='auto', cmap='seismic')

In [None]:
from jwave.geometry import Sources

time_reversed_sources = Sources(
    positions=tuple(map(tuple, element_positions)),
    signals=jnp.array(time_reversed_data.T),
    dt=time_axis.dt,
    domain=domain,
)

plt.plot(time_axis.to_array(), time_reversed_sources.signals[10])
plt.xlabel('Time point')
plt.ylabel('Amplitude [Pa]')
plt.show()

In [None]:
pressure_skull_corrected, _ = get_data(speed_skull, density_skull, domain, time_axis, time_reversed_sources, element_positions)

In [None]:
from jwave.utils import show_field

t_idx = 2250
show_field(pressure_skull_corrected[t_idx])
plt.title(f"Pressure field at t={time_axis.to_array()[t_idx]} seconds")
plt.show()

# Reconstruction

In [None]:
from kwave.utils.filters import gaussian_filter
from kwave.reconstruction.beamform import envelope_detection

def postprocess_result(orig_res):
    result = np.copy(orig_res)
    for i in range(result.shape[0]):
        result[i, :] = gaussian_filter(result[i, :], 1/dx[0], transducer_frequency, 100.0)
    for i in range(result.shape[0]):
        result[i, :] = envelope_detection(result[i, :])
    return np.flipud(result).T

## Single angle

In [None]:
from beamforming_utils import get_receive_beamforming, get_receive_beamforming_medium_specific
signal_delay = (element_pitch * np.sin(angle) / c0) / time_axis.dt 
res = get_receive_beamforming(domain, time_axis, element_positions, output_data, signal, carrier_signal, signal_delay)
# res = get_receive_beamforming_medium_specific(domain, medium, time_axis, element_positions, output_data, signal, carrier_signal, signal_delay)

In [None]:
# bmode=postprocess_result(res)
plt.imshow(res.T, cmap='seismic', interpolation='nearest')
plt.colorbar()
plt.gca().invert_yaxis()
plt.show()

In [None]:
from imaging.demodulate import demodulate_rf_to_iq
freq_sampling = 1/time_axis.dt
iq_signals, freq_carrier = demodulate_rf_to_iq(output_data, freq_sampling, freq_carrier=transducer_frequency)

In [None]:
N = domain.N[0]
Nz = N - transducer_depth
dx0 = domain.dx[0]

# Generate 1D arrays for x and z
x = np.linspace(-(N//2)*dx0, (N//2)*dx0, N)
z = np.linspace(0, Nz*dx0, Nz)

# Create 2D meshgrid for x and z
X, Z = np.meshgrid(x, z)

In [None]:
from imaging.beamform import beamform_delay_and_sum

beamformed_signal = beamform_delay_and_sum(iq_signals, X, Z, freq_sampling, freq_carrier, pitch=dx0, tx_delays=np.zeros(nelements))

In [None]:
beamformed_signal_2 = beamformed_signal.copy()
# beamformed_signal_2[:25, :] = 0

plt.imshow(np.abs(beamformed_signal_2), cmap='seismic')
plt.colorbar()
plt.show()


## Multiple angles

In [None]:
angles = np.linspace(- 10*np.pi/180, 10*np.pi/180, 10)
results = []
for angle in angles:
    print(f"Angle: {angle}")
    sources, signal, carrier_signal = get_plane_wave_excitation(domain, time_axis, transducer_magnitude, transducer_frequency, dx[0], element_positions, angle=angle)
    _, data = get_data(speed, density, domain, time_axis, sources, element_positions)
    _, data_skull = get_data(speed_skull, density_skull, domain, time_axis, sources, element_positions)
    output_data = data-data_skull
    signal_delay = (element_pitch * np.sin(angle) / c0) / time_axis.dt
    res = get_receive_beamforming(domain, time_axis, element_positions, output_data, signal, carrier_signal, signal_delay)
    # res = get_receive_beamforming_medium_specific(domain, medium, time_axis, element_positions, output_data, signal, carrier_signal, signal_delay)
    results.append(res)

In [None]:
compounded_res = np.sum(results, axis=0)
# compounded_bmode=postprocess_result(compounded_res)
# bmodes = [postprocess_result(res) for res in results]
# compounded_bmode = np.sum(bmodes, axis=0)

plt.imshow(compounded_res.T, cmap='seismic', interpolation='nearest')
plt.colorbar()
plt.gca().invert_yaxis()
plt.show()

## Gradient

In [None]:
from jax import value_and_grad
from jwave_utils import get_data_only
from solver_utils import linear_loss, nonlinear_loss

params = speed_skull

# compute first linear gradient
# J = jax.jacrev(get_data_only, argnums=0)(jnp.array(speed), density_homogenous, domain, time_axis, sources, element_positions)
# linear_val_and_grad = value_and_grad(linear_loss, argnums=0)
# linear_loss, linear_gradient = linear_val_and_grad(params, J, output_data)

# compute first nonlinear gradient
nonlinear_val_and_grad = value_and_grad(nonlinear_loss, argnums=0)
nonlinear_loss, nonlinear_gradient = nonlinear_val_and_grad(params, data, density_skull, domain, time_axis, sources, element_positions)
nonlinear_gradient = nonlinear_gradient.at[:, 140:].set(0) # apply mask
print(f"Nonlinear loss: {nonlinear_loss}")

In [None]:
# Viualize
plt.figure(figsize=(8, 6))
plt.imshow(nonlinear_gradient.T, cmap='seismic')
plt.title("First gradient")
plt.xlabel('x [gridpoints]')
plt.ylabel('y [gridpoints]')
plt.gca().invert_yaxis()
plt.colorbar(shrink=0.55)
plt.show()

In [None]:
from jax.example_libraries import optimizers
from tqdm import tqdm

losshistory = []
reconstructions = []
num_steps = 100

# Define optimizer
# init_fun, update_fun, get_params = optimizers.adam(1)
init_fun, update_fun, get_params = optimizers.sgd(1)
opt_state = init_fun(params)

# Define and compile the update function
@jit
def update(opt_state, k):
    v = get_params(opt_state)
    lossval, gradient = nonlinear_val_and_grad(v, data, density_skull, domain, time_axis, sources, element_positions)
    # gradient = smooth_fun(gradient)
    gradient = gradient.at[:, 140:].set(0)
    gradient = gradient / jnp.max(jnp.abs(gradient))
    return lossval, update_fun(k, gradient, opt_state)

# Main loop
pbar = tqdm(range(num_steps))
for k in pbar:
    lossval, opt_state = update(opt_state, k)

    ## For logging
    new_params = get_params(opt_state)
    reconstructions.append(new_params)
    losshistory.append(lossval)
    pbar.set_description("Loss: {}".format(lossval))

In [None]:
plt.imshow(reconstructions[-1][:,:140].T, cmap='seismic')
plt.gca().invert_yaxis()
plt.show()

In [None]:
# Viualize
plt.figure(figsize=(8, 6))
plt.imshow(reconstructions[-1].T, cmap='seismic', vmax=1550)
plt.xlabel('x [gridpoints]')
plt.ylabel('y [gridpoints]')
plt.gca().invert_yaxis()
plt.colorbar(shrink=0.55)
plt.show()