In [None]:
from pathlib import Path

import dask.array as da
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pywt
import torch
from cmcrameri import cm
from scipy.signal import butter, filtfilt, hilbert
from shapely.geometry import Point, Polygon

In [None]:
medium_velocity = 1.9
length_of_oneside = 120
# 10 Hz for smooth velocity measurements
sampling_rate = 10

# place receiver stations
# increase grid_spacing if you can't run the notebook due to system constraints.
grid_spacing = 0.5
x_coords = torch.arange(-35, 35 + grid_spacing, grid_spacing)
y_coords = torch.arange(-35, 35 + grid_spacing, grid_spacing)
array_stations = grid_points = torch.cartesian_prod(x_coords, y_coords).float()

# place auxiliary stations
auxiliary_radius = 75
n_aux = 180
auxiliary_angles = np.linspace(0, 2 * np.pi, n_aux, endpoint=False)
aux_stations = torch.tensor(
    np.stack(
        [
            auxiliary_radius * np.cos(auxiliary_angles),
            auxiliary_radius * np.sin(auxiliary_angles),
        ],
        axis=1,
    )
).float()

# define sources
n_boundary_sources = 180
boundary_source_radius = 100
n_cluster_sources = 6
cluster_spread = 0

boundary_source_angles = np.arange(0, 2 * np.pi, 2 * np.pi / n_boundary_sources)
boundary_sources = torch.tensor(
    np.stack(
        [
            boundary_source_radius * np.cos(boundary_source_angles),
            boundary_source_radius * np.sin(boundary_source_angles),
        ],
        axis=1,
    )
)

angle = 0.8 * np.pi
x_center, y_center = (
    boundary_source_radius * np.cos(angle),
    boundary_source_radius * np.sin(angle),
)

cluster_sources = (
    torch.rand(n_cluster_sources, 2) * cluster_spread - cluster_spread / 2
) + torch.tensor([x_center, y_center])

sources = torch.cat([boundary_sources, cluster_sources], dim=0).float()

# compute distances between sources and stations
distances_array = torch.cdist(sources, array_stations)
distances_aux = torch.cdist(sources, aux_stations)
travel_times_array = distances_array / medium_velocity
travel_times_aux = distances_aux / medium_velocity

# define time
times = torch.arange(0, 2 * length_of_oneside + 1 / sampling_rate, 1 / sampling_rate)
freqs = torch.fft.fftfreq(times.shape[0], d=1 / sampling_rate)
omega = 2 * np.pi * freqs
lapse_times_c1 = torch.arange(
    -length_of_oneside, length_of_oneside + 1 / sampling_rate, 1 / sampling_rate
)

