In [4]:
import matplotlib.pyplot as plt
import requests
from pathlib import Path
# ======================================================================================
# Download the darkmode.mplstyle stylesheet and use it
# ======================================================================================
# Download the darkmode.mplstyle stylesheet from the website repository
url = (
    r"https://raw.githubusercontent.com/vincentvdschaft/quartz-website/v4/"
    r"figure-generation/darkmode.mplstyle")
r = requests.get(url)
# Write the downloaded stylesheet to a file
with open('stylesheet.mplstyle', 'wb') as f:
    f.write(r.content)
# Use the stylesheet
plt.style.use('stylesheet.mplstyle')

In [5]:
import jax.numpy as jnp
from jax import jit


@jit
def compute_xl(xe, ze, xs, zs, lens_thickness, c_lens, c_medium):
    """Computes the lateral point on the lens that the shortest path goes through based
    on Fermat's principle.

    Parameters
    ----------
    xe : float
        The x-coordinate of the element in meters.
    ze : float
        The z-coordinate of the element in meters.
    xs : float
        The x-coordinate of the pixel in meters.
    zs : float
        The z-coordinate of the pixel in meters.
    lens_thickness : float
        The thickness of the lens in meters.
    c_lens : float
        The speed of sound in the lens in m/s.
    c_medium : float
        The speed of sound in the medium in m/s.

    Returns
    -------
    float
        The x-coordinate of the lateral point on the lens.
    """
    xl_init = lens_thickness * (xs - xe) / (zs - ze) + xe
    xl = xl_init
    for i in range(3):
        xl = xl + dxl(xe, ze, xl, xs, zs, lens_thickness, c_lens, c_medium)

    xl = jnp.clip(xl, xl_init - 3 * lens_thickness, xl_init + 3 * lens_thickness)

    # return xl_init
    return xl


@jit
def dxl(xe, ze, xl, xs, zs, zl, c_lens, c_medium):
    """Computes the update step for the lateral point on the lens that the shortest path
    using the Newton-Raphson method.

    Notes
    -----
    This result was derived by defining the total travel time through the lens and the
    medium as a function of the lateral point on the lens and then taking the
    derivative. We then have a function whose root is the lateral point on the lens that
    the shortest path goes through. We then compute the derivative and update the
    lateral point on the lens using the Newton-Raphson method:
    x_new = x - f(x) / f'(x).
    """

    eps = 1e-6

    numerator = -((xe - xl) / (c_lens * jnp.sqrt((xe - xl) ** 2 + (ze - zl) ** 2))) + (
        (xl - xs) / (c_medium * jnp.sqrt((xl - xs) ** 2 + (zl - zs) ** 2)) + eps
    )

    denominator = (
        -(
            (xe - xl) ** 2
            / (c_lens * ((xe - xl) ** 2 + (ze - zl) ** 2) ** (3 / 2) + eps)
        )
        + (1 / (c_lens * jnp.sqrt((xe - xl) ** 2 + (ze - zl) ** 2)))
        - (
            (xl - xs) ** 2
            / (c_medium * ((xl - xs) ** 2 + (zl - zs) ** 2) ** (3 / 2) + eps)
        )
        + (1 / (c_medium * jnp.sqrt((xl - xs) ** 2 + (zl - zs) ** 2) + eps))
    )

    result = -numerator / (denominator + eps)

    # Handle NaNs
    result = jnp.nan_to_num(result)

    # Clip the update step to prevent divergence
    # This value is chosen to be small enough to prevent divergence but large enough to
    # cover the distance accross a normal ultrasound aperture in a single step.
    result = jnp.clip(result, -10e-3, 10e-3)

    return result

In [None]:
c_lens = 1000
c_medium = 1500
lens_thickness = 3e-3

xe = -10e-3
ze = 0e-3
xs = 10e-3

xlims = (-20e-3, 20e-3)
ylims = (-1e-3, 25e-3)

zs = ylims[1] - 2e-3

xl = compute_xl(xe, ze, xs, zs, lens_thickness, c_lens, c_medium)

fig, ax = plt.subplots(figsize=(6, 4))

# Plot the lens with a rectangle
lens = plt.Rectangle(
    (xlims[0], 0),
    xlims[1]-xlims[0],
    lens_thickness,
    color='#5588CC',
    alpha=0.5,
    label='Lens',
)   
ax.add_patch(lens)

marker_size = 4

# Plot the connecting lines
ax.plot([xe, xl], [ze, lens_thickness], 'w--')
ax.plot([xl, xs], [lens_thickness, zs], 'w--')

# Add labels to the connecting lines
ax.text((xe + xl) / 2+1e-3, (ze + lens_thickness) / 2, r'$\tau_{lens}$', color='w', horizontalalignment='left', verticalalignment='center')
ax.text((xl + xs) / 2+1e-3, (lens_thickness + zs) / 2, r'$\tau_{medium}$', color='w', horizontalalignment='left', verticalalignment='center')

# Add labels for the sound speeds
ax.text(xlims[1]*0.8, lens_thickness/2, r'$c_{lens}$', color='w', horizontalalignment='center', verticalalignment='center')
ax.text(xlims[1]*0.8, lens_thickness*1.5, r'$c_{medium}$', color='w', horizontalalignment='center', verticalalignment='center')

# Plot the points
p_element, = ax.plot(xe, ze, 'o', markersize=marker_size)
p_scatterer, = ax.plot(xs, zs, 'o', markersize=marker_size)
p_lateral, = ax.plot(xl, lens_thickness, 'o', markersize=marker_size)

# Create a formatter that divides the ticks by 1e-3
formatter = plt.FuncFormatter(lambda x, _: f'{x * 1e3:.0f}')
ax.xaxis.set_major_formatter(formatter)
ax.yaxis.set_major_formatter(formatter)

# Set the limits
ax.set_xlim(xlims)
ax.set_ylim(ylims)

# Set the labels
ax.set_xlabel('x [mm]')
ax.set_ylabel('z [mm]')

# Add a legend
ax.legend(handles=[p_element, p_lateral, p_scatterer],labels=['Element', 'Lateral point', 'Pixel'], loc='center left', bbox_to_anchor=(1, 0.5))
plt.tight_layout()

output_dir = Path("../content/assets")
output_dir.mkdir(exist_ok=True, parents=True)

plt.savefig(output_dir/"lens_correction.png", dpi=300)
