In [None]:
import zea

zea.init_device()
zea.visualize.set_mpl_style()

from ulsa.ops import LowPassFilter, apply_along_axis
from active_sampling_temporal import preload_data
import matplotlib.pyplot as plt
import keras
from tqdm import tqdm
import jax.numpy as jnp
import numpy as np
import pywt
import matplotlib.pyplot as plt

In [None]:
def scan_sequence(data, pipeline, parameters, keep_keys=None, **kwargs):
    """
    Process a sequence of data frames through the pipeline.
    """
    if keep_keys is None:
        keep_keys = []
    images = []
    for raw_data in tqdm(data):
        output = pipeline(data=raw_data, **parameters, **kwargs)
        for key in keep_keys:
            if key in output:
                kwargs[key] = output[key]
        images.append(keras.ops.convert_to_numpy(output["data"]))
    return keras.ops.stack(images, axis=0)

In [None]:
def imshow(X, *args, scan=None, **kwargs):
    """Display an image using matplotlib's imshow.

    Will use zea's default style and set the extent (in mm), vmin, and vmax
    based on the provided scan object if available.

    Args:
        X (np.ndarray): The image to display.
        *args: Additional positional arguments for plt.imshow.
        scan (zea.Scan, optional): A scan object containing metadata.
            If provided, it will be used to set the extent, vmin, and vmax of the image.
        **kwargs: Additional keyword arguments for plt.imshow.
    """
    extent = kwargs.pop("extent", scan.extent * 1e3 if scan is not None else None)
    cmap = kwargs.pop("cmap", "gray" if X.ndim == 2 else plt.rcParams["image.cmap"])
    vmin = kwargs.pop("vmin", scan.dynamic_range[0] if scan is not None else None)
    vmax = kwargs.pop("vmax", scan.dynamic_range[1] if scan is not None else None)

    return plt.imshow(X, *args, **kwargs, extent=extent, cmap=cmap, vmin=vmin, vmax=vmax)

In [None]:
file = zea.File("/mnt/USBMD_datasets/2024_USBMD_cardiac_S51/HDF5/20240701_P3_PLAX_0000.hdf5")
dynamic_range = (-65, -18)
validation_sample_frames, scan, probe = preload_data(file, 30, "data/raw_data", cardiac=True, dynamic_range=dynamic_range)

In [None]:
def raw_to_fft(raw_data, fs):
    fft = jnp.fft.fft(raw_data, axis=-2)
    fft = jnp.fft.fftshift(fft, axes=-2)
    fft = jnp.abs(fft)
    fft_std = jnp.std(fft, axis=(-3, -1))
    fft = jnp.mean(fft, axis=(-3, -1))
    fft = 20*jnp.log10(fft + 1e-9)  # add small
    fft_std = 20*jnp.log10(fft_std + 1e-9) / 4  # add small value to avoid log(0)
    x = jnp.linspace(-fs / 2, fs / 2, fft.shape[0]) / 1e6  # convert to MHz
    plt.plot(x, fft)
    # plt.fill_between(x, fft - fft_std, fft + fft_std, alpha=0.2)
    plt.xlabel("Frequency (MHz)")
    plt.ylabel("Amplitude (dB)")

In [None]:
# visualize frequency spectrum
output = dict(
    # data=wavelet_denoise_full(validation_sample_frames[0], axis=-3),
    data=validation_sample_frames[0],
    center_frequency=scan.center_frequency,
    sampling_frequency=scan.sampling_frequency,
    n_ax=scan.n_ax,
)
output = zea.ops.Demodulate()(**output)
output["center_frequency"] = 0
# output = zea.ops.Downsample(2)(**output)
output = zea.ops.ChannelsToComplex()(**output)
# output["center_frequency"] = -1e6
output_aa = LowPassFilter(num_taps=128,axis=-2)(**output, bandwidth=2e6)
raw_aa = output_aa["data"]
fs = output["sampling_frequency"]

plt.figure()
raw_to_fft(output["data"], fs)
raw_to_fft(raw_aa, fs)

In [None]:
pipeline = zea.Pipeline.from_default(
    with_batch_dim=False,
    num_patches=40,
    pfield=False,
    jit_options="ops"
)
pipeline.append(zea.ops.ScanConvert(order=1))

parameters = pipeline.prepare_parameters(probe, scan)
output = pipeline(data=validation_sample_frames[0], **parameters)
image_orig = output["data"]
imshow(image_orig, scan=scan)

