In [3]:
import os
import pickle
import numpy as np
import pandas as pd
from astropy.io import fits
from scipy.stats import multivariate_normal
import plotly.graph_objects as go
import plotly.io as pio
from plotly.subplots import make_subplots
from tqdm import tqdm

pio.renderers.default = 'browser'
os.makedirs('../figures', exist_ok=True)


In [4]:
# Load & prepare data
vphi_data = '../data/df_v_final.fits'
with fits.open(vphi_data) as hdul:
    data = hdul[1].data
    df_v = pd.DataFrame({
        col.name: data[col.name].byteswap().newbyteorder()
        if data[col.name].dtype.byteorder == '>' else data[col.name]
        for col in hdul[1].columns
    })
df_v = df_v[np.abs(df_v['Z']) < 2.5]
df_v = df_v.dropna(subset=[
    "v_R", "v_phi", "v_R_uncertainty",
    "v_phi_uncertainty", "v_Z", "v_Z_uncertainty",
    "mh_xgboost"
])


In [5]:
def filter_metallicity_bin(df, mn, mx):
    return df[(df['mh_xgboost'] >= mn) & (df['mh_xgboost'] < mx)]

def plot_gmm_3d_with_weights(df, gmm, colors, labels, title="GMM (vR, vφ, vZ)"):
    # Hard assignments & weights
    X = df[["v_R","v_phi","v_Z"]].values
    w = gmm.amp / gmm.amp.sum()
    pdfs = np.stack([w[k]*multivariate_normal.pdf(X, mean=gmm.mean[k], cov=gmm.covar[k])
                     for k in range(gmm.K)], axis=1)
    resp = pdfs / pdfs.sum(axis=1, keepdims=True)
    comp_idx = resp.argmax(axis=1)
    dfp = df.copy(); dfp["comp"] = comp_idx
    weights_pct = (w*100).round(1)

    # Subplots
    fig = make_subplots(
        rows=1, cols=2,
        specs=[[{"type":"scene"}, {"type":"xy"}]],
        column_widths=[0.7,0.3],
        subplot_titles=[title, "Weights (%)"]
    )

    # 3D scatter & ellipsoids
    u = np.linspace(0,2*np.pi,40)
    v = np.linspace(0,np.pi,20)
    sphere = np.stack([
        np.outer(np.cos(u), np.sin(v)),
        np.outer(np.sin(u), np.sin(v)),
        np.outer(np.ones_like(u), np.cos(v))
    ], axis=-1)
    for k in range(gmm.K):
        name = labels[k]
        sel = dfp.comp==k
        # scatter
        fig.add_trace(go.Scatter3d(
            x=dfp.loc[sel,"v_R"], y=dfp.loc[sel,"v_phi"], z=dfp.loc[sel,"v_Z"],
            mode="markers",
            marker=dict(size=1,color=colors[k]),
            name=name, legendgroup=name, opacity=1.0,
            hovertemplate=(f"<b>{name}</b><br>"
                          "v_R: %{x:.2f} km/s<br>"
                          "v_φ: %{y:.2f} km/s<br>"
                          "v_Z: %{z:.2f} km/s<extra></extra>")
        ), row=1, col=1)
        # ellipsoid
        mu, cov = gmm.mean[k], gmm.covar[k]
        vals, vecs = np.linalg.eigh(cov)
        ell = sphere @ (vecs*(2*np.sqrt(vals)))
        fig.add_trace(go.Mesh3d(
            x=ell[...,0].ravel()+mu[0],
            y=ell[...,1].ravel()+mu[1],
            z=ell[...,2].ravel()+mu[2],
            alphahull=0, opacity=0.15, color=colors[k],
            name=name, legendgroup=name, showlegend=False
        ), row=1, col=1)

    # One bar‐trace per component
    for k in range(gmm.K):
        name = labels[k]
        fig.add_trace(go.Bar(
            x=[name],
            y=[weights_pct[k]],
            marker_color=colors[k],
            name=name,
            legendgroup=name,
            showlegend=False,
            opacity=1.0,
            text=[f"{weights_pct[k]}%"],
            textposition="auto"
        ), row=1, col=2)
    # ensure category order
    fig.update_xaxes(categoryorder="array", categoryarray=labels, row=1, col=2)

    # Highlight dropdown via legendgroup
    groups = [t.legendgroup for t in fig.data]
    buttons = [dict(
        label="All", method="restyle",
        args=[{"opacity": [1.0]*len(groups)}]
    )]
    for lab in labels:
        op = [1.0 if grp==lab else 0.1 for grp in groups]
        buttons.append(dict(label=lab, method="restyle", args=[{"opacity":op}]))

    # Layout + camera + dropdown
    fig.update_layout(
        scene=dict(
            xaxis_title="v_R (km/s)",
            yaxis_title="v_φ (km/s)",
            zaxis_title="v_Z (km/s)"
        ),
        scene_camera=dict(up=dict(x=0,y=1,z=0), eye=dict(x=1.5,y=1.5,z=1.0)),
        updatemenus=[dict(buttons=buttons, direction="down",
                          x=0.5, y=1.10, xanchor='center', yanchor='bottom')],
        width=1400, height=750,
        margin=dict(l=0, r=0, b=100, t=80)   
    )

    # Caption in a boxed annotation
    fig.add_annotation(
        text=(
            "Gaussian Mixture Model decompositions of the stellar velocity "
            "distribution in 3D (v_R–v_φ–v_Z). Use dropdown to highlight; "
            "right shows each component’s fractional weight."
        ),
        x=0.5, y=-0.17,           
        xref='paper', yref='paper',
        showarrow=False,
        align='center',
        font=dict(size=12, color='black'),
        bgcolor='white',           # box background
        bordercolor='black',       # box border
        borderwidth=1,
        borderpad=10               # padding inside box
    )


    # Show & save
    fig.show()
    safe = title.replace(" ", "_").replace(":", "").replace("<","").replace(">","").replace("/","_")
    out = os.path.join("../figures", f"{safe}.html")
    fig.write_html(out, include_plotlyjs="cdn")
    print(f"Saved to {out}")


