In [None]:
from pathlib import Path

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import polars as pl
import torch
from cmcrameri import cm
from matplotlib.gridspec import GridSpec
from shapely.geometry import Point, Polygon
from tqdm import tqdm

In [None]:
plt.style.use(Path("../meta/colorblind_friendly.mplstyle"))
matplotlib.rcParams.update({"font.family": "serif", "font.size": 8})


# station locations
stationlist_all = pl.read_csv("../meta/stations_all.csv")
stations_all = torch.tensor(
    np.vstack((stationlist_all["X"].to_numpy(), stationlist_all["Y"].to_numpy())).T
)

# receiver stations
stationlist_rcv = pl.read_csv("../meta/stations_receivers.csv")
names_receivers = stationlist_rcv["station"]
stations_receivers = torch.tensor(
    np.vstack((stationlist_rcv["X"].to_numpy(), stationlist_rcv["Y"].to_numpy())).T
)

# auxiliary stations
stationlist_aux = pl.read_csv("../meta/stations_auxiliary.csv")
stations_auxiliary = torch.tensor(
    np.vstack((stationlist_aux["X"].to_numpy(), stationlist_aux["Y"].to_numpy())).T
)

master_idx = names_receivers.to_numpy().tolist().index("OMV.GDT")

_cm = 1 / 2.54

fig, ax = plt.subplots(figsize=(1.45 * 9 * _cm, 9 * _cm))
ax.scatter(*stations_all.T, s=5, lw=0, c="#CCC")
ax.scatter(*stations_receivers.T, s=5, lw=0)
ax.scatter(*stations_receivers[master_idx].T, s=100, marker="v", lw=0.5, ec="k")
ax.scatter(*stations_auxiliary.T, s=5, lw=0, c="C4")

ax.set(xlim=(-11, 11), ylim=(-8, 7), aspect="equal")

In [None]:
# selection based on azimuth

# Zhang et al. 2020 use a fixed angle of 8 degrees, or alpha~0.01
# theta = np.radians(8)
# Nouibat et al. 2022 use a fixed angle of 20 degrees, or alpha~0.06
theta = np.radians(20)

alpha = 1 - np.cos(theta)

# compute a mask for each receiver station (station B)
# that determines which auxiliary stations are included
station_A = stations_receivers[master_idx]
masks_angle = []
for station_B in tqdm(stations_receivers):
    mask = []
    for idx, source in enumerate(stations_auxiliary):
        d_a_b = torch.norm(station_A - station_B)
        d_aux_a = torch.norm(source - station_A)
        d_aux_b = torch.norm(source - station_B)
        include_station = False
        # eq. 14 in the manuscript, eq. 3 of Zhang et al. 2020
        if abs(d_aux_a - d_aux_b) >= (1 - alpha) * d_a_b:
            include_station = True
        mask.append(include_station)
    masks_angle.append(mask)

masks_angle = torch.tensor(masks_angle)

In [None]:
# select endfire-lobe style
# Roux et al. 2004
f = 0.3
omega = 2 * torch.pi * f
deltaomega = omega * 0.1
c = 1.9

