# loss

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

sns.set(style="ticks")
plt.rcParams['font.size'] = 18
plt.rcParams['font.family'] = 'Arial'
plt.rcParams['axes.titlesize'] = 18
plt.rcParams['axes.labelsize'] = 20
plt.rcParams['legend.fontsize'] = 14

palette = sns.color_palette('husl')

In [None]:

resnet_losses_phase = np.load(
    '/path/to/your/data/AE_losses_phase.npy',
    allow_pickle=True).item()

resnet_losses_amp = np.load(
    '/path/to/your/data/AE_losses_amplitude.npy',
    allow_pickle=True).item()
fig, ax1 = plt.subplots(figsize=(10, 6))

ax1.plot(resnet_losses_phase["loss"], label=r"$\Phi$: Total Loss",
         color=palette[0], linewidth=2)
ax1.plot(resnet_losses_phase["reconstruction_loss_phase"], label=r"$\Phi$: Reconstruction Loss",
         color=palette[1], linestyle='--', linewidth=2.5)
ax1.plot(resnet_losses_phase["latent_loss_phase"], label=r"$\Phi$: Latent space Loss",
         color=palette[2], linestyle=':', linewidth=3)

ax1.plot(resnet_losses_amp["loss"], label=r"$A$: Total Loss",
         color=palette[3], linewidth=2)
ax1.plot(resnet_losses_amp["reconstruction_loss_amplitude"], label=r"$A$: Reconstruction Loss",
         color=palette[4], linestyle='--', linewidth=2.5)
ax1.plot(resnet_losses_amp["latent_loss_amplitude"], label=r"$A$: Latent space Loss",
         color=palette[5], linestyle=':', linewidth=3)

ax1.set_yscale('log')
ax1.set_xlabel("Epochs")
ax1.set_ylabel("Loss (log scale)")
ax1.tick_params(axis='both', which='major', labelsize=18)
ax1.grid(True, which='both', linestyle='--', linewidth=0.6, alpha=0.5)

ax2 = ax1.twinx()
ax2.plot(resnet_losses_phase["lr"], label=r"$\Phi$: Learning Rate", color="gray", linestyle='-', linewidth=1, alpha=0.3)
ax2.plot(resnet_losses_amp["lr"], label=r"$A$: Learning Rate", color="blue", linestyle='-', linewidth=1, alpha=0.3)
ax2.set_ylabel("Learning Rate", fontsize=18)
ax2.tick_params(axis='y', labelsize=16)

lines1, labels1 = ax1.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax1.legend(lines1 + lines2, labels1 + labels2, loc='upper right', frameon=False, ncol=2)

# 美化
sns.despine(top=True, right=False, left=False, bottom=False)
plt.tight_layout()
plt.show()


# Parameter distribution

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits.mplot3d import Axes3D

dataset_path = '/path/to/your/data/BNS_train_set.npz'
dataset_dict = np.load(dataset_path)
X_train       = dataset_dict['X_train'].astype('float32')                 
y_amplitude       = dataset_dict['y_amplitude'].astype('float32')       

m1, m2 = X_train[:,0:2]
λ1,λ2 = X_train[:,2:4]
s1x, s1y, s1z = X_train[:,4:7]
s2x, s2y, s2z = X_train[:,7:]
s1=np.sqrt(s1x**2 + s1y**2 + s1z**2)
s2=np.sqrt(s2x**2 + s2y**2 + s2z**2)

mass_colors  = ["#184CAC", "#184CAC"]
tidal_colors = ["#184CAC", "#184CAC"]
spin_colors  = ["#184CAC", "#184CAC"]

fig = plt.figure(figsize=(16, 6))
gs = fig.add_gridspec(2, 4)
n_bins = 30

column_data = [
    [(m1, r"$m_1$", mass_colors[0]), (m2, r"$m_2$", mass_colors[1])],
    [(λ1, r"$\Lambda_1$", tidal_colors[0]), (λ2, r"$\Lambda_2$", tidal_colors[1])],
    [(s1, r"$|\vec{s}_1|$", spin_colors[0]), (s2, r"$|\vec{s}_2|$", spin_colors[1])]
]

