In [None]:
import zea

zea.init_device()
zea.visualize.set_mpl_style()

from ulsa.ops import LowPassFilter, apply_along_axis
from active_sampling_temporal import preload_data
import matplotlib.pyplot as plt
import keras
from tqdm import tqdm
import jax.numpy as jnp

In [None]:
def scan_sequence(data, pipeline, parameters, keep_keys=None, **kwargs):
    """
    Process a sequence of data frames through the pipeline.
    """
    if keep_keys is None:
        keep_keys = []
    images = []
    for raw_data in tqdm(data):
        output = pipeline(data=raw_data, **parameters, **kwargs)
        for key in keep_keys:
            if key in output:
                kwargs[key] = output[key]
        images.append(keras.ops.convert_to_numpy(output["data"]))
    return keras.ops.stack(images, axis=0)

In [None]:
def imshow(X, *args, scan=None, **kwargs):
    """Display an image using matplotlib's imshow.

    Will use zea's default style and set the extent (in mm), vmin, and vmax
    based on the provided scan object if available.

    Args:
        X (np.ndarray): The image to display.
        *args: Additional positional arguments for plt.imshow.
        scan (zea.Scan, optional): A scan object containing metadata.
            If provided, it will be used to set the extent, vmin, and vmax of the image.
        **kwargs: Additional keyword arguments for plt.imshow.
    """
    extent = kwargs.pop("extent", scan.extent * 1e3 if scan is not None else None)
    cmap = kwargs.pop("cmap", "gray" if X.ndim == 2 else plt.rcParams["image.cmap"])
    vmin = kwargs.pop("vmin", scan.dynamic_range[0] if scan is not None else None)
    vmax = kwargs.pop("vmax", scan.dynamic_range[1] if scan is not None else None)

    return plt.imshow(X, *args, **kwargs, extent=extent, cmap=cmap, vmin=vmin, vmax=vmax)

In [None]:
file = zea.File("/mnt/USBMD_datasets/2024_USBMD_cardiac_S51/HDF5/20240701_P3_PLAX_0000.hdf5")
dynamic_range = (-65, -18)
validation_sample_frames, scan, probe = preload_data(file, 30, "data/raw_data", cardiac=True, dynamic_range=dynamic_range)

In [None]:
def raw_to_fft(raw_data, fs):
    fft = jnp.fft.fft(raw_data, axis=-2)
    fft = jnp.fft.fftshift(fft, axes=-2)
    fft = jnp.abs(fft)
    fft_std = jnp.std(fft, axis=(-3, -1))
    fft = jnp.mean(fft, axis=(-3, -1))
    fft = 20*jnp.log10(fft + 1e-9)  # add small
    fft_std = 20*jnp.log10(fft_std + 1e-9) / 4  # add small value to avoid log(0)
    x = jnp.linspace(-fs / 2, fs / 2, fft.shape[0]) / 1e6  # convert to MHz
    plt.plot(x, fft)
    # plt.fill_between(x, fft - fft_std, fft + fft_std, alpha=0.2)
    plt.xlabel("Frequency (MHz)")
    plt.ylabel("Amplitude (dB)")

In [None]:
import numpy as np
from scipy.signal import butter, filtfilt

def low_pass_filter(data, axis, cutoff, fs, order=5):
    """
    Apply a Butterworth low-pass filter to the data.

    Parameters:
        data (np.ndarray): Input signal.
        cutoff (float): Cutoff frequency in Hz.
        fs (float): Sampling frequency in Hz.
        order (int): Filter order.

    Returns:
        np.ndarray: Filtered signal.
    """
    def filter(data):
        nyq = 0.5 * fs
        normal_cutoff = cutoff / nyq
        b, a = butter(order, normal_cutoff, btype='low', analog=False)
        return filtfilt(b, a, data)
    return np.apply_along_axis(filter, axis, data)

In [None]:
# visualize frequency spectrum
output = dict(
    data=validation_sample_frames[0],
    center_frequency=scan.center_frequency,
    sampling_frequency=scan.sampling_frequency,
    n_ax=scan.n_ax,
)
output = zea.ops.Demodulate()(**output)
output["center_frequency"] = 0
# output = zea.ops.Downsample(2)(**output)
output = zea.ops.ChannelsToComplex()(**output)
output_aa = LowPassFilter(num_taps=128,axis=-2)(**output, bandwidth=2e6)
raw_aa = output_aa["data"]
# raw_aa2 = low_pass_filter(output["data"], axis=-2, cutoff=0.5e5, fs=output["sampling_frequency"])
fs = output["sampling_frequency"]

plt.figure()
raw_to_fft(output["data"], fs)
raw_to_fft(raw_aa, fs)
# raw_to_fft(raw_aa2, fs)

In [None]:
pipeline = zea.Pipeline.from_default(
    with_batch_dim=False,
    num_patches=40,
    pfield=False,
    jit_options="ops"
)
pipeline.append(zea.ops.ScanConvert(order=1))

parameters = pipeline.prepare_parameters(probe, scan)
output = pipeline(data=validation_sample_frames[0], **parameters)
image_orig = output["data"]
imshow(image_orig, scan=scan)

In [None]:
pipeline = zea.Pipeline.from_default(
    with_batch_dim=False,
    num_patches=40,
    pfield=False,
    jit_options="ops"
)
pipeline.insert(1, zea.ops.Downsample(2))
pipeline.append(zea.ops.ScanConvert(order=1))

parameters = pipeline.prepare_parameters(probe, scan)
output = pipeline(data=validation_sample_frames[0], **parameters)
image1 = output["data"]
imshow(image1, scan=scan)
# images = scan_sequence(validation_sample_frames, pipeline, parameters, bandwidth=2e6)
# images = zea.display.to_8bit(images, dynamic_range=scan.dynamic_range, pillow=False)
# zea.utils.save_to_gif(images, "test1.gif", 20)

In [None]:
import warnings
warnings.filterwarnings("error", category=jnp.ComplexWarning)

pipeline = zea.Pipeline.from_default(
    with_batch_dim=False,
    num_patches=40,
    pfield=True,
    jit_options="ops"
)
pipeline.insert(1, LowPassFilter(num_taps=128, complex_channels=True, axis=-2))
# pipeline.insert(2, zea.ops.Downsample(2))
# pipeline.append(zea.ops.Lambda(lambda x: x[..., None]))
# pipeline.append(zea.ops.LeeFilter(sigma=1))
# pipeline.append(zea.ops.Lambda(lambda x: keras.ops.squeeze(x, axis=-1)))
pipeline.append(zea.ops.ScanConvert(order=1))

parameters = pipeline.prepare_parameters(probe, scan)
output = pipeline(data=validation_sample_frames[0], **parameters, bandwidth=2e6)
image2 = output["data"]
imshow(image2, scan=scan)

In [None]:
# images = scan_sequence(validation_sample_frames, pipeline, parameters, bandwidth=2e6)
# images = zea.display.to_8bit(images, dynamic_range=scan.dynamic_range, pillow=False)
# zea.utils.save_to_gif(images, "test2.gif", 20)

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(20, 20))
axs[0].imshow(image_orig, cmap="gray", vmin=dynamic_range[0], vmax=dynamic_range[1])
axs[0].set_title("Original")
axs[1].imshow(image1, cmap="gray", vmin=dynamic_range[0], vmax=dynamic_range[1])
axs[1].set_title("With Downsample")
axs[2].imshow(image2, cmap="gray", vmin=dynamic_range[0], vmax=dynamic_range[1])
axs[2].set_title("With LowPassFilter")