results = []
masks_endfire = []
for station_B in tqdm(stations_receivers):
    # distance_from_center is half distance between the two stations
    a = distance_from_center = torch.norm(station_A - station_B) / 2

    # hacky way of implementing this
    stations = torch.tensor([[-a, 0], [a, 0]], dtype=torch.float32)
    R = 2 * a

    # measure the angle between the two stations
    d_a_b = torch.norm(stations[0] - stations[1])
    dtheta = torch.linspace(-torch.pi / 2, torch.pi / 2, 721)

    # "directivity pattern"
    # eq. 13 in the manuscript, eq. 3 of Roux et al. 2004
    B = 1 - dtheta**4 / 8 * (R / c) ** 2 * (omega**2 + deltaomega**2 / 12)

    # make directivity pattern usable for selection
    # remove dthetas corresponding to B < 0
    dtheta = dtheta[B >= 0]
    B = B[B >= 0]
    # make sure the polygon is large enough
    B *= 2 * 10
    x, y = B * torch.cos(dtheta), B * torch.sin(dtheta)
    # add an x-mirrored copy
    x = torch.cat([-x, x])
    y = torch.cat([y, y])
    # get true theta_0 from the angle between the two stations
    theta_0 = torch.atan2(station_B[1] - station_A[1], station_B[0] - station_A[0])
    # rotate x and y coordinates around center to align with the two stations
    # use numpy for computing trigonometric functions on scalars
    _x = x * np.cos(theta_0) - y * np.sin(theta_0)
    _y = x * np.sin(theta_0) + y * np.cos(theta_0)
    # shift to the center of the two stations
    _x += (station_A[0] + station_B[0]) / 2
    _y += (station_A[1] + station_B[1]) / 2

    polygon = Polygon(zip(_x, _y))
    mask = torch.tensor(
        [polygon.contains(Point(p)) for p in stations_auxiliary], dtype=torch.bool
    )
    masks_endfire.append(mask)

    results.append([_x, _y, mask])

masks_endfire = torch.stack(masks_endfire, dim=0)

In [None]:
# plot a few examples to see if this works
fig, axs = plt.subplots(
    2, 5, figsize=(2.5 * 9 * _cm, 9 * _cm), sharex=True, sharey=True
)
fig.subplots_adjust(wspace=0, hspace=0.2)
for station_B, mask, ax in zip(stations_receivers, masks_angle, axs[0]):
    ax.scatter(*station_B, marker="v")
    ax.scatter(*station_A, marker="v")
    ax.scatter(*stations_auxiliary[mask].T, marker="o", c="C4", s=5, lw=0)
    ax.set(xlim=(-6, 6), ylim=(-6, 6), aspect="equal")

for idx, (ax, r) in enumerate(zip(axs[1], results)):
    ax.scatter(*stations_receivers[idx], marker="v")
    ax.scatter(*station_A, marker="v")
    ax.scatter(*stations_auxiliary[r[-1]].T, marker="o", c="C4", s=5, lw=0)
    # ax.plot(*r[:2], alpha=0.25, c="k")
    ax.set(xlim=(-6, 6), ylim=(-6, 6), aspect="equal")

In [None]:
# balance number of auxiliary stations on either side of the station pair to avoid asymmetry bias
# for each station pair and determine the "sidedness"  for all auxiliary stations


def is_aux_station_left(station_A, station_B, station):
    """Determine if the auxiliary station is on the left or right side."""
    AB = station_B - station_A
    mid_point = (station_A + station_B) / 2
    perp_vector = torch.tensor([-AB[1], AB[0]])
    AS_mid = station - mid_point
    cross_product = perp_vector[0] * AS_mid[1] - perp_vector[1] * AS_mid[0]
    return True if cross_product > 0 else False


aux_station_classification = []
for station_B in tqdm(stations_receivers):
    is_left = []
    for station in stations_auxiliary:
        is_left.append(is_aux_station_left(station_A, station_B, station))
    aux_station_classification.append(is_left)
aux_station_classification = torch.tensor(aux_station_classification, dtype=torch.bool)

In [None]:
# balance masks to have the same number of stations on either side
masks_angle_balanced = masks_angle.clone()
for idx, mask in enumerate(masks_angle_balanced):
    # determine number of stations on either side of the station pair
    n_left_stations = torch.sum(aux_station_classification[idx][mask])
    n_right_stations = torch.sum(~aux_station_classification[idx][mask])
    n_to_remove = abs(n_left_stations - n_right_stations)

    if n_left_stations > n_right_stations:
        # find indices of left stations
        is_left = aux_station_classification[idx]
        # of all stations that for which is_left and mask are True,
        # remove n_to_remove random indices
        left_station_idxs = torch.where(is_left & mask)[0]
        remove_idxs = torch.randperm(len(left_station_idxs))[:n_to_remove]
        mask[left_station_idxs[remove_idxs]] = False
    elif n_right_stations > n_left_stations:
        # find indices of right stations
        is_right = ~aux_station_classification[idx]
        # of all stations that for which is_right and mask are True,
        # remove n_to_remove random indices
        right_station_idxs = torch.where(is_right & mask)[0]
        remove_idxs = torch.randperm(len(right_station_idxs))[:n_to_remove]
        mask[right_station_idxs[remove_idxs]] = False