# define wavelet to have same length as times and wavelet in center
# same length required for fft-convolution
wavelet_length = 12 * sampling_rate
wavelet_long = torch.zeros_like(times)
wavelet = pywt.ContinuousWavelet("mexh")
wavelet_time = torch.arange(-wavelet_length // 2, wavelet_length // 2)
wavelet, _ = wavelet.wavefun(length=wavelet_length)
wavelet = torch.tensor(wavelet)
# expand wavelet to length of travel times, keep it centered
wavelet_long[
    wavelet_long.shape[0] // 2 - wavelet_length // 2 : wavelet_long.shape[0] // 2
    + wavelet_length // 2
] = wavelet

wavelet_long_freqs = torch.fft.fftfreq(
    2 * length_of_oneside * sampling_rate + 1, d=1 / sampling_rate
)
wavelet_long_spectrum = torch.fft.fft(torch.fft.fftshift(wavelet_long))

# visualise final wavelet
fig, axs = plt.subplots(2)
axs[0].plot(times, wavelet_long)
axs[0].set_xlabel("Time (s)")
axs[1].plot(
    torch.fft.fftshift(wavelet_long_freqs),
    torch.fft.fftshift(torch.abs(wavelet_long_spectrum)),
)
axs[1].set_xlim(0, 1)
axs[1].set_xlabel("Frequency (Hz)")

# plot
fig, ax = plt.subplots()
ax.scatter(*array_stations.T, s=1, label="Array Stations")
ax.scatter(*sources.T, s=1, label="Sources", color="C3")
ax.scatter(*aux_stations.T, s=1, label="Auxiliary Stations", color="C4")
ax.set_aspect("equal")

In [None]:
# move stuff to dask
omega_da = da.from_array(omega.numpy(), chunks=-1)
travel_times_master = travel_times_array[
    :, np.argmin(np.linalg.norm(array_stations, axis=1))
]
travel_times_master_da = da.from_array(travel_times_master.numpy(), chunks=25)
travel_times_array_da = da.from_array(travel_times_array.numpy(), chunks=25)
travel_times_aux_da = da.from_array(travel_times_aux.numpy(), chunks=25)
wavelet_long_freq_da = da.from_array(wavelet_long_spectrum.numpy(), chunks=-1)

master_idx = np.argmin(np.linalg.norm(array_stations, axis=1))

# go
greens_master_da = (
    da.exp(-1j * omega_da[None, :] * travel_times_master_da[:, None])
    * wavelet_long_freq_da[None, :]
)
greens_array_da = (
    da.exp(-1j * omega_da[None, None, :] * travel_times_array_da[:, :, None])
    * wavelet_long_freq_da[None, None, :]
)
greens_aux_da = (
    da.exp(-1j * omega_da[None, None, :] * travel_times_aux_da[:, :, None])
    * wavelet_long_freq_da[None, None, :]
)

covariances_for_c2 = da.einsum(
    "srw,saw->raw",
    greens_array_da,
    greens_aux_da.conj(),
)
# last dimension chunk must be -1 for ifft
# the other numers are tuned to my system
correlations_for_c2 = da.fft.fftshift(
    da.fft.ifft(covariances_for_c2, axis=-1).real, axes=-1
)

covariances_for_c1 = da.einsum(
    "srw,sw->rw",
    greens_array_da,
    greens_array_da[:, master_idx, :].conj(),
)
# last dimension chunk must be -1 for ifft
# the other numers are tuned to my system
correlations_for_c1 = da.fft.fftshift(
    da.fft.ifft(covariances_for_c1, axis=-1).real, axes=-1
)


correlations_for_c2 = torch.tensor(correlations_for_c2.compute())
correlations_for_c1 = torch.tensor(correlations_for_c1.compute())

In [None]:
# PREP FOR C2
lapse_times_c1 = torch.arange(
    -length_of_oneside, length_of_oneside + 1 / sampling_rate, 1 / sampling_rate
)
c1_idx0 = torch.argmin(lapse_times_c1.abs())
window_length = length_of_oneside

start_idxs = torch.tensor([c1_idx0 - (window_length * sampling_rate), c1_idx0])
end_idxs = start_idxs + int(window_length * sampling_rate + 1)

# init empty tensor with dimensions (n_windows, n_array, n_aux, n_samples_in_window)
correlations_for_c2_master_windows = torch.empty(
    len(start_idxs),
    correlations_for_c2.shape[1],
    window_length * sampling_rate + 1,
)

# same for array stations
correlations_for_c2_array_windows = torch.empty(
    len(start_idxs),
    *correlations_for_c2.shape[:-1],
    window_length * sampling_rate + 1,
)

for win_idx, (start_idx, end_idx) in enumerate(zip(start_idxs, end_idxs)):
    # print(win_idx, start_idx, end_idx)
    correlations_for_c2_master_windows[win_idx] = correlations_for_c2[
        master_idx, :, start_idx:end_idx
    ]

    correlations_for_c2_array_windows[win_idx] = correlations_for_c2[
        :, :, start_idx:end_idx
    ]

In [None]:
# C2
correlations_for_c2_master_windows_da = da.from_array(
    correlations_for_c2_master_windows.float().numpy(), chunks=[25, 25, -1]
)
correlations_for_c2_array_windows_da = da.from_array(
    correlations_for_c2_array_windows.float().numpy(),
    chunks=[25, 25, 25, -1],
)

# t: number of time windows (causal + anti-causal)
# r: number of receiver stations
# a: number of auxiliary stations
# w: number of samples in window
c2_covariances_unstacked = da.einsum(
    "taw,traw->traw",
    da.fft.fft(correlations_for_c2_master_windows_da, axis=-1).conj(),
    da.fft.fft(correlations_for_c2_array_windows_da, axis=-1),
)
c2_correlations_unstacked = da.fft.fftshift(
    da.fft.ifft(c2_covariances_unstacked, axis=-1).real,
    axes=-1,
)
c2_correlations_unstacked = torch.tensor(c2_correlations_unstacked.compute())

c2_correlations_unstacked.shape

In [None]:
# filter broadband around 0.3
fmin, fmax = 0.3 / 2, 0.3 * 2
taper = torch.hann_window(window_length * sampling_rate + 1)
b, a = butter(4, [fmin, fmax], btype="band", fs=sampling_rate)

# acausal filtering
c1_correlations_filt = torch.tensor(filtfilt(b, a, correlations_for_c1, axis=-1).copy())
c2_correlations_unstacked_filt = torch.tensor(
    filtfilt(b, a, c2_correlations_unstacked, axis=-1).copy()
)

In [None]:
# average causal and anti-causal parts
c2_correlations_unstacked_filt_avg = c2_correlations_unstacked_filt.mean(dim=0)

# all-direction selection.
c2_correlations_filt = c2_correlations_unstacked_filt_avg.mean(dim=1)

In [None]:
# instead, do endfire-lobe based selection
# select endfire-lobe style


# Roux et al. 2004

f = 0.3
omega = 2 * torch.pi * f
deltaomega = omega * 0.1
c = medium_velocity

results = []
# stations_all = cart_coords[station_idxs_array]
stations_all = array_stations
# station_A = cart_coords[master_idx][0]  # master station
station_A = torch.tensor([0.0, 0.0], dtype=torch.float32)  # master station
stations_aux = aux_stations
masks_endfire = []
for station_B in stations_all:
    # distance_from_center is half distance between the two stations
    a = distance_from_center = torch.norm(station_A - station_B) / 2
    stations = torch.tensor([[-a, 0], [a, 0]], dtype=torch.float32)
    R = 2 * a

    # measure the angle between the two stations
    d_a_b = torch.norm(stations[0] - stations[1])
    dtheta = torch.linspace(-torch.pi / 2, torch.pi / 2, 721)

    # "directivity pattern"
    B = 1 - dtheta**4 / 8 * (R / c) ** 2 * (omega**2 + deltaomega**2 / 12)

    # remove dthetas corresponding to B < 0
    dtheta = dtheta[B >= 0]
    B = B[B >= 0]

    # make sure the polygon is large enough
    B *= 2 * 50

    x, y = B * torch.cos(dtheta), B * torch.sin(dtheta)

    # add an x-mirrored copy
    x = torch.cat([-x, x])
    y = torch.cat([y, y])
    # results.append([x, y, theta_0])

    # get true theta_0 from the angle between the two stations
    theta_0 = torch.atan2(station_B[1] - station_A[1], station_B[0] - station_A[0])
    # rotate x and y coordinates around center to align with the two stations
    _x = x * np.cos(theta_0) - y * np.sin(theta_0)
    _y = x * np.sin(theta_0) + y * np.cos(theta_0)

    # shift to the center of the two stations
    _x += (station_A[0] + station_B[0]) / 2
    _y += (station_A[1] + station_B[1]) / 2

    polygon = Polygon(zip(_x, _y))
    mask = torch.tensor(
        [polygon.contains(Point(p)) for p in stations_aux], dtype=torch.bool
    )
    masks_endfire.append(mask)

    results.append([_x, _y, mask])

masks_endfire = torch.stack(masks_endfire, dim=0)

c2_correlations_endfire_filt = torch.zeros_like(c2_correlations_filt)
for i, mask in enumerate(masks_endfire):
    c2_correlations_endfire_filt[i] = c2_correlations_unstacked_filt_avg[i][mask].mean(
        dim=0
    )

In [None]:
# selection based on azimuth
# e.g., Zhang et al. 2020
# thetas = torch.linspace(torch.pi / 16, torch.pi / 2, 6)
# alphas = 1 - torch.cos(thetas)

# Zhang et al. 2020 uses a fixed angle of 8 degrees, or alpha=0.01
theta = np.radians(8)
# Nouibat et al. 2022 uses a fixed angle of 20 degrees, or alpha=0.06
theta = np.radians(20)

alpha = 1 - np.cos(theta)
# cos θ =1−α
# alphas = [0.01, 0.02, 0.05, 0.1, 0.2, 0.5]

# stations_all = cart_coords[station_idxs_array]
# stations_aux = cart_coords[station_idxs_auxilliary_subset]
# station_A = cart_coords[master_idx][0]  # master station
masks_angle = []
for station_B in stations_all:
    mask = []
    for idx, source in enumerate(stations_aux):
        d_a_b = torch.norm(station_A - station_B)
        d_aux_a = torch.norm(source - station_A)
        d_aux_b = torch.norm(source - station_B)
        include_station = False
        if abs(d_aux_a - d_aux_b) >= (1 - alpha) * d_a_b:
            include_station = True
        mask.append(include_station)
    masks_angle.append(mask)

masks_angle = torch.tensor(masks_angle)

c2_correlations_angle_filt = torch.zeros_like(c2_correlations_filt)
for i, mask in enumerate(masks_angle):
    c2_correlations_angle_filt[i] = c2_correlations_unstacked_filt_avg[i][mask].mean(
        dim=0
    )

In [None]:
### MEASURE GROUP VELOCITY FOR EACH CORRELATION WAVEFIELD AT EVERY STATION

distances_master = np.linalg.norm(array_stations - array_stations[master_idx], axis=1)
lapse_times_c3 = torch.arange(
    -window_length / 2, window_length / 2 + 1 / sampling_rate, 1 / sampling_rate
)


def get_group_velocity(correlations, lapse_times):
    envelopes = np.abs(hilbert(correlations, axis=-1))
    taper = torch.hann_window(envelopes.shape[-1])
    envelopes *= taper.numpy()
    picked_time_idx = np.argmax(envelopes, axis=-1)
    picked_time = abs(lapse_times[picked_time_idx])
    group_velocity = distances_master / picked_time
    return group_velocity


group_velocities_c2_angle = get_group_velocity(
    c2_correlations_angle_filt, lapse_times_c3
)
group_velocity_anomalies_c2_angle = (
    100 * (group_velocities_c2_angle - medium_velocity) / medium_velocity
)

group_velocities_c2_endfire = get_group_velocity(
    c2_correlations_endfire_filt, lapse_times_c3
)
group_velocity_anomalies_c2_endfire = (
    100 * (group_velocities_c2_endfire - medium_velocity) / medium_velocity
)

group_velocities_c2_all = get_group_velocity(c2_correlations_filt, lapse_times_c3)
group_velocity_anomalies_c2_all = (
    100 * (group_velocities_c2_all - medium_velocity) / medium_velocity
)

group_velocities_c1 = get_group_velocity(c1_correlations_filt, lapse_times_c1)
group_velocity_anomalies_c1 = (
    100 * (group_velocities_c1 - medium_velocity) / medium_velocity
)

In [None]:
plt.style.use(Path("../meta/colorblind_friendly.mplstyle"))
matplotlib.rcParams.update({"font.family": "serif", "font.size": 8})
_cm = 1 / 2.54  # cm to inches

lapse_times_c2 = torch.arange(
    -window_length / 2, window_length / 2 + 1 / sampling_rate, 1 / sampling_rate
)
# coordinates
idx_closest_to_0 = array_stations.abs().argmin()

# select a box from which to compute amplitude peaks for normalisation
array_selection = torch.where(
    (array_stations[:, 0] >= -20)
    & (array_stations[:, 0] <= -10)
    & (array_stations[:, 1] >= 10)
    & (array_stations[:, 1] <= 20)
)
# maybe [0] is needed here?
time_idxs_c2 = torch.where((lapse_times_c2 >= 9.75) & (lapse_times_c2 < 10.25))[0]
time_idxs_c1 = torch.where((lapse_times_c1 >= 9.75) & (lapse_times_c1 < 10.25))[0]
print(c2_correlations_filt.shape)
c2_all_peak = c2_correlations_filt[..., time_idxs_c2][array_selection].abs().max()
c2_endfire_peak = (
    c2_correlations_endfire_filt[..., time_idxs_c2][array_selection].abs().max()
)
c2_angle_peak = (
    c2_correlations_angle_filt[..., time_idxs_c2][array_selection].abs().max()
)
c1_peak = c1_correlations_filt[..., time_idxs_c1][array_selection].abs().max()

maxima_for_norm = [
    c1_peak,
    c2_all_peak,
    c2_angle_peak,
    c2_endfire_peak,
]

fig, axs = plt.subplots(4, 4, figsize=(18 * _cm, 18 * _cm))
times_to_plot = [-10, 0, 10] * 4
labels = ("a)", "b)", "c)", "e)", "f)", "g)", "i)", "j)", "k)", "m)", "n)", "o)")
for ax, time_to_plot, label in zip(axs[:, :-1].flatten(), times_to_plot, labels):
    if ax in axs[0]:
        focal_spot = c1_correlations_filt[
            :, np.abs(lapse_times_c1 - time_to_plot).argmin()
        ].clone()
        max_norm = maxima_for_norm[0]
    elif ax in axs[1]:
        focal_spot = c2_correlations_filt[
            :, np.abs(lapse_times_c2 - time_to_plot).argmin()
        ].clone()
        max_norm = maxima_for_norm[1]
    elif ax in axs[2]:
        focal_spot = c2_correlations_endfire_filt[
            :, np.abs(lapse_times_c2 - time_to_plot).argmin()
        ].clone()
        max_norm = maxima_for_norm[2]
    elif ax in axs[3]:
        focal_spot = c2_correlations_angle_filt[
            :, np.abs(lapse_times_c2 - time_to_plot).argmin()
        ].clone()
        max_norm = maxima_for_norm[3]

    focal_spot /= max_norm

    pcm_fs = ax.pcolormesh(
        x_coords,
        y_coords,
        focal_spot.reshape(len(x_coords), len(y_coords)).T,
        cmap=cm.broc,
        vmin=-1,
        vmax=1,
    )

    ax.set_xlim(-35, 35)
    ax.set_ylim(-35, 35)
    ax.set_yticks([-30, 0, 30])
    ax.set_xticks([-30, 0, 30])
    ax.set_xticklabels([-30, 0, 30], fontsize=10)
    ax.set_yticklabels([-30, 0, 30], fontsize=10)
    ax.set_aspect("equal")
    ax.set_title(f"{label}", loc="left", fontsize=10, pad=4)

    if ax in axs[0]:
        clbl = r"$C_1$"
        ax.set_title(rf"$\tau$ = {time_to_plot} s", fontsize=10, pad=4)
    elif ax in axs[1]:
        clbl = r"$C^{all}_2$"
    elif ax in axs[2]:
        clbl = r"$C^{endfire}_2$"
    elif ax in axs[3]:
        clbl = r"$C^{angle}_2$"

    t = ax.text(
        0.05,
        0.95,
        clbl,
        ha="left",
        va="top",
        transform=ax.transAxes,
        fontsize=10,
        bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="none", alpha=0.7),
    )
    ax.scatter([0], [0], marker="v", ec="k", s=100, lw=1, c="#FFA90E")

