In [None]:
#import logging
#logger = logging.getLogger()
#logger.setLevel(logging.INFO)
#logging.debug("test")

# autoreload
%load_ext autoreload
%autoreload 2

import os

os.environ["XLA_FLAGS"] = "--xla_gpu_force_compilation_parallelism=1"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

# Beamforming With Optimized Scans
vbeam is built on top of high-level abstractions — we write Numpy-like code, wrap it with `jax.jit`, and hope for the best. We therefore do not have access to the same low-level optimizations that we would have if the beamformer was written in a language like C++ and CUDA. We are limited to high-level optimizations. One such optimization is simply doing less work. By filtering out redundant points from the scan, _before beamforming_, we can get a much faster beamformer. In this notebook we will explore how to perform such an optimization/pre-processing.

The dataset used is a phased array, focused transmit, cardiac dataset, consisting of 64 elements, and 101 transmit events [1]. Since it is a phased array, we use a sector scan with 256 azimuths and 512 depth samples.

Let's (download and) import it!

_[1] A. Rodriguez-Molares, O. M. H. Rindal, O. Bernard, etal., “The UltraSound ToolBox,” in 2017 IEEE International Ultrasonics Symposium (IUS), ISSN: 1948-5727,Sep. 2017, pp. 1–4. DOI: 10 . 1109 / ULTSYM . 2017 .8092389._


In [None]:
import jax.numpy as np
import pyuff_ustb as pyuff

from vbeam.data_importers import import_pyuff
from vbeam.scan import sector_scan
from vbeam.util.download import cached_download

# Download and read the channel data
data_url = "http://www.ustb.no/datasets/Verasonics_P2-4_parasternal_long_small.uff"
uff = pyuff.Uff(cached_download(data_url))
channel_data = uff.read("/channel_data")

# Import the data
setup = import_pyuff(channel_data, frames=0)

# Define and set a custom sector scan
scan_angles = np.array([wave.source.azimuth for wave in channel_data.sequence])
scan_depths = np.linspace(0, 110e-3, 512)
scan = sector_scan(scan_angles, scan_depths).resize(azimuths=256)
setup.scan = scan

## Beamforming the Full Sector Scan
Let's create a basic DAS beamformer from our setup, run it on our imported data, and plot the result. We also time the beamformer. By calling `.block_until_ready()` on the result, we ensure that not more than one beamforming job is running at the same time.

In [None]:
import jax
import matplotlib.pyplot as plt

from vbeam.beamformers import get_das_beamformer

beamformer = jax.jit(get_das_beamformer(setup))  # jax.jit makes it run fast
result = beamformer(**setup.data)
plt.imshow(result.T, aspect="auto", cmap="gray")

data = setup.data
%timeit beamformer(**data).block_until_ready()

Running on an Nvidia A100 40GB GPU, the average elapsed time came out to be `463 ms ± 1.17 ms`.

Let's explore how much work is being done. There are 101 transmits, 64 receiving elements, and 256x512 pixels. This amounts to having to process 101x64x256x512 ≈ 0.85 billion points. In 463 milliseconds, that approximately amounts to 1.8 billion points per second. This is fairly low for vbeam, and is likely due to a memory bottleneck.

In [None]:
from math import prod

dimension_sizes = setup.size(["transmits", "receivers", "points"])
total_points = prod(dimension_sizes)
print(f"Total number of processed points: {total_points:.1e}")
print(f"Points processed per second: {total_points/0.463:.1e}")

## Filtering out Points by Apodization
We are beamforming using RTB, which means that most pixels are weighted by 0, for a given transmit. By using a `ApodizationFilteredScan` (which wraps our original scan) we only beamform the pixels that are included by RTB.

In [None]:
from vbeam.scan import ApodizationFilteredScan

setup.scan = ApodizationFilteredScan.from_setup(setup, ["transmits", "points"])

beamformer = jax.jit(get_das_beamformer(setup))
result = beamformer(**setup.data)
plt.imshow(result.T, aspect="auto", cmap="gray")

data = setup.data
%timeit beamformer(**data).block_until_ready()

Now, the average elapsed time goes down to `16.8 ms ± 248 µs`; more than 27 times faster! Now, the beamformer only have to process around a quarter of the pixels per transmit, for a total of 0.22 billion points. This beamformer processes around 13 billion points per second, which is actually 7 times more efficient — again, likely due to needing less memory now. 

In [None]:
dimension_sizes = setup.size(["transmits", "receivers", "points"])
total_points = prod(dimension_sizes)
print(f"Total number of processed points: {total_points:.1e}")
print(f"Points processed per second: {total_points/0.017:.1e}")

## Scan Converting the Sector Scan Before Beamforming
A final optimization that we will explore in this notebook is to scan convert the sector scan _before beamforming_. There are a lot of resolution redundancy close to the aperture of a sector scan, as the points are so close together. If we scan convert the grid before beamforming these redundant pixels are merged together. Additionally, as the RTB apodization is quite wide close to the aperture, we get even more (relative) redundancy. We can combine the scan conversion and apodization filtering optimizations for an effect that it greater than the sum of its parts.

In [None]:
from vbeam.scan import ScanConvertedSectorScan

# Use both apodization filtering and scan conversion
setup.scan.base_scan = ScanConvertedSectorScan(scan)

beamformer = jax.jit(get_das_beamformer(setup))
result = beamformer(**setup.data)
plt.imshow(result.T, aspect="auto", cmap="gray")

data = setup.data
%timeit beamformer(**data).block_until_ready()

Using both the apodization filtering and pre scan conversion optimizations, the beamformer now takes `5.88 ms ± 155 µs`, on average. There are far fewer points as well, only about 7% of the original scan, or 26% of the apodization filtered scan. This beamformer processes around 9.7 billion points per second; slightly lower than the previous one, almost certainly due to under-utilizing the GPU cores. This beamformer has around 70% GPU utilization, while the previous one had around 90%.

In [None]:
dimension_sizes = setup.size(["transmits", "receivers", "points"])
total_points = prod(dimension_sizes)
print(f"Total number of processed points: {total_points:.1e}")
print(f"Points processed per second: {total_points/0.006:.1e}")

In [None]:
import jax
import jax.numpy as jnp

from vbeam.core import signal_for_point
from vbeam.postprocess import *
from vbeam.scan.advanced.apodization_filtered_scan import recombine1
from vbeam.util.transformations import *


def get_reduce_1(scan: ApodizationFilteredScan):
    def reduce_1(carry, x, points_axis):
        i, imaged_points = x
        return recombine1(
            carry,
            imaged_points,
            scan._indices[i],
            scan._indices_mask[i],
            points_axis,
        )

    return reduce_1


reduce_transmits = compose(
    Reduce(
        "transmits",
        get_reduce_1(setup.scan),
        jnp.zeros((setup.scan.base_scan.num_points,), dtype="complex64"),
        enumerate=True,
        extra_kwargs={"points_axis": Axis("points", keep=True)},
    ),
    Apply(setup.scan.base_scan.unflatten, Axis("points")),
)


beamformer = compose(
    signal_for_point,
    ForAll("receivers", "points"),
    #ForAll("transmits", "receivers", "points"),
    Apply(jnp.sum, Axis("receivers")),
    reduce_transmits,
    #Apply(setup.scan.unflatten, Axis("transmits"), Axis("points")),
    Apply(normalized_decibels),
    Wrap(jax.jit),
).build(setup.spec.replace({"point_pos": ["transmits", "points"]}))
result = beamformer(**setup.data).block_until_ready()

data = setup.data
%timeit beamformer(**data).block_until_ready()