masks_endfire_balanced = masks_endfire.clone()
for idx, mask in enumerate(masks_endfire_balanced):
    n_left_stations = torch.sum(aux_station_classification[idx][mask])
    n_right_stations = torch.sum(~aux_station_classification[idx][mask])
    n_to_remove = abs(n_left_stations - n_right_stations)

    if n_left_stations > n_right_stations:
        # find indices of left stations
        is_left = aux_station_classification[idx]
        # of all stations that for which is_left and mask are True,
        # remove n_to_remove random indices
        left_station_idxs = torch.where(is_left & mask)[0]
        remove_idxs = torch.randperm(len(left_station_idxs))[:n_to_remove]
        mask[left_station_idxs[remove_idxs]] = False
    elif n_right_stations > n_left_stations:
        # find indices of right stations
        is_right = ~aux_station_classification[idx]
        # of all stations that for which is_right and mask are True,
        # remove n_to_remove random indices
        right_station_idxs = torch.where(is_right & mask)[0]
        remove_idxs = torch.randperm(len(right_station_idxs))[:n_to_remove]
        mask[right_station_idxs[remove_idxs]] = False

# plot a few examples to see if this works
fig, axs = plt.subplots(
    2, 5, figsize=(2.5 * 9 * _cm, 9 * _cm), sharex=True, sharey=True
)
fig.subplots_adjust(wspace=0, hspace=0.2)
for ax, station_B, mask in zip(axs[0], stations_receivers, masks_angle_balanced):
    # ax.scatter(*stations_aux[mask].T, marker="o", s=5, lw=0, c="#94A4A2")
    ax.scatter(*station_B, marker="v")
    ax.scatter(*station_A, marker="v")
    ax.scatter(*stations_auxiliary[mask].T, marker="o", c="C4", s=5, lw=0)
    ax.set(xlim=(-6, 6), ylim=(-6, 6), aspect="equal")
for ax, station_B, mask in zip(axs[1], stations_receivers, masks_endfire_balanced):
    ax.scatter(*station_B, marker="v")
    ax.scatter(*station_A, marker="v")
    ax.scatter(*stations_auxiliary[mask].T, marker="o", c="C4", s=5, lw=0)
    ax.set(xlim=(-6, 6), ylim=(-6, 6), aspect="equal")

In [None]:
# load correlations
c1_correlations_filt_both = torch.load(
    "../data/correlations_for_c1_filt_synth_both.pt", weights_only=False
)
c1_correlations_filt_data = torch.load(
    "../data/correlations_for_c1_filt_data.pt", weights_only=False
)

c2_correlations_unstacked_filt_both = torch.load(
    "../data/c2_correlations_unstacked_filt_synth_both.pt", weights_only=False
)

c2_correlations_unstacked_filt_data = torch.load(
    "../data/c2_correlations_unstacked_filt_data.pt", weights_only=False
)

# must match compute_correlations.ipynb
sampling_rate = 5
length_of_oneside = 300
window_length = length_of_oneside

times = torch.arange(0, 2 * length_of_oneside + 1 / sampling_rate, 1 / sampling_rate)
lapse_times_c1 = torch.arange(
    -length_of_oneside, length_of_oneside + 1 / sampling_rate, 1 / sampling_rate
)
lapse_times_c2 = torch.arange(-window_length / 2, window_length / 2, 1 / sampling_rate)

In [None]:
# average causal and anti-causal C2 correlations
c2_correlations_unstacked_filt_both_avg = c2_correlations_unstacked_filt_both.mean(
    dim=0
)
c2_correlations_unstacked_filt_data_avg = c2_correlations_unstacked_filt_data.mean(
    dim=0
)

