In [11]:
import os
import pickle
import numpy as np
import pandas as pd
from astropy.io import fits
from scipy.stats import multivariate_normal
import plotly.graph_objects as go
import plotly.io as pio
from plotly.subplots import make_subplots
from tqdm import tqdm

pio.renderers.default = 'browser'
os.makedirs('../figures', exist_ok=True)


# 1) Load & prepare data
vphi_data = '../data/df_v_final.fits'
with fits.open(vphi_data) as hdul:
    data = hdul[1].data  
    df_v = pd.DataFrame({
        col.name: data[col.name].byteswap().newbyteorder()
        if data[col.name].dtype.byteorder == '>' else data[col.name]
        for col in hdul[1].columns
    })
df_v = df_v[np.abs(df_v['Z']) < 2.5]
df_v = df_v.dropna(subset=[
    "v_R", "v_phi", "v_R_uncertainty", 
    "v_phi_uncertainty", "v_Z", "v_Z_uncertainty", 
    "mh_xgboost"
])


def filter_metallicity_bin(df, mn, mx):
    return df[(df['mh_xgboost'] >= mn) & (df['mh_xgboost'] < mx)]


def plot_gmm_3d_with_weights(df, gmm, colors, labels, title="GMM (vR, vφ, vZ)"):
    # data matrix
    X = df[["v_R","v_phi","v_Z"]].values

    # posterior responsibilities → hard labels
    w = gmm.amp / gmm.amp.sum()
    pdfs = np.stack([w[k] * multivariate_normal.pdf(X, mean=gmm.mean[k], cov=gmm.covar[k])
                     for k in range(gmm.K)], axis=1)
    resp = pdfs / pdfs.sum(axis=1, keepdims=True)
    comp_idx = resp.argmax(axis=1)
    dfp = df.copy()
    dfp["comp"] = comp_idx

    # compute component weights (%) for bar chart
    weights_pct = (w * 100).round(1)

    # build a 1×2 subplot: scene + bar chart
    fig = make_subplots(
        rows=1, cols=2,
        specs=[[{'type':'scene'}, {'type':'xy'}]],
        column_widths=[0.7, 0.3],
        subplot_titles=[title, "Weights (%)"]
    )

    # 3D scatter + ellipsoids in col=1
    # scatter
    for k in range(gmm.K):
        name = labels[k]
        sel = dfp.comp == k
        fig.add_trace(
            go.Scatter3d(
                x=dfp.loc[sel, "v_R"],
                y=dfp.loc[sel, "v_phi"],
                z=dfp.loc[sel, "v_Z"],
                mode="markers",
                marker=dict(size=3, color=colors[k]),
                name=name,
                legendgroup=name,
                showlegend=True,
                opacity=1.0
            ),
            row=1, col=1
        )
    # ellipsoid meshes
    u = np.linspace(0, 2*np.pi, 40)
    v = np.linspace(0, np.pi, 20)
    xs = np.outer(np.cos(u), np.sin(v))
    ys = np.outer(np.sin(u), np.sin(v))
    zs = np.outer(np.ones_like(u), np.cos(v))
    sphere = np.stack([xs, ys, zs], axis=-1)
    for k in range(gmm.K):
        name = labels[k]
        mu, cov = gmm.mean[k], gmm.covar[k]
        vals, vecs = np.linalg.eigh(cov)
        radii = 2 * np.sqrt(vals)
        ell = sphere @ (vecs * radii)
        x_e = ell[...,0] + mu[0]
        y_e = ell[...,1] + mu[1]
        z_e = ell[...,2] + mu[2]
        fig.add_trace(
            go.Mesh3d(
                x=x_e.flatten(), y=y_e.flatten(), z=z_e.flatten(),
                alphahull=0, opacity=0.15, color=colors[k],
                name=name, legendgroup=name, showlegend=False
            ),
            row=1, col=1
        )

    # bar chart of weights in col=2
    fig.add_trace(
        go.Bar(
            x=labels,
            y=weights_pct,
            marker_color=colors,
            showlegend=False,
            text=[f"{wt}%" for wt in weights_pct],
            textposition="auto"
        ),
        row=1, col=2
    )

    # build highlight dropdown (just affects 3D traces)
    n3d = 2 * gmm.K  # scatter + mesh per component
    buttons = [dict(
        label="All", method="restyle",
        args=[{"opacity": [1.0]*n3d + [1.0]}]  # last trace is bar, keep at full
    )]
    for j, lab in enumerate(labels):
        op = []
        for i in range(n3d):
            comp_i = i % gmm.K
            op.append(1.0 if comp_i == j else 0.1)
        # bar stays at full opacity (index n3d)
        op.append(1.0)
        buttons.append(dict(
            label=lab, method="restyle",
            args=[{"opacity": op}]
        ))

    # layout
    fig.update_layout(
        scene=dict(
            xaxis=dict(title="v_R (km/s)"),
            yaxis=dict(title="v_φ (km/s)"),
            zaxis=dict(title="v_Z (km/s)")
        ),
        # pull the “camera” back for a wider view
        scene_camera=dict(
            up=dict(x=0, y=1, z=0),
            eye=dict(x=1.5, y=1.5, z=1.0)
        ),
        # place the dropdown centered above the plot title
        updatemenus=[dict(
            buttons=buttons,
            direction="down",
            x=0.5,       # center horizontally
            y=1.10,      # just above the title
            xanchor='center',
            yanchor='bottom'
        )],
        width=1400,
        height=700,
        margin=dict(l=0, r=0, b=0, t=80)  # more top margin to fit dropdown
    )

    # show & export (unchanged) …
    fig.show()
    fname = (
        title
        .replace(" ", "_")
        .replace(":", "")
        .replace("<", "")
        .replace(">", "")
        .replace("/", "_")
    )
    html_out = os.path.join("../figures", f"{fname}.html")
    fig.write_html(html_out, include_plotlyjs="cdn")
    print(f"Saved interactive plot to {html_out}")


