# Patient-Adaptive Focused Transmit Beamforming using Cognitive Ultrasound

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tue-bmd/ulsa/blob/main/docs/source/notebooks/agent/agent_example.ipynb)
&nbsp;
[![View on GitHub](https://img.shields.io/badge/GitHub-View%20Source-blue?logo=github)](https://github.com/tue-bmd/ulsa/blob/main/docs/source/notebooks/agent/agent_example.ipynb)
&nbsp;
[![Hugging Face model](https://img.shields.io/badge/Hugging%20Face-Model-yellow?logo=huggingface)](https://huggingface.co/zeahub/ulsa)

In [None]:
%%capture
%pip install zea

In [None]:
import os

os.environ["KERAS_BACKEND"] = "jax"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

In [None]:
import zea
import jax
import keras
import numpy as np
import scipy
from matplotlib import animation
import matplotlib.pyplot as plt
from tqdm import tqdm
from keras.ops import convert_to_numpy
from IPython.display import HTML

In [None]:
from ulsa.agent import setup_agent, Recover, AgentMask, reset_agent_state
from ulsa.pipeline import make_pipeline
from ulsa.ops import lines_rx_apo, Squeeze, Copy
from ulsa.utils import (
    update_scan_for_polar_grid,
    FOCUSED_TRANSMITS,
    load_subsampled_data,
    copy_transmits_from_scan,
    get_subsampled_parameters,
    precompute_dynamic_range,
    scan_sequence,
)
from ulsa.io_utils import postprocess_agent_results

We will work with the GPU if available, and initialize using `init_device` to pick the best available device. Also, (optionally), we will set the matplotlib style for plotting.

In [None]:
zea.init_device(verbose=False)
zea.visualize.set_mpl_style()

## Prepare data

In [None]:
data_type = "data/raw_data"
file_path = "/mnt/USBMD_datasets/2024_USBMD_cardiac_S51/HDF5/20240701_P1_A4CH_0001.hdf5"
file = zea.File(file_path)

## Agent

In [None]:
seed = jax.random.PRNGKey(42)
agent_config = zea.Config.from_yaml("configs/cardiac_112_3_frames.yaml")
agent, _ = setup_agent(agent_config, seed, jit_mode="off")  # the pipeline will jit

n_actions = agent_config.action_selection.n_actions
io_config = agent_config.io_config

In [None]:
scan = file.scan()

if "cardiac" in str(file.path):
    scan = copy_transmits_from_scan(scan, FOCUSED_TRANSMITS)
    update_scan_for_polar_grid(scan)

scan.dynamic_range = agent_config.data.image_range

## Targets

In [None]:
pipeline = make_pipeline(
    data_type=data_type,
    output_range=agent.input_range,
    output_shape=agent_config.action_selection.shape,
    action_selection_shape=agent_config.action_selection.shape,
)
pipeline.append(Squeeze(axis=-1))

# For raw data we need to prepare some beamforming settings
if data_type == "data/raw_data":
    bandpass_rf = scipy.signal.firwin(
        numtaps=128,
        cutoff=np.array([0.5, 1.5]) * scan.center_frequency,
        pass_zero="bandpass",
        fs=scan.sampling_frequency,
    )
    rx_apo = lines_rx_apo(scan.n_tx_total, scan.grid_size_z, scan.grid_size_x)
    bandwidth = 2e6

    params = pipeline.prepare_parameters(
        scan=scan, bandpass_rf=bandpass_rf, rx_apo=rx_apo, bandwidth=bandwidth, minval=0
    )
    params |= precompute_dynamic_range(file, scan, params)
else:
    params = {}

In [None]:
n_frames = 30
data = load_subsampled_data(
    file,
    data_type,
    slice(0, n_frames),
    np.ones(agent_config.action_selection.n_possible_actions),
    agent_config.action_selection.n_possible_actions,
)
targets = scan_sequence(data, pipeline, params)

targets = postprocess_agent_results(
    targets, io_config, scan_convert_order=0, image_range=[-1, 1]
)

fig, _ = zea.visualize.plot_image_grid(
    targets[::2],
    titles=[f"t={t}" for t in list(range(0, len(targets), 2))],
    remove_axis=True,
    vmin=0,
    vmax=255,
)

## Ultrasound pipeline + active perception

In [None]:
pipeline = make_pipeline(
    data_type=data_type,
    output_range=agent.input_range,
    output_shape=agent.input_shape,
    action_selection_shape=agent_config.action_selection.shape,
)

# Make sure the subsampled measurements are masked in the right way
pipeline.append(AgentMask())

# Copy the measurement to another key
pipeline.append(Copy(output_key="measurement"))

# Recover the subsampled data
pipeline.append(Recover(agent, hard_projection=True))

# Crop to the right shape
post_process = keras.layers.CenterCrop(*agent_config.action_selection.shape)
pipeline.append(zea.ops.Lambda(post_process))

## Active perception loop

In [None]:
# Initialize the agent state
agent_state = reset_agent_state(agent, seed)
params["agent_state"] = agent_state
selected_lines = agent_state.selected_lines

# Initialize lists
reconstructions = []
measurements = []
belief_distributions = []

for i in tqdm(range(n_frames)):
    selected_data = load_subsampled_data(file, data_type, i, selected_lines, n_actions)
    subsampled_params = get_subsampled_parameters(
        data_type, scan, selected_lines, rx_apo, n_actions
    )
    subsampled_params = pipeline.prepare_parameters(**subsampled_params)

    # Skip some parameters from the scan class
    subsampled_params.pop("dynamic_range", None)

    # Run pipeline
    output = pipeline(data=selected_data, **{**params, **subsampled_params})

    # Load data from the output
    reconstruction = output["data"]
    selected_lines = output["agent_state"].selected_lines

    # Keep some keys for the next iteration
    params["agent_state"] = output["agent_state"]

    # Store the reconstruction
    reconstructions.append(convert_to_numpy(reconstruction))
    measurements.append(convert_to_numpy(output["measurement"]))
    belief_distributions.append(convert_to_numpy(output["agent_state"].belief_distribution))

reconstructions = np.stack(reconstructions)
measurements = np.stack(measurements)
belief_distributions = np.stack(belief_distributions)

In [None]:
reconstructions = postprocess_agent_results(
    np.squeeze(reconstructions, -1),
    io_config,
    scan_convert_order=0,
    image_range=[-1, 1],
    reconstruction_sharpness_std=io_config.get("reconstruction_sharpness_std", 0.0),
)
measurements = postprocess_agent_results(
    keras.ops.squeeze(post_process(measurements), -1),
    io_config,
    scan_convert_order=0,
    image_range=[-1, 1],
)
variance = keras.ops.var(belief_distributions, axis=1)
variance = convert_to_numpy(keras.ops.squeeze(post_process(variance), -1))
variance = postprocess_agent_results(
    variance,
    io_config,
    scan_convert_order=0,
    image_range=[0, np.percentile(variance, 99.5)],
)

In [None]:
fig, ims = zea.visualize.plot_image_grid(
    [targets[0], measurements[0], reconstructions[0], variance[0]],
    titles=["Target", "Measurements", "Reconstruction", "Variance"],
    ncols=4,
    vmin=0,
    vmax=255,
    cmap=["gray"] * 3 + ["inferno"],
    figsize=(11, 4),
)


def update(frame):
    ims[0].set_array(targets[frame])
    ims[1].set_array(measurements[frame])
    ims[2].set_array(reconstructions[frame])
    ims[3].set_array(variance[frame])

    return ims


ani = animation.FuncAnimation(fig, update, frames=len(targets), blit=True, interval=100)
plt.close(fig)
HTML(ani.to_jshtml(embed_frames=True, default_mode="loop"))