# auxiliary station averaging
# all-direction selection.
c2_correlations_filt_both_allstack = c2_correlations_unstacked_filt_both_avg.mean(dim=1)
c2_correlations_filt_data_allstack = c2_correlations_unstacked_filt_data_avg.mean(dim=1)

# apply angle selection
c2_correlations_filt_both_anglestack = torch.zeros_like(
    c2_correlations_filt_both_allstack
)
c2_correlations_filt_data_anglestack = torch.zeros_like(
    c2_correlations_filt_data_allstack
)
for i, mask in enumerate(masks_angle_balanced):
    c2_correlations_filt_both_anglestack[i] = c2_correlations_unstacked_filt_both_avg[
        i
    ][mask].mean(dim=0)
    c2_correlations_filt_data_anglestack[i] = c2_correlations_unstacked_filt_data_avg[
        i
    ][mask].mean(dim=0)

# apply endfire-lobe selection
c2_correlations_filt_both_endfirestack = torch.zeros_like(
    c2_correlations_filt_both_allstack
)
c2_correlations_filt_data_endfirestack = torch.zeros_like(
    c2_correlations_filt_data_allstack
)
for i, mask in enumerate(masks_endfire_balanced):
    c2_correlations_filt_both_endfirestack[i] = c2_correlations_unstacked_filt_both_avg[
        i
    ][mask].mean(dim=0)
    c2_correlations_filt_data_endfirestack[i] = c2_correlations_unstacked_filt_data_avg[
        i
    ][mask].mean(dim=0)

In [None]:
c2_zerolag_index = lapse_times_c2.abs().argmin()
c2_focal_spots = torch.stack(
    [
        c2_correlations_filt_data_allstack[..., c2_zerolag_index],
        c2_correlations_filt_data_endfirestack[..., c2_zerolag_index],
        c2_correlations_filt_data_anglestack[..., c2_zerolag_index],
        c2_correlations_filt_both_allstack[..., c2_zerolag_index],
        c2_correlations_filt_both_endfirestack[..., c2_zerolag_index],
        c2_correlations_filt_both_anglestack[..., c2_zerolag_index],
    ]
)

# replace all nans with 0s
c2_focal_spots[torch.isnan(c2_focal_spots)] = 0.0

# remove some outliers in data to stabilise normalisation
n_outliers_to_remove = 10
for fs in c2_focal_spots[:3]:
    for i in range(n_outliers_to_remove):
        fs[np.nanargmax(np.abs(fs))] = torch.nan

for fs in c2_focal_spots:
    fs /= np.nanmax(fs.abs())

c1_zerolag_index = lapse_times_c1.abs().argmin()
c1_focal_spot_synth = c1_correlations_filt_both[..., c1_zerolag_index]
c1_focal_spot_synth[torch.isnan(c1_focal_spot_synth)] = 0.0
c1_focal_spot_synth /= c1_focal_spot_synth.abs().max()

c1_focal_spot_data = c1_correlations_filt_data[..., c1_zerolag_index]
c1_focal_spot_data[master_idx] = torch.nan
c1_focal_spot_data /= np.nanmax(c1_focal_spot_data.abs())

# compute differences to show source effect patterns
focal_spot_diffs_data = c2_focal_spots[:3] - c1_focal_spot_data
focal_spot_diffs_synth = c2_focal_spots[3:] - c1_focal_spot_synth


In [None]:
fig = plt.figure(figsize=(18 * _cm, 1 / 2.5 * 18 * _cm))
gs = GridSpec(
    3,
    18,
    figure=fig,
    height_ratios=[0.5, 1, 1],
    width_ratios=[1] * 18,
    hspace=0.2,
)