for col, column_group in enumerate(column_data):
    for row, (data, title, color) in enumerate(column_group):
        ax = fig.add_subplot(gs[row, col])
        counts, bins = np.histogram(data, bins=n_bins, density=True)
        centers = 0.5 * (bins[:-1] + bins[1:])
        ax.hist(data, bins=n_bins, density=True, alpha=0.6, edgecolor='black', color=color)
        ax.plot(centers, counts, '-', color=color, lw=1.5)
        ax.set_title(title)
        
        if row == 1:
            if col == 0:
                ax.set_xlabel(r"$M_\odot$")
            else:
                ax.set_xlabel("Dimensionless")
        
        ax.set_ylabel("Density")
        ax.grid(True)

def set_axes_equal(ax):
    limits = np.array([
        ax.get_xlim3d(),
        ax.get_ylim3d(),
        ax.get_zlim3d(),
    ])
    centers = np.mean(limits, axis=1)
    radius = 0.5 * np.max(limits[:, 1] - limits[:, 0])
    for ctr, axis in zip(centers, [ax.set_xlim3d, ax.set_ylim3d, ax.set_zlim3d]):
        axis([ctr - radius, ctr + radius])

ax3d_1 = fig.add_subplot(gs[0, 3], projection='3d')
ax3d_1.set_facecolor("lightgray")  
p1 = ax3d_1.scatter(s1x, s1y, s1z, c=s1, cmap='ocean', s=0.4, alpha=0.5, linewidth=0)
ax3d_1.set_title(r"$\vec{s}_1$")
ax3d_1.set_xlabel('$s_x$'); ax3d_1.set_ylabel('$s_y$'); ax3d_1.set_zlabel('$s_z$')
ax3d_1.view_init(elev=18, azim=35)
set_axes_equal(ax3d_1)
fig.colorbar(p1, ax=ax3d_1, fraction=0.045, pad=0.05)

ax3d_2 = fig.add_subplot(gs[1, 3], projection='3d')
ax3d_2.set_facecolor("lightgray") 
p2 = ax3d_2.scatter(s2x, s2y, s2z, c=s2, cmap='ocean', s=0.4, alpha=0.5, linewidth=0)
ax3d_2.set_title(r"$\vec{s}_2$")
ax3d_2.set_xlabel('$s_x$'); ax3d_2.set_ylabel('$s_y$'); ax3d_2.set_zlabel('$s_z$')
ax3d_2.view_init(elev=18, azim=35)
set_axes_equal(ax3d_2)
fig.colorbar(p2, ax=ax3d_2, fraction=0.045, pad=0.05)

sns.despine(top=True, right=True)
plt.subplots_adjust(wspace=0.5, hspace=0.5, top=0.92)
plt.show()


In [None]:

dataset_path = '/path/to/your/data/BNS_test_set.npz'
dataset_dict = np.load(dataset_path)
X_test     = dataset_dict['X_test'].astype('float32')                 
y_amplitude       = dataset_dict['y_amplitude'].astype('float32')       

m1, m2 = X_test[:,0:2]
λ1,λ2 = X_test[:,2:4]
s1x, s1y, s1z = X_test[:,4:7]
s2x, s2y, s2z = X_test[:,7:]
s1=np.sqrt(s1x**2 + s1y**2 + s1z**2)
s2=np.sqrt(s2x**2 + s2y**2 + s2z**2)

mass_colors  = ["#A6A6A6", "#A6A6A6"]
tidal_colors = ["#A6A6A6", "#A6A6A6"]
spin_colors  = ["#A6A6A6", "#A6A6A6"]

fig = plt.figure(figsize=(16, 6))
gs = fig.add_gridspec(2, 4)
n_bins = 30

column_data = [
    [(m1, r"$m_1$", mass_colors[0]), (m2, r"$m_2$", mass_colors[1])],
    [(λ1, r"$\Lambda_1$", tidal_colors[0]), (λ2, r"$\Lambda_2$", tidal_colors[1])],
    [(s1, r"$|\vec{s}_1|$", spin_colors[0]), (s2, r"$|\vec{s}_2|$", spin_colors[1])]
]

for col, column_group in enumerate(column_data):
    for row, (data, title, color) in enumerate(column_group):
        ax = fig.add_subplot(gs[row, col])
        counts, bins = np.histogram(data, bins=n_bins, density=True)
        centers = 0.5 * (bins[:-1] + bins[1:])
        ax.hist(data, bins=n_bins, density=True, alpha=0.6, edgecolor='black', color=color)
        ax.plot(centers, counts, '-', color=color, lw=1.5)
        ax.set_title(title)
        
        if row == 1:
            if col == 0:
                ax.set_xlabel(r"$M_\odot$")
            else:
                ax.set_xlabel("Dimensionless")
        
        ax.set_ylabel("Density")
        ax.grid(True)
