In [None]:
from pathlib import Path

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import polars as pl
import torch
from cmcrameri import cm

In [None]:
plt.style.use(Path("../meta/colorblind_friendly.mplstyle"))
matplotlib.rcParams.update({"font.family": "serif", "font.size": 8})


# station locations
stationlist_all = pl.read_csv("../meta/stations_all.csv")
stations_all = torch.tensor(
    np.vstack((stationlist_all["X"].to_numpy(), stationlist_all["Y"].to_numpy())).T
)

# receiver stations
stationlist_rcv = pl.read_csv("../meta/stations_receivers.csv")
names_receivers = stationlist_rcv["station"]
stations_receivers = torch.tensor(
    np.vstack((stationlist_rcv["X"].to_numpy(), stationlist_rcv["Y"].to_numpy())).T
)

# auxiliary stations
stationlist_aux = pl.read_csv("../meta/stations_auxiliary.csv")
stations_auxiliary = torch.tensor(
    np.vstack((stationlist_aux["X"].to_numpy(), stationlist_aux["Y"].to_numpy())).T
)

master_idx = names_receivers.to_numpy().tolist().index("OMV.GDT")

_cm = 1 / 2.54

fig, ax = plt.subplots(figsize=(1.45 * 9 * _cm, 9 * _cm))
ax.scatter(*stations_all.T, s=5, lw=0, c="#CCC")
ax.scatter(*stations_receivers.T, s=5, lw=0)
ax.scatter(*stations_receivers[master_idx].T, s=100, marker="v", lw=0.5, ec="k")
ax.scatter(*stations_auxiliary.T, s=5, lw=0, c="C4")

ax.set(xlim=(-11, 11), ylim=(-8, 7), aspect="equal")

In [None]:
c2_correlations_unstacked_filt_data = torch.load(
    "../data/c2_correlations_unstacked_filt_data.pt", weights_only=False
)
c2_correlations_unstacked_filt_synth = torch.load(
    "../data/c2_correlations_unstacked_filt_synth_both.pt", weights_only=False
)

# must match compute_correlations.ipynb
sampling_rate = 5
length_of_oneside = 300
window_length = length_of_oneside

times = torch.arange(0, 2 * length_of_oneside + 1 / sampling_rate, 1 / sampling_rate)
lapse_times_c1 = torch.arange(
    -length_of_oneside, length_of_oneside + 1 / sampling_rate, 1 / sampling_rate
)
lapse_times_c3 = torch.arange(-window_length / 2, window_length / 2, 1 / sampling_rate)

print(
    c2_correlations_unstacked_filt_data.shape,
    c2_correlations_unstacked_filt_synth.shape,
)

In [None]:
# average causal and anti-causal C2 correlations
c2_correlations_unstacked_filt_data_avg = c2_correlations_unstacked_filt_data.mean(
    dim=0
)
c2_correlations_unstacked_filt_synth_avg = c2_correlations_unstacked_filt_synth.mean(
    dim=0
)

# average over all auxiliary stations
c2_correlations_filt_data = c2_correlations_unstacked_filt_data_avg.mean(dim=1)
c2_correlations_filt_synth = c2_correlations_unstacked_filt_synth_avg.mean(dim=1)

c2_correlations_filt_data.shape, c2_correlations_filt_synth.shape

In [None]:
plt.style.use("colorblind_white")


c3_zerolag_idx = lapse_times_c3.abs().argmin()
# set auto-correlation in data to nan
c2_correlations_filt_data[master_idx] = torch.nan
# remove outliers for normalisation stability
n_to_remove = 5
for _ in range(n_to_remove):
    idx_max = np.nanargmax(c2_correlations_filt_data[..., c3_zerolag_idx].abs())
    c2_correlations_filt_data[idx_max, :] = torch.nan

# coordinates
maxs_for_norm = [
    np.nanmax(c2_correlations_filt_data[..., c3_zerolag_idx].abs()),
    np.nanmax(c2_correlations_filt_synth[..., c3_zerolag_idx].abs()),
]


_cm = 1 / 2.54  # cm to inches
fig, axs = plt.subplots(
    2, 3, figsize=(18 * _cm, 2 / 3 * 18 * _cm), sharex=True, sharey=True
)
# plot data for different time windows
times_to_plot = [-1, 0, 1] * 2
labels = ("a)", "b)", "c)", "d)", "e)", "f)")
for ax, time_to_plot, label in zip(axs.flatten(), times_to_plot, labels):
    ax.set_title(f"{label}", loc="left", fontsize=10, pad=5)
    ax.set_xlim(-4.5, 4.5)
    ax.set_ylim(-4.5, 4.5)
    ax.set_aspect("equal")
    ax.set_yticks([-4, 0, 4])
    ax.set_xticks([-4, 0, 4])
    ax.set_xticklabels([-4, 0, 4], fontsize=10)
    ax.set_yticklabels([-4, 0, 4], fontsize=10)

    if ax in axs[0]:
        focal_spot = c2_correlations_filt_data[
            :, (lapse_times_c3 - time_to_plot).abs().argmin()
        ].clone()
        prefix = "Observed"
        max_for_norm = maxs_for_norm[0]
        _cmap = cm.vik
    else:
        focal_spot = c2_correlations_filt_synth[
            :, (lapse_times_c3 - time_to_plot).abs().argmin()
        ].clone()
        prefix = "Simulated"
        max_for_norm = maxs_for_norm[1]
        _cmap = cm.broc

    # some cleanup
    focal_spot[torch.isnan(focal_spot)] = 0
    if ax in axs[0]:
        for i in range(2):
            idx_max = focal_spot.abs().argmax()
            focal_spot[idx_max] = 0

    # normalise amplitudes
    focal_spot /= max_for_norm

    sct = ax.scatter(
        *stations_receivers.T,
        c=focal_spot,
        cmap=_cmap,
        s=10,
        vmin=-1,
        vmax=1,
        lw=0,
    )

    if time_to_plot == times_to_plot[-1]:
        x0, y0, w, h = ax.get_position().bounds
        cbar_ax = fig.add_axes([x0 + w + 0.01, y0, 0.01, h])
        cbar = fig.colorbar(sct, cax=cbar_ax, orientation="vertical")
        cbar.set_label(f"{prefix} " + r"$C_2$" + " amplitudes", fontsize=10)
        cbar.ax.tick_params(labelsize=10)
        # set ticks to -1, 0, 1
        cbar.set_ticks([-1, 0, 1])
        cbar.set_ticklabels(["-1", "0", "1"])

    t = ax.text(
        0.05,
        0.95,
        rf"$\tau$ = {time_to_plot} s",
        ha="left",
        va="top",
        transform=ax.transAxes,
        fontsize=10,
        bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="none", alpha=0.7),
    )

    ax.scatter([0], [0], marker="v", ec="k", s=100, lw=1, c="#FFA90E")

for ax in axs[1]:
    ax.set_xlabel("Distance [km]", labelpad=0, fontsize=10)
for ax in axs[:, 0]:
    ax.set_ylabel("Distance [km]", labelpad=0, fontsize=10)

fig.savefig("../figures/figure4.png", dpi=300, bbox_inches="tight")