sketch_axs = [fig.add_subplot(gs[0, 0 + (_ * 2) : 2 * (_ + 1)]) for _ in range(9)]
focal_spot_axs = [fig.add_subplot(gs[1, 0 + (_ * 3) : 3 * (_ + 1)]) for _ in range(6)]
focal_spot_synth_axs = [
    fig.add_subplot(gs[2, 0 + (_ * 3) : 3 * (_ + 1)]) for _ in range(6)
]
axs = np.array(sketch_axs + focal_spot_axs + focal_spot_synth_axs)

# formatting
for ax in sketch_axs:
    ax.set(xticks=[], yticks=[], xticklabels=[], yticklabels=[])
for ax in focal_spot_axs + focal_spot_synth_axs:
    ax.set(
        xlim=(-4.5, 4.5),
        ylim=(-4.5, 4.5),
        aspect="equal",
        xticks=(-4, 0, 4),
        yticks=(-4, 0, 4),
        xticklabels=[],
        yticklabels=[],
    )
for ax in focal_spot_synth_axs:
    ax.set_xlabel("Distance [km]", labelpad=0, fontsize=7)
    ax.set_xticklabels([-4, 0, 4], fontsize=7)

# create axes collections corresponding to stacking stragies
H = 6
angle_stack_axs = np.array([axs[3 + H : 5 + H], axs[9 + H : 11 + H]]).flat
endfire_stack_axs = np.array([axs[5 + H : 7 + H], axs[11 + H : 13 + H]]).flat
full_stack_axs = np.array([axs[7 + H : 9 + H], axs[13 + H : 15 + H]]).flat

# assign focal spott and diff axs
fs_axs = np.array(
    [
        angle_stack_axs[0],
        endfire_stack_axs[0],
        full_stack_axs[0],
        angle_stack_axs[2],
        endfire_stack_axs[2],
        full_stack_axs[2],
    ]
).flat
diff_axs = np.array(
    [
        angle_stack_axs[1],
        endfire_stack_axs[1],
        full_stack_axs[1],
        angle_stack_axs[3],
        endfire_stack_axs[3],
        full_stack_axs[3],
    ]
).flat

labels = ["c)", "g)", "k)", "e)", "i)", "m)"]
c2_lbls = [r"$C^{all}_2$", r"$C^{endfire}_2$", r"$C^{angle}_2$"] * 2

for fs_idx, (ax, fs, label, c2_lbl) in enumerate(
    zip(fs_axs, c2_focal_spots, labels, c2_lbls)
):
    _cmap = cm.broc
    if fs_idx < 3:
        _cmap = cm.vik

    fs /= max(abs(fs))
    ax.scatter(
        *stations_receivers.T,
        c=fs,
        cmap=_cmap,
        s=2.5,
        vmin=-1,
        vmax=1,
        lw=0,
    )
    ax.set_title(f"{label}", loc="left", fontsize=8, pad=1)
    ax.scatter([0], [0], marker="v", ec="k", s=50, lw=1, c="#FFA90E")

    ax.text(
        0.05,
        0.95,
        c2_lbl,
        ha="left",
        va="top",
        transform=ax.transAxes,
        fontsize=7,
        fontweight="bold",
        # path_effects=[
        #     matplotlib.patheffects.withStroke(linewidth=2, foreground="white")
        # ],
        bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="none", alpha=0.7),
    )