def set_axes_equal(ax):
    limits = np.array([
        ax.get_xlim3d(),
        ax.get_ylim3d(),
        ax.get_zlim3d(),
    ])
    centers = np.mean(limits, axis=1)
    radius = 0.5 * np.max(limits[:, 1] - limits[:, 0])
    for ctr, axis in zip(centers, [ax.set_xlim3d, ax.set_ylim3d, ax.set_zlim3d]):
        axis([ctr - radius, ctr + radius])

ax3d_1 = fig.add_subplot(gs[0, 3], projection='3d')
p1 = ax3d_1.scatter(s1x, s1y, s1z, c=s1, cmap='Greys', s=0.4, alpha=0.5, linewidth=0)
ax3d_1.set_title(r"$\vec{s}_1$")
ax3d_1.set_xlabel('$s_x$'); ax3d_1.set_ylabel('$s_y$'); ax3d_1.set_zlabel('$s_z$')
ax3d_1.view_init(elev=18, azim=35)
set_axes_equal(ax3d_1)
fig.colorbar(p1, ax=ax3d_1, fraction=0.045, pad=0.05)

ax3d_2 = fig.add_subplot(gs[1, 3], projection='3d')
p2 = ax3d_2.scatter(s2x, s2y, s2z, c=s2, cmap='Greys', s=0.4, alpha=0.5, linewidth=0)
ax3d_2.set_title(r"$\vec{s}_2$")
ax3d_2.set_xlabel('$s_x$'); ax3d_2.set_ylabel('$s_y$'); ax3d_2.set_zlabel('$s_z$')
ax3d_2.view_init(elev=18, azim=35)
set_axes_equal(ax3d_2)
fig.colorbar(p2, ax=ax3d_2, fraction=0.045, pad=0.05)

sns.despine(top=True, right=True)
plt.subplots_adjust(wspace=0.5, hspace=0.5, top=0.92)
plt.show()


# Mismatch

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import patches
from matplotlib.colors import LinearSegmentedColormap

res_mlp_cnn_hp_hc_metrics=np.load('/path/to/your/data/res_mlp_cnn_hp_hc_metrics.npz')
cae_hp_hc_metrics=np.load('/path/to/your/data/cae_hp_hc_metrics.npz')

sns.set(
    style="ticks",
    rc={
        "font.family":     "sans-serif",
        "font.sans-serif": ["Arial", "DejaVu Sans", "Liberation Sans"],
        "font.size":       18,
        "axes.titlesize":  18,
        "axes.labelsize":  18,
        "legend.fontsize": 16
    }
)

COLORS = {"mlp": "#BB0D3B", "cae": "#090783"}
CMAPS  = {
    "mlp": LinearSegmentedColormap.from_list("mlp_cmap", ["#ffffff", COLORS["mlp"]]),
    "cae": LinearSegmentedColormap.from_list("cae_cmap", ["#ffffff", COLORS["cae"]])
}
LEGENDS = {
    "mlp": patches.Patch(color=COLORS["mlp"], label="Residual-MLP-CNN "),
    "cae": patches.Patch(color=COLORS["cae"], label="cAE ")
}


fig, axs = plt.subplots(3, 1, figsize=(7, 8), sharex=True)
gridsize = 40

axs[0].set_facecolor("white")
axs[0].hexbin(
    cae_hp_hc_metrics["cycles_hp_true"],
    np.log10(cae_hp_hc_metrics["mismatch_hp"]),
    gridsize=gridsize, cmap=CMAPS["cae"],
    alpha=1, edgecolors="face", linewidths=1, bins="log"
)
axs[0].hexbin(
    res_mlp_cnn_hp_hc_metrics["cycles_hp_true"],
    np.log10(res_mlp_cnn_hp_hc_metrics["mismatch_hp"]),
    gridsize=gridsize, cmap=CMAPS["mlp"],
    alpha=0.5, edgecolors="face", linewidths=1, bins="log"
)
axs[0].legend(handles=[LEGENDS["cae"], LEGENDS["mlp"]], loc="lower right", frameon=False)
axs[0].set_ylabel(r"Mismatch ($\log_{10}$)")

