In [10]:
import os
import pickle
import numpy as np
import pandas as pd
from astropy.io import fits
from scipy.stats import multivariate_normal
import plotly.graph_objects as go
import plotly.io as pio
from plotly.subplots import make_subplots
from tqdm import tqdm

pio.renderers.default = 'browser'
os.makedirs('../figures', exist_ok=True)

with fits.open('../data/vis_main_filtered.fits') as hdul:
    data = hdul[1].data
df = pd.DataFrame(np.array(data).byteswap().newbyteorder())

df = df[np.abs(df['Z']) < 2.5]
df = df.dropna(subset=[
    'v_R','v_phi','v_R_uncertainty','v_phi_uncertainty',
    'v_Z','v_Z_uncertainty','mh_xgboost','aom_xp'
])

def define_alpha_sequences(df, alpha_col='aom_xp', feh_col='mh_xgboost'):
    a = df[alpha_col]; f = df[feh_col]
    high = (
        ((f < -0.6) & (a > 0.28)) |
        ((f >= -0.6) & (f <= 0.125) & (a > -0.25*f + 0.13)) |
        ((f > 0.125) & (a > 0.1))
    )
    low  = (
        ((f < -0.8) & (a < 0.21)) |
        ((f >= -0.8) & (f <= 0.07)  & (a < -0.21*f + 0.045)) |
        ((f > 0.07)  & (a < 0.03))
    )
    return df[high].copy(), df[low].copy()

df_high, df_low = define_alpha_sequences(df)

def filter_metallicity_bin(df, mn, mx):
    return df[(df['mh_xgboost'] >= mn) & (df['mh_xgboost'] < mx)].copy()

def plot_gmm_3d_with_weights(df, gmm, colors, labels, title):
    import numpy as np
    from scipy.stats import multivariate_normal
    import plotly.graph_objects as go
    from plotly.subplots import make_subplots

    # 1) hard assignments + compute weights
    X = df[['v_R','v_phi','v_Z']].values
    w = gmm.amp / gmm.amp.sum()
    pdfs = np.stack([
        w[k] * multivariate_normal.pdf(X, mean=gmm.mean[k], cov=gmm.covar[k])
        for k in range(gmm.K)
    ], axis=1)
    resp = pdfs / pdfs.sum(axis=1, keepdims=True)
    df = df.assign(comp=resp.argmax(axis=1))
    weights = (w * 100).round(1)

    # 2) build subplots
    fig = make_subplots(
        rows=1, cols=2,
        specs=[[{'type':'scene'},{'type':'xy'}]],
        column_widths=[0.7, 0.3],
        subplot_titles=[title, 'Weights (%)']
    )

    # 3) prepare unit‐sphere for ellipsoids
    u = np.linspace(0, 2*np.pi, 40)
    v = np.linspace(0, np.pi, 20)
    sphere = np.stack([
        np.outer(np.cos(u), np.sin(v)),
        np.outer(np.sin(u), np.sin(v)),
        np.outer(np.ones_like(u),   np.cos(v))
    ], axis=-1)  # shape (40,20,3)

    for k in range(gmm.K):
        name = labels[k]
        sel  = (df.comp == k)

        # scatter
        fig.add_trace(go.Scatter3d(
            x=df.loc[sel,'v_R'], y=df.loc[sel,'v_phi'], z=df.loc[sel,'v_Z'],
            mode='markers',
            marker=dict(size=1, color=colors[k]),
            name=name, legendgroup=name, opacity=1.0,
            hovertemplate=(
                f"<b>{name}</b><br>"
                "v_R: %{x:.1f} km/s<br>"
                "v_φ: %{y:.1f} km/s<br>"
                "v_Z: %{z:.1f} km/s<extra></extra>"
            )
        ), row=1, col=1)

        # compute 2σ‐ellipsoid transform
        vals, vecs = np.linalg.eigh(gmm.covar[k])
        scales = 2 * np.sqrt(vals)       # 2σ along each principal axis
        M      = vecs * scales           # shape (3,3): each column i scaled by scales[i]
        ell    = sphere @ M.T            # apply M to each unit‐sphere point

        # plot ellipsoid
        fig.add_trace(go.Mesh3d(
            x=ell[...,0].ravel() + gmm.mean[k][0],
            y=ell[...,1].ravel() + gmm.mean[k][1],
            z=ell[...,2].ravel() + gmm.mean[k][2],
            alphahull=0, opacity=0.15, color=colors[k],
            name=name, legendgroup=name, showlegend=False
        ), row=1, col=1)

    # bar chart of weights
    for k in range(gmm.K):
        fig.add_trace(go.Bar(
            x=[labels[k]], y=[weights[k]],
            marker_color=colors[k],
            showlegend=False, opacity=1.0,
            text=[f"{weights[k]}%"], textposition='auto'
        ), row=1, col=2)
    fig.update_xaxes(categoryorder='array', categoryarray=labels, row=1, col=2)

    # dropdown buttons
    groups = [t.legendgroup for t in fig.data]
    buttons = [dict(label='All', method='restyle',
                    args=[{'opacity': [1.0]*len(groups)}])]
    for lab in labels:
        op = [1.0 if grp == lab else 0.1 for grp in groups]
        buttons.append(dict(label=lab, method='restyle', args=[{'opacity': op}]))

    fig.update_layout(
        title=None,
        scene=dict(
            xaxis_title='v_R (km/s)',
            yaxis_title='v_φ (km/s)',
            zaxis_title='v_Z (km/s)',
            camera=dict(
                up=dict(x=0, y=1, z=0),    # ensure v_φ (y‐axis) is drawn vertically “up”
                eye=dict(x=1.5, y=1.5, z=1.0)
            )
        ),
        updatemenus=[dict(
            buttons=buttons,
            direction='down',
            x=0.5, y=1.10,
            xanchor='center',
            yanchor='bottom'
        )],
        width=1400,
        height=750,
        margin=dict(l=0, r=0, b=120, t=40)
    )


    fig.add_annotation(
        text=("GMM decompositions of the 3D velocity distribution "
              "(v_R–v_φ–v_Z). Use the dropdown to highlight a component; "
              "its corresponding bar will also light up."),
        x=0.5, y=-0.18, xref='paper', yref='paper',
        showarrow=False, align='center',
        font=dict(size=12), bgcolor='white',
        bordercolor='black', borderwidth=1, borderpad=8
    )

    fig.show()

    fn  = title.replace(' ', '_').replace(':','').replace('<','').replace('>','').replace('/','_')
    out = f"../figures/{fn}.html"
    fig.write_html(out, include_plotlyjs='cdn')
    print("Saved →", out)