c2_diff_lbls = [
    r"$C^{all}_2 - C_1$",
    r"$C^{endfire}_2 -C_1$",
    r"$C^{full}_2 - C_1$",
    r"$C^{all}_2 - C_1$",
    r"$C^{endfire}_2 - C_1$",
    r"$C^{angle}_2 - C_1$",
]
c2_diff_lbls = [
    "c) – a)",
    "g) – a)",
    "k) – a)",
    "e) – b)",
    "i) – b)",
    "m) – b)",
]
for idx, (ax, diff, label, c2_diff_lbl) in enumerate(
    zip(
        diff_axs,
        torch.cat([focal_spot_diffs_data, focal_spot_diffs_synth], dim=0),
        ["d", "h", "l", "f", "j", "n"],
        c2_diff_lbls,
    )
):
    vmin, vmax = -0.5, 0.5
    _cmap = cm.broc
    if idx < 3:
        vmin, vmax = -0.5, 0.5
        _cmap = cm.vik

    sct = ax.scatter(
        *stations_receivers.T,
        c=diff,
        cmap=_cmap,
        s=2.5,
        # vmin=-0.5,
        # vmax=0.5,
        vmin=-1,
        vmax=1,
        lw=0,
    )
    ax.scatter([0], [0], marker="v", ec="k", s=50, lw=1, c="#FFA90E")
    ax.set_title(f"{label})", loc="left", fontsize=8, pad=1)

    ax.text(
        0.05,
        0.95,
        # "Bias",
        c2_diff_lbl,
        ha="left",
        va="top",
        transform=ax.transAxes,
        fontsize=7,
        bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="none", alpha=0.7),
    )

    # move this ax a bit to the left set_position
    x0, y0, w, h = ax.get_position().bounds
    ax.set_position([x0 - 0.01, y0, w, h])  # move left by 0.01

    cbars = []
    if label in ("l", "n"):
        # add colorbar
        x0, y0, w, h = ax.get_position().bounds
        cbar_ax = fig.add_axes([x0 + w + 0.01, y0, 0.0075, h])
        cbar = fig.colorbar(sct, cax=cbar_ax)
        cbar.ax.tick_params(labelsize=6)
        # set ticks to -1, 0, 1
        cbar.set_ticks([-1, 0, 1])
        cbar.set_ticklabels(["-1", "0", "1"])
        if label == "l":
            cbar.set_label("Observed amplitudes", fontsize=5, labelpad=0)
        if label == "n":
            cbar.set_label("Simulated amplitudes", fontsize=5, labelpad=0)

        cbars.append(cbar)


# ---- Sketches of auxiliary station selections ----
# horizontal space between fs_axs[0] and diff_axs[0], compute from their positions and widths
# this means the space from right edge of fs_axs[0] to the left edge of diff_axs[0]
space_between = 0.005
new_width = (
    fs_axs[0].get_position().bounds[2]
    + diff_axs[0].get_position().bounds[2]
    + space_between  # space between the two axes
) / 3