for i, (label, dataset, cmap) in enumerate([
    ("cae", cae_hp_hc_metrics, CMAPS["cae"]),
    ("mlp", res_mlp_cnn_hp_hc_metrics, CMAPS["mlp"])
], start=1):
    axs[i].set_facecolor("white")
    im = axs[i].hexbin(
        dataset["cycles_hp_true"],
        np.log10(dataset["mismatch_hp"]),
        gridsize=gridsize, cmap=cmap, bins="log",
        edgecolors="face", linewidths=1
    )
    axs[i].set_ylabel(r"Mismatch ($\log_{10}$)")
    axs[i].legend(handles=[LEGENDS[label]], loc="lower right", frameon=False)
    plt.colorbar(im, ax=axs[i], fraction=0.046, pad=0.04)
axs[-1].set_xlabel("Cycles")

plt.tight_layout()
sns.despine()
plt.show()


In [None]:

datasets = {}
for key, arr in [
    ("MLP hp", res_mlp_cnn_hp_hc_metrics["mismatch_hp"]),
    ("MLP hc", res_mlp_cnn_hp_hc_metrics["mismatch_hc"]),
    ("CAE hp", cae_hp_hc_metrics["mismatch_hp"]),
    ("CAE hc", cae_hp_hc_metrics["mismatch_hc"])
]:
    pos = arr[arr > 0]
    if pos.size:
        datasets[key] = np.log10(pos)

keys           = ["MLP hp", "MLP hc", "CAE hp", "CAE hc"]
labels_display = [
    "Residual-MLP-CNN $h_+$", "Residual-MLP-CNN $h_{\\times}$",
    "cAE $h_+$", "cAE $h_{\\times}$"
]
colors         = ["#2E2B2B", "#252424", "#1F77B4", "#579BF5"]
n_bins         = 30

plt.figure(figsize=(7, 4))
for k, lbl, c in zip(keys, labels_display, colors):
    data = datasets.get(k)
    if data is None: continue
    cnts, bins = np.histogram(data, bins=n_bins, density=True)
    ctrs       = 0.5 * (bins[:-1] + bins[1:])
    plt.bar(ctrs, cnts, width=bins[1]-bins[0], alpha=0.2, color=c, edgecolor="none")
    plt.plot(ctrs, cnts, "-", label=lbl, linewidth=1.5, color=c)

plt.xlabel(r"Mismatch ($\log_{10}$)")
plt.ylabel("Probability Density")
plt.legend(loc="upper left", frameon=False)
plt.grid(True, linestyle="--", linewidth=0.5, alpha=0.7)
sns.despine(top=True, right=True)
plt.tight_layout()
plt.show()


In [None]:

m1       = cae_hp_hc_metrics['x_test'][:, 0]
m2       = cae_hp_hc_metrics['x_test'][:, 1]
lam1     = cae_hp_hc_metrics['x_test'][:, 2]
lam2     = cae_hp_hc_metrics['x_test'][:, 3]
cycles   = cae_hp_hc_metrics['cycles_hp_true']
mismatch_log = np.log10(cae_hp_hc_metrics['mismatch_hp'])

cc = "viridis"

plt.figure(figsize=(8, 6))
sc = plt.scatter(lam1, lam2, c=mismatch_log, cmap=cc, alpha=0.7)
cb = plt.colorbar(sc)
cb.ax.tick_params(labelsize=16)
cb.set_label(r"Mismatch($\log_{10}$)")
plt.xlabel(r"$\Lambda_1$")
plt.ylabel(r"$\Lambda_2$")
plt.grid(True, linestyle="--", linewidth=0.6, alpha=0.5)
sns.despine()
plt.tight_layout()
plt.show()


spin1 = cae_hp_hc_metrics['x_test'][:, 4:7]
spin2 = cae_hp_hc_metrics['x_test'][:, 7:10]
spin1_mag = np.linalg.norm(spin1, axis=1)
spin2_mag = np.linalg.norm(spin2, axis=1)

plt.figure(figsize=(8, 6))
sc = plt.scatter(spin1_mag, spin2_mag, c=mismatch_log, cmap=cc, alpha=0.7)
cb = plt.colorbar(sc)
cb.ax.tick_params(labelsize=16)
cb.set_label(r"Mismatch($\log_{10}$)")
plt.xlabel(r"$|\vec{s}_1|$")
plt.ylabel(r"$|\vec{s}_2|$")
plt.grid(True, linestyle="--", linewidth=0.6, alpha=0.5)
sns.despine()
plt.tight_layout()
plt.show()
