In [None]:
from pathlib import Path

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import polars as pl
import torch
from cmcrameri import cm
from matplotlib.gridspec import GridSpec

In [None]:
plt.style.use(Path("../meta/colorblind_friendly.mplstyle"))
matplotlib.rcParams.update({"font.family": "serif", "font.size": 8})


# station locations
stationlist_all = pl.read_csv("../meta/stations_all.csv")
stations_all = torch.tensor(
    np.vstack((stationlist_all["X"].to_numpy(), stationlist_all["Y"].to_numpy())).T
)

# receiver stations
stationlist_rcv = pl.read_csv("../meta/stations_receivers.csv")
names_receivers = stationlist_rcv["station"]
stations_receivers = torch.tensor(
    np.vstack((stationlist_rcv["X"].to_numpy(), stationlist_rcv["Y"].to_numpy())).T
)

# auxiliary stations
stationlist_aux = pl.read_csv("../meta/stations_auxiliary.csv")
stations_auxiliary = torch.tensor(
    np.vstack((stationlist_aux["X"].to_numpy(), stationlist_aux["Y"].to_numpy())).T
)

master_idx = names_receivers.to_numpy().tolist().index("OMV.GDT")

_cm = 1 / 2.54

fig, ax = plt.subplots(figsize=(1.45 * 9 * _cm, 9 * _cm))
ax.scatter(*stations_all.T, s=5, lw=0, c="#CCC")
ax.scatter(*stations_receivers.T, s=5, lw=0)
ax.scatter(*stations_receivers[master_idx].T, s=100, marker="v", lw=0.5, ec="k")
ax.scatter(*stations_auxiliary.T, s=5, lw=0, c="C4")

ax.set(xlim=(-11, 11), ylim=(-8, 7), aspect="equal")

# manually chosen auxiliary stations
aux_names = ["OMV.AGF", "OMV.BNT", "OMV.BSE", "OMV.BUK"]
station_idxs_auxilliary = [
    np.where(stationlist_aux["station"] == name)[0][0] for name in aux_names
]
# indices of the chosen stations in the auxiliary stations list

# shift positions list contents by one, wrapping around
station_idxs_in_subset = [
    station_idxs_auxilliary[1],
    station_idxs_auxilliary[2],
    station_idxs_auxilliary[3],
    station_idxs_auxilliary[0],
]
# plot station_idxs_in_subset
ax.scatter(
    *stations_auxiliary[station_idxs_in_subset].T,
    s=200,
    lw=0,
    c="C4",
)

In [None]:
c2_correlations_unstacked_filt_both = torch.load(
    "../data/c2_correlations_unstacked_filt_synth_both.pt", weights_only=False
)

c2_correlations_unstacked_filt_boundary = torch.load(
    "../data/c2_correlations_unstacked_filt_synth_boundary.pt", weights_only=False
)

c2_correlations_unstacked_filt_isolated = torch.load(
    "../data/c2_correlations_unstacked_filt_synth_isolated.pt", weights_only=False
)

c2_correlations_unstacked_filt_data = torch.load(
    "../data/c2_correlations_unstacked_filt_data.pt", weights_only=False
)

# must match compute_correlations.ipynb
sampling_rate = 5
length_of_oneside = 300
window_length = length_of_oneside
lapse_times_c2 = torch.arange(-window_length / 2, window_length / 2, 1 / sampling_rate)

In [None]:
# average causal and anti-causal C2 correlations
c2_correlations_unstacked_filt_both = c2_correlations_unstacked_filt_both.mean(dim=0)
c2_correlations_unstacked_filt_boundary = c2_correlations_unstacked_filt_boundary.mean(
    dim=0
)
c2_correlations_unstacked_filt_isolated = c2_correlations_unstacked_filt_isolated.mean(
    dim=0
)

c2_correlations_filt_data = c2_correlations_unstacked_filt_data.mean(dim=0)