# 5) define bins & run
bin_info = [
    ("VMP : -3<[M/H]<-2",  -3.0, -2.0, "../models/gmm_vmp.pkl",
        ['red','blue'], ["Stationary halo","Prograde halo"]
    ),
    ("IMP : -2<[M/H]<-1.6",-2.0, -1.6, "../models/gmm_imp.pkl",
        ['aqua','red','gold','blue'], ["GS/E(1)","Stationary halo","GS/E(2)","Prograde halo"]
    ),
    ("MP1 : -1.6<[M/H]<-1.3",-1.6, -1.3, "../models/gmm_mp1.pkl",
        ['red','green','blue','gold','aqua'], ["Stationary halo","Thick Disc","Prograde halo","GS/E(2)","GS/E(1)"]
    ),
    ("MP2 : -1.3<[M/H]<-1.0",-1.3, -1.0, "../models/gmm_mp2.pkl",
        ['blue','aqua','green','red','gold'], ["Prograde halo","GS/E(1)","Thick Disc","Stationary halo","GS/E(2)"]
    ),
]

for title, mn, mx, pkl, cols, labs in tqdm(bin_info, desc="3D+weights bins"):
    dfb = filter_metallicity_bin(df_v, mn, mx)
    with open(pkl, "rb") as f:
        g = pickle.load(f)
    plot_gmm_3d_with_weights(dfb, g, cols, labs, title=title)


3D+weights bins:  50%|█████     | 2/4 [00:00<00:00,  5.29it/s]

Saved interactive plot to ../figures/VMP__-3[M_H]-2.html
Saved interactive plot to ../figures/IMP__-2[M_H]-1.6.html


3D+weights bins:  75%|███████▌  | 3/4 [00:00<00:00,  5.55it/s]

Saved interactive plot to ../figures/MP1__-1.6[M_H]-1.3.html


3D+weights bins: 100%|██████████| 4/4 [00:00<00:00,  4.89it/s]

Saved interactive plot to ../figures/MP2__-1.3[M_H]-1.0.html



