In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt

import jax
from jax import jit
from jax.lib import xla_bridge
print(f"Jax is using: {xla_bridge.get_backend().platform}")

import sys
sys.path.append('../')

# Dataset

In [None]:
import pyuff_ustb as pyuff

filepath = "../data/PICMUS_experiment_resolution_distortion.uff"
# filepath = "../data/picmus.uff"
uff = pyuff.Uff(filepath)
print(uff)  # <- print the keys of the UFF file
channel_data = uff.read("channel_data")
scan = uff.read("scan")

In [None]:
beamformed_data = np.array(uff.read("beamformed_data").data).reshape(387,609)
plt.imshow(np.abs(beamformed_data).T, cmap='gray')
plt.show()

In [None]:
uff.read("beamformed_data").sequence[0]

In [None]:
channel_data = uff.read("channel_data")
channel_data

In [None]:
channel_data.sampling_frequency

In [None]:
channel_data.PRF

In [None]:
channel_data.probe

In [None]:
channel_data.sequence[0]

In [None]:
rf_data = np.array(channel_data.data)
rf_data.shape

In [None]:
plt.imshow((rf_data[:, :, 5]), cmap='seismic', aspect='auto', 
           vmin=-1e-2, vmax=1e-2)
plt.colorbar()
plt.show()

# Simulate data

In [None]:
# define linear ultrasound transducer (P4-1)
nelements = channel_data.probe.N_elements
element_pitch = channel_data.probe.pitch
transducer_extent = (nelements - 1) * element_pitch # length of the transducer [m]
transducer_frequency = 5.208e6 # frequency of the transducer [Hz]
transducer_magnitude = 1 # magnitude of the transducer [MPa]
print(f"Transducer extent: {transducer_extent:.3f} m")

In [None]:
# define spatial parameters
factor = 2
dx = np.array([element_pitch/factor, element_pitch/factor]) # grid spacing [m]
element_pitch_gridpoints = int(element_pitch / dx[0])
transducer_extent_gridpoints = element_pitch_gridpoints * (nelements - 1)
N = np.array([transducer_extent_gridpoints, 550]).astype(int)
pml = np.array([20, 20]) # size of the perfectly matched layer [grid points]

# define transducer position in domain
transducer_depth = pml[1] # depth of the transducer [grid points]
transducer_x_start = N[0]//2 - nelements//2 * (element_pitch // dx[0]) # start index of the transducer in the x-direction [grid points]
element_positions = np.array([
    np.linspace(0, transducer_extent_gridpoints, nelements),
    (N[1] - transducer_depth) * np.ones(nelements),
], dtype=int)
element_positions

In [None]:
from jwave_utils import get_domain, get_homogeneous_medium

# define jwave medium
c0 = channel_data.sound_speed # speed of sound [m/s]
medium_params = {
    'c0': c0,  # speed of sound [m/s]
    'rho0': 1000,  # density [kg/m^3]
    'background_mean': 1,  # mean of the background noise
    'pml_size': pml[0]  # size of the perfectly matched layer [grid points]
}

domain = get_domain(N, dx)
speed_homogenous, density_homogenous = get_homogeneous_medium(domain, **medium_params, background_std=0, background_seed=29)

ext = [0, N[0]*dx[0], N[1]*dx[1], 0]
plt.scatter(element_positions[0]*dx[0], element_positions[1]*dx[1],
            c='r', marker='o', s=5, label='transducer element')
plt.imshow(speed_homogenous.T, cmap='gray', extent=ext)
plt.colorbar(label='Speed of sound [m/s]')
plt.xlabel('[m]')
plt.ylabel('[m]')
plt.legend(prop={'size': 7})
plt.gca().invert_yaxis()
plt.show()

In [None]:
from jwave.geometry import TimeAxis, Medium
from jwave import FourierSeries
from jwave_utils import get_plane_wave_excitation

angle_idx = 75//2
angle = channel_data.sequence[angle_idx].source.azimuth#10 * np.pi / 180
dt = 1/channel_data.sampling_frequency
Nt = rf_data.shape[0]
t_end = Nt * dt
# time_axis = TimeAxis(dt, Nt*dt)

# medium = Medium(domain, FourierSeries(jnp.expand_dims(speed_homogenous, -1), domain), FourierSeries(jnp.expand_dims(density_homogenous, -1), domain), pml_size=20)
# time_axis = TimeAxis.from_medium(medium, cfl=0.3)
time_factor = 1
time_axis = TimeAxis(dt/time_factor, t_end)
t = time_axis.to_array()

sources, signal, carrier_signal = get_plane_wave_excitation(
    domain, time_axis, transducer_magnitude, transducer_frequency, 
    element_pitch, element_positions, angle=angle, hann_window=False, tone=True)

fig, axs = plt.subplots(1, 2, figsize=(12, 4))

# Plot the time-domain signal
axs[0].plot(sources.signals[10])
axs[0].set_xlabel('Time [s]')
axs[0].set_ylabel('Amplitude [Pa]')
axs[0].set_title('Time-Domain Signal')

# Compute and plot the Fourier spectrum
signal_fft = np.fft.fft(sources.signals[10])
freq = np.fft.fftfreq(len(signal_fft), d=time_axis.dt)
axs[1].plot(freq, np.abs(signal_fft))
axs[1].set_xlabel('Frequency [Hz]')
axs[1].set_xlim(0, 10e6)
axs[1].set_ylabel('Magnitude')
axs[1].set_title('Fourier Spectrum')
plt.show()


In [None]:
from jwave_utils import get_data

# simulate data using jwave
pressure_homogenous, data_homogenous = get_data(speed_homogenous, density_homogenous, domain, time_axis, sources, element_positions)

