In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt

from jax import jit
from jax.lib import xla_bridge
print(f"Jax is using: {xla_bridge.get_backend().platform}")

import sys
sys.path.append('../')

# Simulate data

In [None]:
# define linear ultrasound transducer (P4-1)
nelements = 64
element_pitch = 2.95e-4 # distance between transducer elements
transducer_extent = (nelements - 1) * element_pitch # length of the transducer [m]
transducer_frequency = 1e6 # frequency of the transducer [Hz]
transducer_magnitude = 1e6 # magnitude of the transducer [Pa]
print(f"Transducer extent: {transducer_extent:.3f} m")

In [None]:
# define spatial parameters
N = np.array([128, 128]).astype(int) # grid size [grid points]
dx = np.array([element_pitch, element_pitch]) # grid spacing [m]
pml = np.array([20, 20]) # size of the perfectly matched layer [grid points]

# define transducer position in domain
transducer_depth = pml[1] # depth of the transducer [grid points]
transducer_x_start = N[0]//2 - nelements//2 # start index of the transducer in the x-direction [grid points]
element_positions = np.array([
    np.linspace(transducer_x_start, transducer_x_start + nelements - 1, nelements),
    (N[1] - transducer_depth) * np.ones(nelements),
], dtype=int)

In [None]:
from jwave_utils import get_domain, get_point_medium, get_homogeneous_medium

# define jwave medium
domain = get_domain(N, dx)
medium = get_point_medium(domain, c0=1500, rho0=1000, scatterer_radius=2, pml_size=pml[0])
medium_homogenous = get_homogeneous_medium(domain, c0=1500, rho0=1000, pml_size=pml[0])

ext = [0, N[0]*dx[0], N[1]*dx[1], 0]
plt.scatter(element_positions[1]*dx[1], element_positions[0]*dx[0],
            c='r', marker='o', s=5, label='transducer element')
plt.imshow(medium.sound_speed.params, cmap='gray', extent=ext)
plt.colorbar(label='Speed of sound [m/s]')
plt.xlabel('[m]')
plt.ylabel('[m]')
plt.legend(prop={'size': 7})
plt.gca().invert_yaxis()
plt.show()

In [None]:
from jwave.geometry import TimeAxis
from jwave_utils import get_plane_wave_excitation

delay_s = 0
time_axis = TimeAxis.from_medium(medium, cfl=0.3)
sources, ss, s = get_plane_wave_excitation(domain, time_axis, transducer_magnitude, transducer_frequency, element_positions, delay_s=delay_s)

plt.plot(sources.signals[0])
plt.xlabel('Time point')
plt.ylabel('Amplitude [Pa]')
plt.show()

In [None]:
from jwave_utils import get_data

# simulate data using jwave
pressure, data = get_data(medium, time_axis, sources, element_positions)
_, data_homogenous = get_data(medium_homogenous, time_axis, sources, element_positions)

In [None]:
from jwave.utils import show_field

t_idx = 200
show_field(pressure[t_idx])
plt.title(f"Pressure field at t={time_axis.to_array()[t_idx]} seconds")
plt.show()

In [None]:
plt.imshow(data, aspect='auto', cmap='seismic')
plt.xlabel('Transducer elements')
plt.ylabel('Time point')
plt.show()

In [None]:
output_data = data-data_homogenous
plt.imshow(output_data, aspect='auto', cmap='seismic')
plt.xlabel('Transducer elements')
plt.ylabel('Time point')
plt.show()

# Beamforming

In [None]:
import torch

c0=1500
slope = delay_s * time_axis.dt * c0 / dx[0].item()
oblique_factor = (slope*slope+1)**0.5

s_stack_t = torch.from_numpy(np.asarray(jnp.vstack([ss[i] for i in range(element_positions.shape[1])]))).to("cuda:0")
s_stack_t /= torch.max(s_stack_t).item()

def compute_time_delays_for_point(x1: np.ndarray, x: int, delta_y: int, c: float = c0):
    scaled_y = delta_y * dx[1]
    scaled_dx = (x1 - x) * dx[0]
    return np.round((scaled_y * oblique_factor + np.sqrt(scaled_y*scaled_y + scaled_dx * scaled_dx)) / (c  * time_axis.dt)).astype(int).tolist()

element_y = (N[1] - transducer_depth)

s_t = torch.from_numpy(np.asarray(s)).to("cuda:0")
s_t /= torch.max(s_t).item()
output_data_t = torch.from_numpy(np.asarray(output_data)).to("cuda:0")
element_positions_t = torch.from_numpy(element_positions[0]).to("cuda:0")
def compute_torch_signal(pt_x, pt_y):
    delta_y = abs(element_y - pt_y)
    delays = compute_time_delays_for_point(element_positions[0], pt_x, delta_y)
    signal = torch.tensor([0.0]).to("cuda:0")
    slanted_x_coord = int(pt_x - slope * delta_y)
    if slanted_x_coord > transducer_x_start + nelements - 1 or slanted_x_coord < transducer_x_start:
        return np.array([0.0])
    for i in range(len(delays)):
        if abs(element_positions[0][i] - (pt_x - slope * delta_y)) < delta_y:
            delta = delays[i]
            signal += torch.dot(s_stack_t[slanted_x_coord - transducer_x_start, :-delta], output_data_t[delta:, i])
    return signal * time_axis.dt

In [None]:
domain.N

In [None]:
# from beamforming_utils import compute_torch_signal

# def get_results():
#     results = np.array([[compute_torch_signal(i, j, output_data, domain, nelements, transducer_depth, transducer_x_start, ss, delay_s, time_axis, element_positions).item() for j in range(0, 128)] for i in range(0, 128)])
#     return results
def get_results():
    results = np.array([[compute_torch_signal(i, j).item() for j in range(0, 128)] for i in range(0, 128)])
    return results
res = get_results()

In [None]:
from kwave.utils.filters import gaussian_filter
from kwave.reconstruction.beamform import envelope_detection

def postprocess_result(orig_res):
    result = np.copy(orig_res)
    for i in range(result.shape[0]):
        result[i, :] = gaussian_filter(result[i, :], 1/dx[0], transducer_frequency, 100.0)
    for i in range(result.shape[0]):
        result[i, :] = envelope_detection(result[i, :])
    # Plotting the heat map
    plt.imshow(result, cmap='viridis', interpolation='nearest')

    # Adding a color bar to show the scale
    plt.colorbar()

    # Display the heat map
    plt.show()
    return result

a=postprocess_result(res)
# plt.imshow(a, cmap='viridis', interpolation='nearest')