In [3]:
import os
import pickle
import numpy as np
import pandas as pd
from astropy.io import fits
from scipy.stats import multivariate_normal
import plotly.graph_objects as go
import plotly.io as pio
from tqdm import tqdm

pio.renderers.default = 'browser'

os.makedirs('../figures', exist_ok=True)


vphi_data = '../data/df_v_final.fits'
with fits.open(vphi_data) as hdul:
    data = hdul[1].data  
    df_v = pd.DataFrame({
        col.name: data[col.name].byteswap().newbyteorder()
        if data[col.name].dtype.byteorder == '>' else data[col.name]
        for col in hdul[1].columns
    })

df_v = df_v[np.abs(df_v['Z']) < 2.5]
df_v = df_v.dropna(subset=[
    "v_R", "v_phi", "v_R_uncertainty", 
    "v_phi_uncertainty", "v_Z", "v_Z_uncertainty", 
    "mh_xgboost"
])


def filter_metallicity_bin(df, mn, mx):
    return df[(df['mh_xgboost'] >= mn) & (df['mh_xgboost'] < mx)]


def plot_gmm_3d(df, gmm, component_colors, component_labels, title="GMM (vR, vφ, vZ)"):
    X = df[["v_R","v_phi","v_Z"]].values
    

    weights = gmm.amp / gmm.amp.sum()
    pdfs = np.stack([
        weights[k] * multivariate_normal.pdf(X, mean=gmm.mean[k], cov=gmm.covar[k])
        for k in range(gmm.K)
    ], axis=1)
    resp = pdfs / pdfs.sum(axis=1, keepdims=True)
    labels = resp.argmax(axis=1)
    dfp = df.copy()
    dfp["comp"] = labels

    fig = go.Figure()
    for k in range(gmm.K):
        name = component_labels[k]
        sel = (dfp.comp == k)
        fig.add_trace(go.Scatter3d(
            x=dfp.loc[sel, "v_R"],
            y=dfp.loc[sel, "v_phi"],
            z=dfp.loc[sel, "v_Z"],
            mode="markers",
            marker=dict(size=3, color=component_colors[k]),
            name=name,
            legendgroup=name,
            showlegend=True,
            opacity=1.0
        ))
    # parametric unit sphere for ellipsoids
    u = np.linspace(0, 2*np.pi, 40)
    v = np.linspace(0, np.pi, 20)
    xs = np.outer(np.cos(u), np.sin(v))
    ys = np.outer(np.sin(u), np.sin(v))
    zs = np.outer(np.ones_like(u), np.cos(v))
    sphere = np.stack([xs, ys, zs], axis=-1)
    # mesh (ellipsoid) traces
    for k in range(gmm.K):
        name = component_labels[k]
        mu, cov = gmm.mean[k], gmm.covar[k]
        vals, vecs = np.linalg.eigh(cov)
        radii = 2 * np.sqrt(vals)  # 2σ
        ell = sphere @ (vecs * radii)
        x_e = ell[...,0] + mu[0]
        y_e = ell[...,1] + mu[1]
        z_e = ell[...,2] + mu[2]
        fig.add_trace(go.Mesh3d(
            x=x_e.flatten(),
            y=y_e.flatten(),
            z=z_e.flatten(),
            alphahull=0,
            opacity=0.15,
            color=component_colors[k],
            name=name,
            legendgroup=name,
            showlegend=False
        ))

    n_traces = len(fig.data)
    buttons = []
    # All button
    buttons.append(dict(
        label="All",
        method="restyle",
        args=[{"opacity": [1.0] * n_traces}]
    ))
    # one button per component
    for j, lab in enumerate(component_labels):
        opacities = []
        for i in range(n_traces):
            comp_idx = i % gmm.K
            opacities.append(1.0 if comp_idx == j else 0.1)
        buttons.append(dict(
            label=lab,
            method="restyle",
            args=[{"opacity": opacities}]
        ))

    fig.update_layout(
        title=title,
        scene=dict(
            xaxis=dict(title="v_R (km/s)"),
            yaxis=dict(title="v_φ (km/s)"),
            zaxis=dict(title="v_Z (km/s)")
        ),
        scene_camera=dict(
            up=dict(x=0, y=1, z=0),            # ensure v_phi is screen-up
            eye=dict(x=1.2, y=1.2, z=0.8)      # tweak viewpoint
        ),
        updatemenus=[dict(
            buttons=buttons,
            direction="down",
            x=1.02,
            y=1.15
        )],
        legend=dict(
            itemclick='toggleothers',
            itemdoubleclick='toggle'
        ),
        width=800,
        height=700,
        margin=dict(l=0, r=0, b=0, t=40)
    )

    fig.show()  # opens in browser
    # sanitize filename
    fname = title.replace(" ","_").replace(":","").replace("<","").replace(">","").replace("/","_")
    html_path = os.path.join("../figures", f"{fname}.html")
    fig.write_html(html_path, include_plotlyjs="cdn")
    print(f"Saved interactive plot to {html_path}")



bin_info = [
    # (title,     mn,   mx,   pkl,                      colors,                                  labels)
    ("VMP : -3<[M/H]<-2",  -3.0, -2.0, "../models/gmm_vmp.pkl",
        ['red','blue'],
        ["Stationary halo","Prograde halo"]
    ),
    ("IMP : -2<[M/H]<-1.6",-2.0, -1.6, "../models/gmm_imp.pkl",
        ['aqua','red','gold','blue'],
        ["GS/E(1)","Stationary halo","GS/E(2)","Prograde halo"]
    ),
    ("MP1 : -1.6<[M/H]<-1.3",-1.6, -1.3, "../models/gmm_mp1.pkl",
        ['red','green','blue','gold','aqua'],
        ["Stationary halo","Thick Disc","Prograde halo","GS/E(2)","GS/E(1)"]
    ),
    ("MP2 : -1.3<[M/H]<-1.0",-1.3, -1.0, "../models/gmm_mp2.pkl",
        ['blue','aqua','green','red','gold'],
        ["Prograde halo","GS/E(1)","Thick Disc","Stationary halo","GS/E(2)"]
    ),
]

for title, mn, mx, pkl_path, colors, labels in tqdm(bin_info, desc="3D GMM bins"):
    df_bin = filter_metallicity_bin(df_v, mn, mx)
    with open(pkl_path, "rb") as f:
        gmm = pickle.load(f)
    plot_gmm_3d(df_bin, gmm,
                component_colors=colors,
                component_labels=labels,
                title=title)


3D GMM bins:  50%|█████     | 2/4 [00:00<00:00,  5.03it/s]

Saved interactive plot to ../figures/VMP__-3[M_H]-2.html
Saved interactive plot to ../figures/IMP__-2[M_H]-1.6.html


3D GMM bins:  75%|███████▌  | 3/4 [00:00<00:00,  5.47it/s]

Saved interactive plot to ../figures/MP1__-1.6[M_H]-1.3.html


3D GMM bins: 100%|██████████| 4/4 [00:00<00:00,  4.99it/s]

Saved interactive plot to ../figures/MP2__-1.3[M_H]-1.0.html