# Run for each bin
bin_info = [
    ("VMP : -3<[M/H]<-2",  -3.0, -2.0, "../models/gmm_vmp.pkl",
     ['red','blue'], ["Stationary halo","Prograde halo"]),
    ("IMP : -2<[M/H]<-1.6", -2.0, -1.6, "../models/gmm_imp.pkl",
     ['aqua','red','gold','blue'], ["GS/E(1)","Stationary halo","GS/E(2)","Prograde halo"]),
    ("MP1 : -1.6<[M/H]<-1.3", -1.6, -1.3, "../models/gmm_mp1.pkl",
     ['red','green','blue','gold','aqua'], ["Stationary halo","Thick Disc","Prograde halo","GS/E(2)","GS/E(1)"]),
    ("MP2 : -1.3<[M/H]<-1.0", -1.3, -1.0, "../models/gmm_mp2.pkl",
     ['blue','aqua','green','red','gold'], ["Prograde halo","GS/E(1)","Thick Disc","Stationary halo","GS/E(2)"]),
]

for title, mn, mx, pkl, cols, labs in tqdm(bin_info, desc="3D+weights bins"):
    dfb = filter_metallicity_bin(df_v, mn, mx)
    with open(pkl, "rb") as f:
        g = pickle.load(f)
    plot_gmm_3d_with_weights(dfb, g, cols, labs, title=title)



IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html

3D+weights bins:  50%|█████     | 2/4 [00:00<00:00,  4.50it/s]

Saved to ../figures/VMP__-3[M_H]-2.html
Saved to ../figures/IMP__-2[M_H]-1.6.html


3D+weights bins: 100%|██████████| 4/4 [00:00<00:00,  4.77it/s]

Saved to ../figures/MP1__-1.6[M_H]-1.3.html
Saved to ../figures/MP2__-1.3[M_H]-1.0.html