In [None]:
# coordinates
idx_closest_to_0 = np.abs(stations_receivers[:, 0]).argmin()

c2_zerolag_index = np.abs(lapse_times_c2).argmin()
# center_corr_data = c3_correlation_filtered_unstack[..., c2_zerolag_index]
c2_focal_spots_data = c2_correlations_filt_data[..., c2_zerolag_index]
c2_focal_spots_both = c2_correlations_unstacked_filt_both[..., c2_zerolag_index]
c2_focal_spots_boundary = c2_correlations_unstacked_filt_boundary[..., c2_zerolag_index]
c2_focal_spots_isolated = c2_correlations_unstacked_filt_isolated[..., c2_zerolag_index]

fig = plt.figure(figsize=(18 * _cm, 18 * _cm))
gs = GridSpec(
    5, 5, figure=fig, width_ratios=[0.33, 1, 1, 1, 1], height_ratios=[0.33, 1, 1, 1, 1]
)
axs = np.array([fig.add_subplot(gs[r, c]) for r in range(5) for c in range(5)]).reshape(
    5, 5
)

labels = [f"{c})" for c in "abcdefghijklmnop"]
focal_spot_axs = axs[1:, 1:]  # focal spot axes

for ax, aux_idx, label in zip(
    focal_spot_axs.flat,
    station_idxs_in_subset * 4,
    labels,
):
    # for ax, time_to_plot, label in zip(axs.flatten(), times_to_plot, labels):
    if ax in focal_spot_axs[3, :]:
        focal_spot = c2_focal_spots_data[:, aux_idx].clone()
    elif ax in focal_spot_axs[2, :]:
        focal_spot = c2_focal_spots_both[:, aux_idx].clone()
    elif ax in focal_spot_axs[1, :]:
        focal_spot = c2_focal_spots_boundary[:, aux_idx].clone()
    elif ax in focal_spot_axs[0, :]:
        focal_spot = c2_focal_spots_isolated[:, aux_idx].clone()

    focal_spot[idx_closest_to_0] = 0
    focal_spot[torch.isnan(focal_spot)] = 0

    # remove outliers for data only
    # 10 highest values for normalisation stability
    if ax in focal_spot_axs[3, :]:
        n_to_remove = 10
        for _ in range(n_to_remove):
            idx_max = focal_spot.abs().argmax()
            focal_spot[idx_max] = 0

    # normalise amplitudes
    focal_spot /= focal_spot.abs().max()

    _cmap = cm.broc
    if ax in focal_spot_axs[3, :]:
        _cmap = cm.vik

    sct = ax.scatter(
        *stations_receivers.T,
        c=focal_spot,
        cmap=_cmap,
        s=5,
        vmin=-1,
        vmax=1,
        lw=0,
    )

    ax.set_title(f"{label}", loc="left", fontsize=10, pad=2)
    ax.scatter([0], [0], marker="v", ec="k", s=80, lw=1, c="#FFA90E")

    # if ax is on the rightmost side, add colorbar
    if ax in focal_spot_axs[:, -1]:
        x0, y0, width, height = ax.get_position().bounds
        cax = fig.add_axes([x0 + width + 0.005, y0, 0.0075, height])
        cbar = fig.colorbar(sct, cax=cax)
        cbar.ax.tick_params(labelsize=8)
        prefix = "Observed" if ax in focal_spot_axs[3, :] else "Simulated"
        cbar.set_label(
            f"{prefix}\n" + r"$C_2$" + " amplitudes", fontsize=8, labelpad=-2
        )
        cbar.set_ticks([-1, 0, 1])

# put labels on axs on the left and bottom
for ax in focal_spot_axs.flat:
    ax.set_xlim(-4.5, 4.5)
    ax.set_ylim(-4.5, 4.5)
    ax.set_aspect("equal")
    ax.set_yticks([-4, 0, 4])
    ax.set_xticks([-4, 0, 4])