fullstack_mask = np.ones_like(masks_endfire_balanced, dtype=bool)
for idx, _masks in enumerate(
    [fullstack_mask, masks_endfire_balanced, masks_angle_balanced]
):
    current_sketch_axs = sketch_axs[idx * 3 : (idx + 1) * 3]
    # reposition the axes to fit exactly with the axes directly below
    # idx == 0, this would be relative to the position of fs_axs[0] and diff_axs[0]
    x0, y0, w, h = current_sketch_axs[0].get_position().bounds
    x_shift = 0.005
    current_sketch_axs[0].set_position([x0 + x_shift, y0, new_width, h])
    current_sketch_axs[1].set_position(
        [x0 + x_shift + new_width + 0.005, y0, new_width, h]
    )
    current_sketch_axs[2].set_position(
        [x0 + x_shift + 2 * new_width + 0.01, y0, new_width, h]
    )

    # manually picked receiver station indices
    # 1168 nearby, 593 mid, 1437 far
    chosen_station_idxs = [1168, 593, 1437]

    shifts_a = [np.array([-20, 0]), np.array([0, 0]), np.array([20, 0])]
    for sketch_ax, chosen_station_idx, shift_a in zip(
        current_sketch_axs, chosen_station_idxs, shifts_a
    ):
        # additionally, draw the areas that these selections are based on
        # for that, reproduce the polygons used for the selection

        # all-direction "selection"
        station_B = stations_receivers[chosen_station_idx]
        if idx == 0:
            # fill the entire background with k at alpha=0.05
            sketch_ax.add_patch(
                matplotlib.patches.Rectangle(
                    (-6, -6), 12, 12, fill=True, color="k", alpha=0.05
                )
            )

        # endfire lobe selection
        if idx == 1:
            # for the chosen station, compute the polygon again (following the code at the top)
            a = torch.norm(station_A - station_B) / 2
            stations = torch.tensor([[-a, 0], [a, 0]], dtype=torch.float32)
            R = 2 * a
            # measure the angle between the two stations
            d_a_b = torch.norm(stations[0] - stations[1])
            dtheta = torch.linspace(-torch.pi / 2, torch.pi / 2, 721)
            # "directivity pattern" (Roux et al. 2004)
            B = 1 - dtheta**4 / 8 * (R / c) ** 2 * (omega**2 + deltaomega**2 / 12)
            # remove dthetas corresponding to B < 0
            dtheta = dtheta[B >= 0]
            B = B[B >= 0]
            # make sure the polygon is large enough
            B *= 2 * 10
            x, y = B * torch.cos(dtheta), B * torch.sin(dtheta)
            # add an x-mirrored copy
            x = torch.cat([-x, x])
            y = torch.cat([y, y])
            # define the polygon before adding to ax
            # get true theta_0 from the angle between the two stations
            theta_0 = torch.atan2(
                station_B[1] - station_A[1], station_B[0] - station_A[0]
            )
            # rotate x and y coordinates around center to align with the two stations
            _x = x * np.cos(theta_0) - y * np.sin(theta_0)
            _y = x * np.sin(theta_0) + y * np.cos(theta_0)
            # shift to the center of the two stations
            _x += (station_A[0] + station_B[0]) / 2
            _y += (station_A[1] + station_B[1]) / 2
            polygon = matplotlib.patches.Polygon(
                np.stack([_x, _y], axis=1),
                closed=True,
                fill=True,
                facecolor="k",
                edgecolor="None",
                alpha=0.1,
            )
            sketch_ax.add_patch(polygon)

        # fixed angle selection
        elif idx == 2:
            # for the chosen station, compute the polygon again (following the code at the top)
            # define a dense grid of helper points
            __x = torch.linspace(-6, 6, 100)
            __y = torch.linspace(-6, 6, 100)
            __coords = torch.cartesian_prod(__x, __y).float()

            theta = np.radians(20)
            alpha = 1 - np.cos(theta)
            mask_edge = []
            for cc in __coords:
                include_station = False
                d_a_b = torch.norm(station_A - station_B)
                d_aux_a = torch.norm(cc - station_A)
                d_aux_b = torch.norm(cc - station_B)
                # if abs(d_aux_a - d_aux_b) >= (1 - alpha) * d_a_b:
                if abs(abs(d_aux_a - d_aux_b) - (1 - alpha) * d_a_b) <= 0.001:
                    include_station = True
                mask_edge.append(include_station)
            mask_edge = np.array(mask_edge)
            # get the stations that are on the edge of the polygon
            edge_of_polygon = __coords[mask_edge]

            # find the positions for aux stations that would define the outline
            # of the polygon for this fixed-angle definition, i.e.,
            # abs(d_aux_a - d_aux_b) == (1 - alpha) * d_a_b
            # in terms of aux stations

            # take only edge points with x < than station_b x
            edge_of_polygon_left = edge_of_polygon[edge_of_polygon[:, 0] < station_B[0]]
            # sort by y coordinates
            edge_of_polygon_left = edge_of_polygon_left[
                edge_of_polygon_left[:, 1].argsort()
            ]

            polygon = matplotlib.patches.Polygon(
                edge_of_polygon_left,
                closed=True,
                fill=True,
                facecolor="k",
                edgecolor="None",
                alpha=0.1,
                # lw=0,
            )
            sketch_ax.add_patch(polygon)

            # same for right
            edge_of_polygon_right = edge_of_polygon[
                edge_of_polygon[:, 0] >= station_B[0]
            ]
            # sort by y coordinates
            edge_of_polygon_right = edge_of_polygon_right[
                edge_of_polygon_right[:, 1].argsort()
            ]
            polygon = matplotlib.patches.Polygon(
                edge_of_polygon_right,
                closed=True,
                fill=True,
                facecolor="k",
                edgecolor="None",
                alpha=0.1,
                # lw=0,
            )
            sketch_ax.add_patch(polygon)

        x0, y0, w, h = sketch_ax.get_position().bounds
        sketch_ax.scatter(*(station_A), s=30, marker="v", ec="k", lw=0.5, c="#FFA90E")
        sketch_ax.scatter(
            *(stations_auxiliary[_masks[chosen_station_idx]]).T,
            marker="o",
            s=1,
            lw=0,
            c="#832DB6",
            clip_on=False,
        )
        sketch_ax.scatter(
            *(stations_receivers[chosen_station_idx]),
            s=20,
            marker="v",
            ec="k",
            lw=0.5,
            c="#3F90DA",
        )

        sketch_ax.set(xlim=(-6, 6), ylim=(-6, 6), aspect="equal")
        sketch_ax.axis("off")  # turn off the axes

