In [None]:
import zea
import matplotlib.pyplot as plt
import jax
import numpy as np
from pathlib import Path
import ulsa.ops

zea.init_device()

In [None]:
# acq = "20251219_lever_wessel3_0000"
# date = "2025-12-19"
# acq = "20251211_cirs_line_dw_hi_0001"
# date = "2025-12-11"
acq = "20251222_s3_a4ch_line_dw_0000"
date = "2025-12-22"

In [None]:
path = f"/mnt/z/usbmd/Wessel/Verasonics/{date}_zea/{acq}.hdf5"
path = Path(path)
harmonic_imaging = True

file = zea.File(path)
# transmits = slice(None)
frame_nr = -1
scan = file.scan()
focused_transmits = np.where(scan.focus_distances > 0)[0]
unfocused_transmits = np.where(scan.focus_distances <= 0)[0]
selected_transmits = focused_transmits
scan.grid_type = "polar"
scan.grid_size_x = 336
# scan.grid_size_z = 112
scan.pfield_kwargs = {
    "downsample": 1,
    "downmix": 1,
    "percentile": 1,
    "alpha": 0.5,
    "norm": False,
}
scan.polar_limits = (scan.polar_angles.min(), scan.polar_angles.max())
scan.set_transmits(selected_transmits)
scan.pixels_per_wavelength = 2
if harmonic_imaging:
    print(
        f"Harmonic imaging: Setting transmit frequency to {scan.demodulation_frequency*1e-6:.2f} MHz"
    )
    scan.center_frequency = scan.demodulation_frequency
scan.f_number = 0.3
scan.zlims = (0, 0.15)

# Avoid the first pixel being at r = 0
# dr = scan.zlims[1] - scan.zlims[0] / scan.grid_size_z
# scan.zlims = (scan.zlims[0] + dr, scan.zlims[1] + dr)

# scan.focus_distances = scan.focus_distances * scan.wavelength  # in meters
delattr(scan, "n_ch")

center = scan.center_frequency
width = 2e6  # Hz
f1 = center - width / 2
f2 = center + width / 2
bpf = zea.func.get_band_pass_filter(128, scan.sampling_frequency, f1, f2)

# Specific to Philips S5-1
scan.apply_lens_correction = True
scan.lens_thickness = 1e-3
scan.lens_sound_speed = 1000

raw_data = file.load_data("raw_data", (frame_nr, selected_transmits))
fps = 1/ (scan.time_to_next_transmit[0][0] * (scan.n_tx_total)*2)

In [None]:
from zea.data.convert.verasonics import VerasonicsFile
vera_path = f"/mnt/z/usbmd/Wessel/Verasonics/{date}/{acq}.mat"

with VerasonicsFile(vera_path) as vf:
    origins = vf.dereference_all(vf["TX"]["Origin"])
    origins = origins[:scan.n_tx]
    origins = np.array(origins).squeeze() # (n_tx, 3)
    origins = origins * vf.wavelength

In [None]:
class UndoTGC(zea.ops.Operation):
    def __init__(self, axis, **kwargs):
        super().__init__(**kwargs)
        self.axis = axis

    def call(self, tgc_gain_curve, **kwargs):
        data = kwargs[self.key]
        data = zea.func.apply_along_axis(lambda x: x / tgc_gain_curve, self.axis, data)
        return {self.output_key: data}


class ApplyAlongAxis(zea.ops.Operation):
    def __init__(self, axis, fn: callable, **kwargs):
        super().__init__(**kwargs)
        self.axis = axis
        self.fn = fn

    def call(self, **kwargs):
        data = kwargs[self.key]
        data = zea.func.apply_along_axis(self.fn, self.axis, data)
        return {self.output_key: data}


