In [1]:
import numpy as np
import matplotlib.pyplot as plt
import glob
import re
from CenterOfMass2 import CenterOfMass

# -----------------------------
# Functions
# -----------------------------
def compute_inertia_tensor(positions, masses):
    r_squared = np.sum(positions**2, axis=1)
    I = np.zeros((3,3))
    for i in range(3):
        for j in range(3):
            delta = 1.0 if i==j else 0.0
            I[i,j] = np.sum(masses * (r_squared * delta - positions[:,i]*positions[:,j]))
    return I

def compute_axis_ratios(I):
    eigvals,_ = np.linalg.eigh(I)
    eigvals = np.sort(eigvals)[::-1]
    return np.sqrt(eigvals[1]/eigvals[0]), np.sqrt(eigvals[2]/eigvals[0])

def radial_shell_axis_ratios(positions, masses, r_bins):
    num_shells = len(r_bins)-1
    ba_array = np.full(num_shells, np.nan)
    ca_array = np.full(num_shells, np.nan)
    radii = np.linalg.norm(positions, axis=1)
    for i in range(num_shells):
        r1, r2 = r_bins[i], r_bins[i+1]
        mask = (radii>=r1) & (radii<r2)
        if np.sum(mask) >= 20:
            I_shell = compute_inertia_tensor(positions[mask], masses[mask])
            ba, ca = compute_axis_ratios(I_shell)
            ba_array[i] = ba
            ca_array[i] = ca
    return ba_array, ca_array

# -----------------------------
# Main: Automatic Snapshot Detection
# -----------------------------
r_bins = np.linspace(0, 200, 21)  # Radial bins (kpc)

# Get all MW and M31 snapshot files in folder
MW_files = sorted(glob.glob("MW_*.txt"), key=lambda f: int(re.findall(r'\d+', f)[0]))
M31_files = sorted(glob.glob("M31_*.txt"), key=lambda f: int(re.findall(r'\d+', f)[0]))

# Make sure we have the same number of snapshots for MW and M31
assert len(MW_files) == len(M31_files), "Mismatch in number of MW and M31 snapshots!"

ba_matrix = []
ca_matrix = []
snapshot_labels = []

for mw_file, m31_file in zip(MW_files, M31_files):
    MW = CenterOfMass(mw_file, 1)
    M31 = CenterOfMass(m31_file, 1)

    # Combine particle data
    x = np.concatenate((MW.x, M31.x))
    y = np.concatenate((MW.y, M31.y))
    z = np.concatenate((MW.z, M31.z))
    m = np.concatenate((MW.m, M31.m))
    positions = np.vstack((x, y, z)).T

    # Shift to combined COM frame
    x_com, y_com, z_com = MW.COMdefine(x, y, z, m)
    positions -= np.array([x_com, y_com, z_com])

    # Compute radial shell axis ratios
    ba_array, ca_array = radial_shell_axis_ratios(positions, m, r_bins)
    ba_matrix.append(ba_array)
    ca_matrix.append(ca_array)

    snapshot_number = re.findall(r'\d+', mw_file)[0]
    snapshot_labels.append(f"Snap {snapshot_number}")

# Convert to numpy arrays
ba_matrix = np.array(ba_matrix)
ca_matrix = np.array(ca_matrix)

# -----------------------------
# Plot Heatmaps
# -----------------------------
fig, ax = plt.subplots(1,2, figsize=(14,6))

im1 = ax[0].imshow(ba_matrix.T, origin='lower', aspect='auto', cmap='viridis',
                   extent=[0,len(snapshot_labels)-1,r_bins[0],r_bins[-1]])
ax[0].set_title("b/a Heatmap")
ax[0].set_ylabel("Radius [kpc]")
ax[0].set_xticks(range(len(snapshot_labels)))
ax[0].set_xticklabels(snapshot_labels, rotation=45)
fig.colorbar(im1, ax=ax[0], label='b/a')

im2 = ax[1].imshow(ca_matrix.T, origin='lower', aspect='auto', cmap='plasma',
                   extent=[0,len(snapshot_labels)-1,r_bins[0],r_bins[-1]])
ax[1].set_title("c/a Heatmap")
ax[1].set_ylabel("Radius [kpc]")
ax[1].set_xticks(range(len(snapshot_labels)))
ax[1].set_xticklabels(snapshot_labels, rotation=45)
fig.colorbar(im2, ax=ax[1], label='c/a')

plt.tight_layout()
plt.show()


Matplotlib is building the font cache; this may take a moment.


Snapshot 445 — Inertia Tensor:
 [[3.55327985e+08 5.46408277e+06 3.52852561e+06]
 [5.46408277e+06 3.44289892e+08 2.20582841e+05]
 [3.52852561e+06 2.20582841e+05 3.57701657e+08]]
Global Axis Ratios: b/a = 0.991, c/a = 0.973
→ Global halo shape: SPHERICAL