sketch_axs[1].set_title("All directions", fontsize=8, pad=4)
sketch_axs[4].set_title("Endfire lobes", fontsize=8, pad=4)
sketch_axs[7].set_title("Fixed angle", fontsize=8, pad=4)

# move all sketch_axs a bit up to reduce vertical space
for sketch_ax in sketch_axs:
    x0, y0, w, h = sketch_ax.get_position().bounds
    sketch_ax.set_position([x0, y0 + 0.03, w, h])

# move the bottom row of axs a tiny bit down
for ax in focal_spot_synth_axs:
    x0, y0, w, h = ax.get_position().bounds
    ax.set_position([x0, y0 - 0.01, w, h])

# move second cbar down, too
x0, y0, w, h = cbars[0].ax.get_position().bounds
cbars[0].ax.set_position([x0, y0 - 0.01, w, h])  # move down by 0.01

# add two axis to the very left of the left-most axes showing the C1 focal spots
x0, y0, w, h = focal_spot_axs[0].get_position().bounds
C1data_ax = fig.add_axes([x0 - (w + 0.03), y0, w, h])
C1data_ax.scatter(
    *stations_receivers.T,
    c=c1_focal_spot_data,
    cmap=cm.vik,
    s=2.5,
    vmin=-1,
    vmax=1,
    lw=0,
)
x0, y0, w, h = focal_spot_synth_axs[0].get_position().bounds
C1synth_ax = fig.add_axes([x0 - (w + 0.03), y0, w, h])
C1synth_ax.scatter(
    *stations_receivers.T,
    c=c1_focal_spot_synth,
    cmap=cm.broc,
    s=2.5,
    vmin=-1,
    vmax=1,
    lw=0,
)

# formatting
for ax in (C1data_ax, C1synth_ax):
    ax.scatter(*(station_A), s=30, marker="v", ec="k", lw=0.5, c="#FFA90E")
    ax.set(
        xlim=(-4.5, 4.5),
        ylim=(-4.5, 4.5),
        aspect="equal",
        xticks=(-4, 0, 4),
        yticks=(-4, 0, 4),
        # xticklabels=[-4, 0, 4],
        yticklabels=[-4, 0, 4],
    )
    ax.text(
        0.05,
        0.95,
        r"$C_1$",
        ha="left",
        va="top",
        transform=ax.transAxes,
        fontsize=7,
        bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="none", alpha=0.7),
    )
C1data_ax.set_title("a)", loc="left", fontsize=8, pad=1)
C1data_ax.set_ylabel("Distance [km]", labelpad=0, fontsize=7)
C1data_ax.set_yticklabels([-4, 0, 4], fontsize=7)
C1data_ax.set_xticklabels([])

C1synth_ax.set_title("b)", loc="left", fontsize=8, pad=1)
C1synth_ax.set_ylabel("Distance [km]", labelpad=0, fontsize=7)
C1synth_ax.set_xlabel("Distance [km]", labelpad=0, fontsize=7)
C1synth_ax.set_xticklabels([-4, 0, 4], fontsize=7)

fig.savefig("../figures/figure7.png", dpi=300, bbox_inches="tight")