In [None]:
attenuation_coef = 0.3  # dB/cm/MHz
tgc_curve = zea.func.make_tgc_curve(
    scan.n_ax, attenuation_coef, scan.sampling_frequency, scan.center_frequency, scan.sound_speed
)
tgc_fn = lambda x: x * tgc_curve
# scan.distance_to_apex = 0.0
rx_apo = ulsa.ops.lines_rx_apo(
    n_tx=len(focused_transmits), grid_size_z=scan.grid_size_z, grid_size_x=scan.grid_size_x
)
pipeline = zea.Pipeline(
    [
        zea.ops.FirFilter(axis=-3, filter_key="bpf"),
        ulsa.ops.WaveletDenoise(),  # optional
        # UndoTGC(axis=-3),
        # ApplyAlongAxis(axis=-3, fn=tgc_fn),
        zea.ops.Demodulate(),
        zea.ops.LowPassFilter(complex_channels=True, axis=-2),  # optional
        zea.ops.Downsample(2),
        zea.ops.Map(
            [
                zea.ops.TOFCorrection(),
                zea.ops.PfieldWeighting(),
                # ulsa.ops.Multiply("rx_apo"),
                zea.ops.DelayAndSum(),
            ],
            # argnames=["flatgrid", "flat_pfield", "rx_apo"],
            argnames=["flatgrid", "flat_pfield"],
            in_axes=(0, 0, 1),
            chunks=100,
        ),
        zea.ops.ReshapeGrid(),
        zea.ops.EnvelopeDetect(),
        zea.ops.Normalize(),
        zea.ops.LogCompress(clip=False),
        # zea.ops.keras_ops.ExpandDims(axis=-1),
        # zea.ops.keras_ops.Resize(size=(112, 112)),
        # zea.ops.keras_ops.Squeeze(axis=-1),
        zea.ops.ScanConvert(order=1),
    ],
    with_batch_dim=False,
    # jit_options=None,
)

params = pipeline.prepare_parameters(scan=scan, bpf=bpf, origins=origins, bandwidth=2e6, rx_apo=rx_apo)
output = pipeline(data=raw_data, **params)
data = output["data"]

plt.figure(figsize=(10, 5))
plt.imshow(data, cmap="gray", vmin=-60, vmax=0, interpolation="nearest")
plt.savefig("processed_image.png", dpi=300, bbox_inches="tight")

In [None]:
from tqdm import tqdm
def raw_to_mp4(
    file: zea.File,
    video_path,
    pipeline: zea.Pipeline,
    scan: zea.Scan,
    transmits=None,
    frames=None,
    **kwargs,
):
    assert pipeline.with_batch_dim is False, "Pipeline must be without batch dimension"
    if transmits is None:
        transmits = slice(None)
    if frames is None:
        frames = slice(None)
    raw_data_frames = file.load_data("raw_data", (frames, transmits))
    data_list = []
    params = pipeline.prepare_parameters(scan=scan, **kwargs)
    for raw_data in tqdm(raw_data_frames):
        data = pipeline(data=raw_data, **params)["data"]
        data = zea.display.to_8bit(data, (-60, 0), pillow=False)
        data_list.append(data)
    zea.io_lib.save_video(jax.numpy.stack(data_list), video_path, fps=fps)


name = path.stem + "_zea.mp4"
raw_to_mp4(file, path.parent / name, pipeline, scan, transmits=selected_transmits, bpf=bpf, origins=origins, bandwidth=2e6)

In [None]:
break

In [None]:
plt.figure(figsize=(10, 5))

if "non_standard_elements/verasonics_image_buffer" in file:
    vera_image = file["non_standard_elements/verasonics_image_buffer"][frame_nr][0].T
else:
    vera_image = file.load_data("image", frame_nr)
plt.imshow(vera_image, cmap="gray", vmin=-60, vmax=0)
# plt.savefig("vera_image.png", dpi=300)

In [None]:
vera_images = file.load_data("image")
vera_images = np.clip(vera_images, a_min=-60, a_max=0)
vera_images = zea.display.to_8bit(vera_images, (-60, 0), pillow=False)
name= path.stem + "_vera.mp4"
zea.io_lib.save_video(vera_images, path.parent /name)