In [None]:
pipeline = zea.Pipeline.from_default(
    with_batch_dim=False,
    num_patches=40,
    pfield=False,
    jit_options="ops"
)
pipeline.insert(1, zea.ops.Downsample(2))
pipeline.append(zea.ops.ScanConvert(order=1))

parameters = pipeline.prepare_parameters(probe, scan)
output = pipeline(data=validation_sample_frames[0], **parameters)
image1 = output["data"]
imshow(image1, scan=scan)
# images = scan_sequence(validation_sample_frames, pipeline, parameters, bandwidth=2e6)
# images = zea.display.to_8bit(images, dynamic_range=scan.dynamic_range, pillow=False)
# zea.utils.save_to_gif(images, "test1.gif", 20)

In [None]:
def wavelet_denoise_rf(rf_signal, wavelet="db4", level=4, threshold_factor=0.5):
    """
    Denoise ultrasound RF signal using wavelet thresholding.

    Parameters:
    - rf_signal: 1D numpy array of RF data
    - wavelet: Wavelet type (e.g., 'db4', 'sym8')
    - level: Decomposition level
    - threshold_factor: Scaling for universal threshold

    Returns:
    - Denoised RF signal
    """
    # Decompose
    coeffs = pywt.wavedec(rf_signal, wavelet, level=level)

    # Estimate noise from the detail coefficients at the highest level
    sigma = np.median(np.abs(coeffs[-1])) / 0.6745
    threshold = threshold_factor * sigma * np.sqrt(2 * np.log(len(rf_signal)))

    # Threshold detail coefficients
    new_coeffs = [coeffs[0]]  # Keep approximation unaltered
    for c in coeffs[1:]:
        new_c = pywt.threshold(c, threshold, mode="soft")  # or 'hard'
        new_coeffs.append(new_c)

    # Reconstruct signal
    return pywt.waverec(new_coeffs, wavelet)


def wavelet_denoise_full(data, axis, **kwargs):
    """
    Apply wavelet denoising to the data along a specified axis.

    Parameters:
    - data: Input data (e.g., RF signal)
    - axis: Axis along which to apply the denoising
    - kwargs: Additional parameters for wavelet denoising

    Returns:
    - Denoised data
    """
    # Apply wavelet denoising along the specified axis
    return np.apply_along_axis(lambda x: wavelet_denoise_rf(x, **kwargs), axis, data)

In [None]:
pipeline = zea.Pipeline(
    [
        zea.ops.Lambda(
            wavelet_denoise_full,
            func_kwargs=dict(axis=-3, wavelet="db4", threshold_factor=0.1),
            jittable=False,
        ),
        zea.ops.Demodulate(),
        LowPassFilter(num_taps=128, complex_channels=True, axis=-2),
        zea.ops.Downsample(2),
        zea.ops.PatchedGrid(
            [
                zea.ops.TOFCorrection(),
                zea.ops.DelayAndSum(),
            ]
        ),
        zea.ops.EnvelopeDetect(),
        zea.ops.Normalize(),
        zea.ops.LogCompress(),
        zea.ops.ScanConvert(order=1),
    ],
    with_batch_dim=False,
)

parameters = pipeline.prepare_parameters(probe, scan)
output = pipeline(data=validation_sample_frames[0], **parameters, bandwidth=2e6)
image2 = output["data"]
plt.figure()
imshow(image2, scan=scan)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(20, 20))
axs[0].imshow(image_orig, cmap="gray", vmin=dynamic_range[0], vmax=dynamic_range[1])
axs[0].set_title("Original")
axs[1].imshow(image2, cmap="gray", vmin=dynamic_range[0], vmax=dynamic_range[1])
axs[1].set_title("With lpf and denoising in rf")

In [None]:
# Load or generate sample RF data
rf_signal = validation_sample_frames[0,0,:,0,0][:500]
# Denoise
denoised_rf = wavelet_denoise_rf(rf_signal, level=4, wavelet='sym8', threshold_factor=0.5)

# Plot
plt.figure(figsize=(12, 5))
plt.plot(rf_signal, label='Noisy RF', alpha=0.5)
plt.plot(denoised_rf, label='Denoised RF', linewidth=2)
plt.legend()
plt.title('Wavelet Denoising of RF Ultrasound Signal')
plt.show()