for ax in focal_spot_axs[-1]:
    ax.set_xlabel("Distance [km]", labelpad=0, fontsize=10)
    ax.set_xticklabels([-4, 0, 4], fontsize=10)
for ax in focal_spot_axs[:, 0]:
    ax.set_ylabel("Distance [km]", labelpad=0, fontsize=10)
    ax.set_yticklabels([-4, 0, 4], fontsize=10)
# for all other axes, hide the ticklabels
for ax in focal_spot_axs[:-1, :].flat:
    ax.set_xticklabels([])
for ax in focal_spot_axs[:, 1:].flat:
    ax.set_yticklabels([])

# draw auxiliary station layouts on top
aux_layout_axs = axs[:1, 1:]
for auxaxidx, ax in zip(station_idxs_in_subset, aux_layout_axs.flat):
    ax.axis("off")
    ax.scatter(*stations_receivers.T, s=0.25, lw=0, c="#3F90DA")
    ax.scatter(
        *stations_receivers[master_idx].T, marker="v", c="#FFA90E", ec="k", s=75, lw=0.5
    )
    ax.scatter(
        *stations_auxiliary[auxaxidx].T,
        s=75,
        lw=0.5,
        ec="k",
        marker="v",
        c="#832DB6",
        clip_on=False,
    )
    ax.set_aspect("equal")

# the following matches compute_correlations.ipynb
torch.manual_seed(42)
n_boundary_sources = 100
boundary_source_angles = np.linspace(
    0, 2 * np.pi - 2 * np.pi / n_boundary_sources, n_boundary_sources
)
boundary_source_radius = 50
boundary_sources = torch.tensor(
    np.stack(
        [
            boundary_source_radius * np.cos(boundary_source_angles),
            boundary_source_radius * np.sin(boundary_source_angles),
        ],
        axis=1,
    )
)

# additional cluster of sources in the northwest of the boundary sources
n_cluster_sources = 25
cluster_spread = 25
x_center, y_center = (
    boundary_source_radius * np.cos(0.8 * np.pi),
    boundary_source_radius * np.sin(0.8 * np.pi),
)
cluster_sources = (
    torch.rand(n_cluster_sources, 2) * cluster_spread - cluster_spread / 2
) + torch.tensor([x_center, y_center])

all_sources = torch.cat([boundary_sources, cluster_sources], dim=0).float()

# draw source layouts on left
source_layout_axs = axs[1:, 0]
for ax, sources, label in zip(
    source_layout_axs.flat,
    (cluster_sources, boundary_sources, all_sources, all_sources),
    ("Isolated\nsources", "Boundary\nsources", "All\nsources", "Unknown\nsources"),
):
    ax.axis("off")
    ax.scatter(*stations_receivers.T, s=0.25, lw=0, c="#3F90DA")
    if ax in source_layout_axs[:3]:
        ax.scatter(
            *sources.T,
            s=5,
            lw=0,
            marker="*",
            c="#717581",
            label="boundary sources",
        )
    else:
        # -> field data
        # plot randomly distributed question marks instead
        # align them loosely along a circle
        circle_angles = torch.rand(7) * 2 * np.pi
        circle_x = boundary_source_radius * torch.cos(circle_angles)
        circle_y = boundary_source_radius * torch.sin(circle_angles)
        for _ in range(7):
            ax.text(
                circle_x[_],
                circle_y[_],
                "?",
                fontsize=10,
                ha="center",
                va="center",
                color="#717581",
                fontfamily="serif",
            )

    ax.set_title(label, fontsize=10, pad=7)
    ax.set_xlim(-55, 55)
    ax.set_ylim(-55, 55)
    ax.set_aspect("equal")

    # move to the left
    x0, y0, width, height = ax.get_position().bounds
    ax.set_position([x0 - 0.07, y0, width, height])

# turn off unused top-left axis
axs[0, 0].axis("off")

fig.savefig("../figures/figure6.png", dpi=300, bbox_inches="tight")