for ax in axs[-1, :].flat:
    ax.set_xlabel("Distance [km]", labelpad=0, fontsize=10)
for ax in axs[:, 0]:
    ax.set_ylabel("Distance [km]", labelpad=0, fontsize=10)

for ax in axs[:3, :-1].flat:
    ax.set_xticklabels([])
for ax in axs[:3, 1:-1].flat:
    ax.set_yticklabels([])

### group vel errors

for i, (group_velocity_anomalies, lbl) in enumerate(
    zip(
        [
            group_velocity_anomalies_c1,
            group_velocity_anomalies_c2_all,
            group_velocity_anomalies_c2_endfire,
            group_velocity_anomalies_c2_angle,
        ],
        ["d)", "h)", "l)", "p)"],
    )
):
    ax = axs[i, -1]
    pcm = ax.pcolormesh(
        x_coords,
        y_coords,
        group_velocity_anomalies.reshape(x_coords.shape[0], y_coords.shape[0]).T,
        vmin=-15,
        vmax=15,
        # cmap=cm.cork,
        cmap=cm.bam_r,
    )
    ax.set_aspect("equal")
    ax.set_xlim(-35, 35)
    ax.set_ylim(-35, 35)
    ax.set_aspect("equal")
    ax.set_yticks([-30, 0, 30])
    ax.set_xticks([-30, 0, 30])
    ax.set_xticklabels([-30, 0, 30], fontsize=10)
    ax.set_yticklabels([-30, 0, 30], fontsize=10)
    ax.set_title(lbl, loc="left", fontsize=10, pad=4)

    ax.scatter([0], [0], marker="v", ec="k", s=100, lw=1, c="#FFA90E")