bin_info_high = [
    # (title, min, max, pkl,                       colors,                              labels)
    ("VMP_high  : -3<[M/H]<-2",  -3.0, -2.0, "../models/gmm_vmp_high.pkl",
        ['red'],                             ["Stationary halo"]
    ),
    ("IMP_high  : -2<[M/H]<-1.6", -2.0, -1.6, "../models/gmm_imp_high.pkl",
        ['red','blue','aqua','gold'],       ["Stationary halo","Prograde halo","GS/E(1)","GS/E(2)"]
    ),
    ("MP1_high  : -1.6<[M/H]<-1.3",-1.6, -1.3,"../models/gmm_mp1_high.pkl",
        ['aqua','gold','blue','green','red'],["GS/E(1)","GS/E(2)","Prograde halo","Thick Disc","Stationary halo"]
    ),
    ("MP2_high  : -1.3<[M/H]<-1.0",-1.3, -1.0,"../models/gmm_mp2_high.pkl",
        ['red','purple','green'],            ["Stationary halo","GS/E","Thick Disc"]
    ),
]

bin_info_low = [
    ("VMP_low   : -3<[M/H]<-2",  -3.0, -2.0,"../models/gmm_vmp_low.pkl",
        ['red'],                             ["Stationary halo"]
    ),
    ("IMP_low   : -2<[M/H]<-1.6", -2.0, -1.6,"../models/gmm_imp_low.pkl",
        ['red','blue'],                      ["Stationary halo","Prograde halo"]
    ),
    ("MP1_low   : -1.6<[M/H]<-1.3",-1.6, -1.3,"../models/gmm_mp1_low.pkl",
        ['purple','red','blue'],             ["GS/E","Stationary halo","Thick Disc/Prograde halo"]
    ),
    ("MP2_low   : -1.3<[M/H]<-1.0",-1.3, -1.0,"../models/gmm_mp2_low.pkl",
        ['blue','green','aqua','gold','red'],["Prograde halo","Thick Disc","GS/E(1)","GS/E(2)","Stationary halo"]
    ),
]

for title, mn, mx, pkl_path, colors, labels in tqdm(bin_info_high, desc="High-α bins"):
    df_bin = filter_metallicity_bin(df_high, mn, mx)
    with open(pkl_path, 'rb') as f:
        gmm = pickle.load(f)
    plot_gmm_3d_with_weights(df_bin, gmm, colors, labels, title)

for title, mn, mx, pkl_path, colors, labels in tqdm(bin_info_low, desc="Low-α bins"):
    df_bin = filter_metallicity_bin(df_low, mn, mx)
    with open(pkl_path, 'rb') as f:
        gmm = pickle.load(f)
    plot_gmm_3d_with_weights(df_bin, gmm, colors, labels, title)


High-α bins:  50%|█████     | 2/4 [00:00<00:00,  5.45it/s]

Saved → ../figures/VMP_high___-3[M_H]-2.html
Saved → ../figures/IMP_high___-2[M_H]-1.6.html


High-α bins:  75%|███████▌  | 3/4 [00:00<00:00,  5.52it/s]

Saved → ../figures/MP1_high___-1.6[M_H]-1.3.html


High-α bins: 100%|██████████| 4/4 [00:00<00:00,  4.34it/s]


Saved → ../figures/MP2_high___-1.3[M_H]-1.0.html


Low-α bins:  50%|█████     | 2/4 [00:00<00:00,  5.02it/s]

Saved → ../figures/VMP_low____-3[M_H]-2.html
Saved → ../figures/IMP_low____-2[M_H]-1.6.html


Low-α bins: 100%|██████████| 4/4 [00:00<00:00,  5.12it/s]

Saved → ../figures/MP1_low____-1.6[M_H]-1.3.html
Saved → ../figures/MP2_low____-1.3[M_H]-1.0.html