In [None]:
from jwave.utils import show_field

t_idx = 400
plt.imshow(pressure_homogenous.params[t_idx])
plt.title(f"Pressure field at t={time_axis.to_array()[t_idx]} seconds")
plt.show()

In [None]:
data_homogenous.shape

In [None]:
plt.imshow(data_homogenous, aspect='auto', cmap='seismic')
plt.xlabel('Transducer elements')
plt.ylabel('Time point')
plt.colorbar()
plt.show()

In [None]:
plt.imshow(rf_data[:, :, angle_idx], cmap='seismic', aspect='auto')
plt.colorbar()
plt.show()

In [None]:
output_data = rf_data[:, :, angle_idx]/np.max(rf_data[:, :, angle_idx]) - data_homogenous[::time_factor]/np.max(data_homogenous)
plt.imshow(output_data, cmap='seismic', aspect='auto')
plt.colorbar()
plt.show()

# Reconstruction

## Naive

In [None]:
from beamforming_utils import get_receive_beamforming
signal_delay = (element_pitch * np.sin(angle) / c0) / time_axis.dt 
res = get_receive_beamforming(domain, time_axis, element_positions, rf_data[:, :, angle_idx], signal, carrier_signal, signal_delay)

In [None]:
plt.imshow(res.T, cmap='gray')#, vmin=-1e-7, vmax=1e-7)
plt.colorbar()
plt.gca().invert_yaxis()
plt.show()

## ntk

In [None]:
from imaging.demodulate import demodulate_rf_to_iq
freq_sampling = 1/time_axis.dt
iq_signals, freq_carrier = demodulate_rf_to_iq(rf_data[:, :, angle_idx], freq_sampling, freq_carrier=transducer_frequency)

In [None]:
N0 = int(domain.N[0])
N1 = int(domain.N[1])
Nz = N1 - transducer_depth
dx0 = domain.dx[0]

# Generate 1D arrays for x and z
x = np.linspace(-(N0//2)*dx0, (N0//2)*dx0, N0)
z = np.linspace(0, Nz*dx0, Nz)

# Create 2D meshgrid for x and z
X, Z = np.meshgrid(x, z)

In [None]:
from imaging.beamform import beamform_delay_and_sum

beamformed_signal = beamform_delay_and_sum(iq_signals, X, Z, freq_sampling, freq_carrier, pitch=element_pitch, tx_delays=np.zeros(nelements))

In [None]:
beamformed_signal_2 = beamformed_signal.copy()
# beamformed_signal_2[:25, :] = 0

plt.imshow(np.abs(beamformed_signal_2), cmap='gray')
plt.colorbar()
plt.show()


## Gradient

In [None]:
from jax import value_and_grad
from jwave_utils import get_data_only
from solver_utils import linear_loss, nonlinear_loss

params = speed_homogenous

# compute first linear gradient
# J = jax.jacrev(get_data_only, argnums=0)(jnp.array(speed), density_homogenous, domain, time_axis, sources, element_positions)
# linear_val_and_grad = value_and_grad(linear_loss, argnums=0)
# linear_loss, linear_gradient = linear_val_and_grad(params, J, output_data)

# compute first nonlinear gradient
nonlinear_val_and_grad = value_and_grad(nonlinear_loss, argnums=0)
nonlinear_loss, nonlinear_gradient = nonlinear_val_and_grad(params, rf_data[:, :, angle_idx], density_homogenous, domain, time_axis, sources, element_positions)
nonlinear_gradient = nonlinear_gradient.at[:, 440:].set(0) # apply mask
# nonlinear_gradient = nonlinear_gradient.at[:, :20].set(0) # apply mask
print(f"Nonlinear loss: {nonlinear_loss}")

In [None]:
nonlinear_gradient.shape

In [None]:
# Viualize
# plt.figure(figsize=(8, 6))
plt.imshow(nonlinear_gradient.T[200:], cmap='gray')
# plt.title("First gradient")
# plt.xlabel('x [gridpoints]')
# plt.ylabel('y [gridpoints]')
plt.gca().invert_yaxis()
# plt.colorbar(shrink=0.55)
# plt.show()

In [None]:
from jax.example_libraries import optimizers
from tqdm import tqdm
from jwave.signal_processing import smooth

losshistory = []
reconstructions = []
num_steps = 100

# Define optimizer
# init_fun, update_fun, get_params = optimizers.adam(1)
init_fun, update_fun, get_params = optimizers.sgd(1)
# init_fun, update_fun, get_params = optimizers.momentum(1, 0.9)
opt_state = init_fun(params)

# Define and compile the update function
@jit
def update(opt_state, k):
    v = get_params(opt_state)
    lossval, gradient = nonlinear_val_and_grad(v, rf_data[:, :, angle_idx], density_homogenous, domain, time_axis, sources, element_positions)
    # gradient = smooth(gradient)
    gradient = gradient.at[:, 440:].set(0)
    gradient = gradient / jnp.max(jnp.abs(gradient))
    return lossval, update_fun(k, gradient, opt_state)

# Main loop
pbar = tqdm(range(num_steps))
for k in pbar:
    lossval, opt_state = update(opt_state, k)

    ## For logging
    new_params = get_params(opt_state)
    reconstructions.append(new_params)
    losshistory.append(lossval)
    pbar.set_description("Loss: {}".format(lossval))

In [None]:
# Viualize
plt.figure(figsize=(8, 6))
plt.imshow(reconstructions[-1].T, cmap='gray')
plt.xlabel('x [gridpoints]')
plt.ylabel('y [gridpoints]')
plt.gca().invert_yaxis()
plt.colorbar(shrink=0.55)
plt.show()