# add colorbar to the last column
x0, y0, w0, h0 = axs[-2, -1].get_position().bounds
x1, y1, w1, h1 = axs[-1, -1].get_position().bounds
vertical_space_between_both = y0 - (y1 + h1)
x0, y0, w0, h0 = axs[0, -1].get_position().bounds
cbar_ax = fig.add_axes([x0 + w0 + 0.03, y0 + h0 / 2 - 0.02, w0 / 1.5, 0.01])
cbar = fig.colorbar(
    pcm,
    cax=cbar_ax,
    orientation="horizontal",
    extend="both",
    ticks=[-15, 0, 15],
)
cbar.ax.tick_params(labelsize=8)
cbar.ax.set_xticklabels(["-15", "0", "15"], fontsize=8)
cbar.ax.set_xlabel("Group velocity\nerror [%]", fontsize=8, labelpad=4)

# another colorbar for pcm_fs
cbar_ax2 = fig.add_axes([x0 + w0 + 0.03, y0 + h0 - 0.02, w0 / 1.5, 0.01])
cbar2 = fig.colorbar(
    pcm_fs,
    cax=cbar_ax2,
    orientation="horizontal",
    extend="both",
    ticks=[-1, 0, 1],
)
cbar2.ax.tick_params(labelsize=8)
cbar2.ax.set_xticklabels(["-1", "0", "1"], fontsize=8)
cbar2.ax.set_xlabel("Simulated\namplitudes", fontsize=8, labelpad=4)

# cleanup
for ax in axs[:, -1]:
    ax.set_yticklabels([])
    if ax in axs[:-1, -1]:
        ax.set_xticklabels([])
    if ax == axs[-1, -1]:
        ax.set_xlabel("Distance [km]", labelpad=0, fontsize=10)
for ax in axs[-1, 1:]:
    ax.set_yticklabels([])

fig.savefig("../figures/figure9.png", dpi=300, bbox_inches="tight")
