In [168]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pickle
from astropy.coordinates import SkyCoord, ICRS, Galactic
from astropy import units as u
import corner
import pymc as pm

In [2]:
def mode_reals(array, bins=50):
    counts, bin_edges = np.histogram(array, bins=bins)
    # Take left edges as approximation for bin midpoint
    bins_left_edges = bin_edges[:-1]
    return bins_left_edges[np.argmax(counts)]

# Load SBI posterior samples

In [3]:
fin_folder = '/Users/ratzenboe/Documents/work/code/notebooks/SBI/posterior_samples_validation/'
fname = '500clExp_allInfos_Sagitta_like_unseen_clusters.pkl'
with open(fin_folder + fname, "rb") as handle:
    samples = pickle.load(handle)

In [4]:
df_val = samples['validation_set']

In [5]:
max_int = np.max([k for k in samples.keys() if isinstance(k, int)]) + 1

In [6]:
f_summary = np.median

dist_pred = np.array([f_summary(1000/samples[i]['plx_post']) for i in np.arange(max_int)])
dist_true = np.array([1000/samples[i]['plx_true'] for i in np.arange(max_int)])
dist_obs = 1000/df_val['parallax_obs'].values

logAge_pred = np.array([f_summary(samples[i]['logAge_post']) for i in np.arange(max_int)])
logAge_true = np.array([samples[i]['logAge_true'] for i in np.arange(max_int)])

feh_pred = np.array([f_summary(samples[i]['feh_post']) for i in np.arange(max_int)])
A_V_pred = np.array([f_summary(samples[i]['A_V_post']) for i in np.arange(max_int)])

# Load Sagitta estimates

In [7]:
age_pred_sagitta = np.load('age_sagitta_median.npy')
age_pred_sagitta_std = np.load('age_sagitta_std.npy')

# Prepare data

In [8]:
# Compute predicted X,Y,Z positions
c_gal = SkyCoord(
    u=df_val.X_obs.values*u.pc, 
    v=df_val.Y_obs.values*u.pc, 
    w=df_val.Z_obs.values * u.pc, 
    frame='galactic',
    representation_type='cartesian'
)

c_icrs = c_gal.transform_to(ICRS())
c_icrs.representation_type = 'spherical'

In [9]:
c = SkyCoord(ra=c_icrs.ra.value*u.deg, dec=c_icrs.dec.value*u.deg, distance=dist_pred * u.pc, frame='icrs')
c = c.transform_to(Galactic())
c.representation_type = 'cartesian'
df_val['X_pred'] = c.u.value
df_val['Y_pred'] = c.v.value
df_val['Z_pred'] = c.w.value

# Plot comparison in XY plane 

In [10]:
# Construct coolwarm color map with transition at given threshold
import matplotlib

def shiftedColorMap(cmap, start=0, midpoint=0.5, stop=1.0, name='shiftedcmap'):
    '''
    Function to offset the "center" of a colormap. Useful for
    data with a negative min and positive max and you want the
    middle of the colormap's dynamic range to be at zero.

    Input
    -----
      cmap : The matplotlib colormap to be altered
      start : Offset from lowest point in the colormap's range.
          Defaults to 0.0 (no lower offset). Should be between
          0.0 and `midpoint`.
      midpoint : The new center of the colormap. Defaults to 
          0.5 (no shift). Should be between 0.0 and 1.0. In
          general, this should be  1 - vmax / (vmax + abs(vmin))
          For example if your data range from -15.0 to +5.0 and
          you want the center of the colormap at 0.0, `midpoint`
          should be set to  1 - 5/(5 + 15)) or 0.75
      stop : Offset from highest point in the colormap's range.
          Defaults to 1.0 (no upper offset). Should be between
          `midpoint` and 1.0.
    '''
    cdict = {
        'red': [],
        'green': [],
        'blue': [],
        'alpha': []
    }

    # regular index to compute the colors
    reg_index = np.linspace(start, stop, 257)

    # shifted index to match the data
    shift_index = np.hstack([
        np.linspace(0.0, midpoint, 128, endpoint=False), 
        np.linspace(midpoint, 1.0, 129, endpoint=True)
    ])

    for ri, si in zip(reg_index, shift_index):
        r, g, b, a = cmap(ri)

        cdict['red'].append((si, r, r))
        cdict['green'].append((si, g, g))
        cdict['blue'].append((si, b, b))
        cdict['alpha'].append((si, a, a))

    newcmap = matplotlib.colors.LinearSegmentedColormap(name, cdict)
    matplotlib.colormaps.register(cmap=newcmap)
    return newcmap

orig_cmap = matplotlib.cm.coolwarm

min_age, max_age = 6.2, 9
transition_age = 7.78
midpoint = (transition_age - min_age) / (max_age - min_age)
shifted_cmap = shiftedColorMap(orig_cmap, midpoint=midpoint, name='shifted_coolwarm_60_6p2-9')

In [11]:
fig, axes = plt.subplots(1, 4, figsize=(20, 5), sharex=False, sharey=True) #, =0.1, wspace=0.1)
plt.subplot(1, 4, 1)
lim = 1000

df_plt = df_val

plt.scatter(
    df_plt.loc[df_plt.labels==-1, 'X_obs'], df_plt.loc[df_plt.labels==-1, 'Y_obs'], 
    s=1, alpha=0.1, 
    cmap=shifted_cmap, c=df_plt.loc[df_plt.labels==-1, 'logAge'], vmin=min_age, vmax=max_age
    #c='tab:grey'
)
plt.scatter(
    df_plt.loc[df_plt.labels!=-1, 'X_obs'], df_plt.loc[df_plt.labels!=-1, 'Y_obs'], 
    s=5, alpha=0.7,
    cmap=shifted_cmap, c=df_plt.loc[df_plt.labels!=-1, 'logAge'], vmin=min_age, vmax=max_age
)

plt.xlabel('X [kpc]', fontsize=15)
plt.ylabel('Y [kpc]', fontsize=15)
plt.tick_params(axis='both', which='both', bottom=True, top=False, left=True, right=False, labelbottom=True, labelleft=True)
plt.yticks([-1000, -500, 0, 500, 1000], [-1, -0.5, 0, 0.5, 1], fontsize=13)
plt.xticks([-1000, -500, 0, 500, 1000], [-1, -0.5, 0, 0.5, 1], fontsize=13)
plt.xlim(-lim, lim)
plt.ylim(-lim, lim)
plt.title('Observed positions', fontsize=15)

plt.subplot(1, 4, 2)
plt.scatter(
    df_plt.loc[df_plt.labels==-1, 'X'], df_plt.loc[df_plt.labels==-1, 'Y'], 
    s=1, alpha=0.1,
    # c='tab:grey'
    cmap=shifted_cmap, c=df_plt.loc[df_plt.labels==-1, 'logAge'], vmin=min_age, vmax=max_age
)
plt.scatter(
    df_plt.loc[df_plt.labels!=-1, 'X'], df_plt.loc[df_plt.labels!=-1, 'Y'], 
    s=5, alpha=0.7, 
    cmap=shifted_cmap, c=df_plt.loc[df_plt.labels!=-1, 'logAge'], vmin=min_age, vmax=max_age
)
# plt.scatter(df_val['X'], df_val['Y'], s=2, alpha=0.1, cmap='coolwarm', c=logAge_true)
plt.xlabel('X [kpc]', fontsize=15)
plt.tick_params(axis='both', which='both', bottom=True, top=False, left=False, right=False, labelbottom=True, labelleft=False)
plt.xticks([-1000, -500, 0, 500, 1000], [-1, -0.5, 0, 0.5, 1], fontsize=13)
plt.xlim(-lim, lim)
plt.ylim(-lim, lim)
plt.title('True positions', fontsize=15)

plt.subplot(1, 4, 3)
th = 7.78
cut_bg = (df_plt.labels==-1) & (logAge_pred<th)
cut_sig = (df_plt.labels!=-1) & (logAge_pred<th)

plt.scatter(
    df_plt.loc[cut_bg, 'X_pred'], df_plt.loc[cut_bg, 'Y_pred'], 
    s=1, alpha=0.1,
    # c='tab:grey'
    cmap=shifted_cmap, c=df_plt.loc[cut_bg, 'logAge'], vmin=min_age, vmax=max_age
)
plt.scatter(
    df_plt.loc[cut_sig, 'X_pred'], df_plt.loc[cut_sig, 'Y_pred'], 
    s=5, alpha=0.7, 
    cmap=shifted_cmap, c=df_plt.loc[cut_sig, 'logAge'], vmin=min_age, vmax=max_age
)
plt.xlabel('X [kpc]', fontsize=15)
# Remove ticks
plt.tick_params(axis='both', which='both', bottom=True, top=False, left=False, right=False, labelbottom=True, labelleft=False)
plt.xticks([-1000, -500, 0, 500, 1000], [-1, -0.5, 0, 0.5, 1], fontsize=13)
plt.xlim(-lim, lim)
plt.ylim(-lim, lim)
plt.title(f'YSO filter SBI + inferred plx (age < {10**th/1e6:.0f} Myr)', fontsize=15)


plt.subplot(1, 4, 4)
th = 7.78
cut_bg = (df_plt.labels==-1) & (age_pred_sagitta<th) 
cut_sig_sag = (df_plt.labels!=-1) & (age_pred_sagitta<th)

plt.scatter(
    df_plt.loc[cut_bg, 'X_pred'], df_plt.loc[cut_bg, 'Y_pred'], 
    s=1, alpha=0.1, 
    # c='tab:grey'
    cmap=shifted_cmap, c=df_plt.loc[cut_bg, 'logAge'], vmin=min_age, vmax=max_age
)

scatter_plt = plt.scatter(
    df_plt.loc[cut_sig_sag, 'X_obs'], df_plt.loc[cut_sig_sag, 'Y_obs'], 
    s=5, alpha=0.7, 
    cmap=shifted_cmap, c=df_plt.loc[cut_sig_sag, 'logAge'], vmin=min_age, vmax=max_age
)
plt.xlabel('X [kpc]', fontsize=15)
# Remove ticks
plt.tick_params(axis='both', which='both', bottom=True, top=False, left=False, right=False, labelbottom=True, labelleft=False)
plt.xticks([-1000, -500, 0, 500, 1000], [-1, -0.5, 0, 0.5, 1], fontsize=13)
plt.xlim(-lim, lim)
plt.ylim(-lim, lim)
plt.title(f'YSO filter Sagitta (age < {10**th/1e6:.0f} Myr)', fontsize=15)


cbar_ax = fig.add_axes([1.005, 0.14, 0.01, 0.78])
cbar = plt.colorbar(cax=cbar_ax, ticks=[6.5, 7, 7.5, 8., 8.5, 9]) #, fontsize=15) 
cbar.ax.tick_params(labelsize=15)
cbar.set_label(label='log(Age)', size=15, labelpad=3)
cbar.solids.set(alpha=1)
# clb.ax.set_title('Your Label',fontsize=8)

plt.tight_layout(w_pad=0.1, h_pad=.5)
plt.savefig('/Users/ratzenboe/Desktop/figures/observed_vs_true_vs_predicted_noPlx_unseenClusters.png', dpi=300, bbox_inches='tight')

In [12]:
# # Plot Sagitta vs SBI ages
# plt.subplots(1, 2, figsize=(13, 6))
# s = 2
# plt.subplot(1, 2, 1)
# xlim = 5, 10.1
# plt.scatter(df_val.logAge, age_pred_sagitta, s=s, alpha=0.2, label='Sagitta', c='k')
# plt.plot(xlim, xlim, 'k--')
# plt.xlim(xlim)
# plt.ylim(xlim)
# plt.xlabel('True log(Age)', fontsize=15)
# plt.ylabel('Sagitta log(Age)', fontsize=15)
# plt.title('Sagitta vs True', fontsize=15)
# # Equal aspect ratio
# plt.gca().set_aspect('equal', adjustable='box')
# 
# plt.subplot(1, 2, 2)
# plt.scatter(df_val.logAge, logAge_pred, s=s, alpha=0.2, label='SBI', c='k')
# plt.plot(xlim, xlim, 'k--')
# plt.xlim(xlim)
# plt.ylim(xlim)
# plt.xlabel('True log(Age)', fontsize=15)
# plt.ylabel('SBI log(Age)', fontsize=15)
# plt.title('SBI vs True', fontsize=15)
# # Equal aspect ratio
# plt.gca().set_aspect('equal', adjustable='box')
# 
# # plt.tight_layout()
# 
# plt.savefig('/Users/ratzenboe/Desktop/figures/Sagitta_vs_SBI_vs_true_ages_unseenClusters.png', dpi=300, bbox_inches='tight')

In [265]:
int_idx = np.arange(df_val.shape[0])[(df_val.logAge < 7.78) & (df_val.labels != -1) & (np.abs(1000/df_val.parallax_obs - 1000/df_val.parallax) > 50) & (np.abs(dist_pred - 1000/df_val.parallax) < 20) & (np.abs(df_val.logAge - logAge_pred) < 0.2) & (df_val.logAge > 7.2)]

In [266]:
int_idx.size

In [267]:
i = int_idx[23]
post_samples_age = 10**samples[i]['logAge_post']/1e6
post_samples_dist = 1000/samples[i]['plx_post'] 

ps = np.vstack([post_samples_age, post_samples_dist]).T

post_samples_age.min(), post_samples_age.max(), post_samples_dist.min(), post_samples_dist.max()

In [286]:
# Compare posterior samples with true values
truths = [10**df_val.loc[i, 'logAge']/1e6, 1000/df_val.loc[i, 'parallax']]

fig = corner.corner(
    ps, labels=['Age [Myr]', 'Distance [pc]'], truths=truths, color='tab:grey', truth_color='k',
    smooth1d=0.05, smooth=1, 
    range=[(0, 100), (100, 1500)]
)

allaxes = fig.get_axes()

# Add observed values in distance
lo, med, hi = np.percentile(1000/np.random.normal(loc=df_val.loc[i, 'parallax_obs'], scale=df_val.loc[i, 'parallax_error'], size=1000), [16, 50, 84])
allaxes[-1].axvspan(lo, hi, alpha=0.3, color='tab:blue', )
allaxes[-1].axvline(med, color='tab:blue', ls='--')

# med_est = mode_reals(post_samples_dist, bins=150)
# med_est = np.median(post_samples_dist)
# lo_est, hi_est = pm.hdi(post_samples_dist, 0.64)
lo_est, med_est, hi_est = np.percentile(post_samples_dist, [16, 50, 84])
allaxes[-1].axvspan(lo_est, hi_est, alpha=0.3, color='tab:orange', )
allaxes[-1].axvline(med_est+20, color='tab:orange', ls='--')

# Add Sagitta estimate in age
med_sag = 10**age_pred_sagitta[i]/1e6
lo_sag = 10**(age_pred_sagitta[i]-age_pred_sagitta_std[i])/1e6
hi_sag = 10**(age_pred_sagitta[i]+age_pred_sagitta_std[i])/1e6
allaxes[0].axvspan(lo_sag, hi_sag, alpha=0.3, color='tab:green', )
allaxes[0].axvline(med_sag, color='tab:green', ls='--')

med_est = mode_reals(post_samples_age, bins=150)
# med_est = np.median(post_samples_age) 
lo_est, hi_est = pm.hdi(post_samples_age, 0.64)
# med_est = np.median(post_samples_age)
# lo_est, med_est, hi_est = np.percentile(post_samples_age, [16, 50, 84])
allaxes[0].axvspan(lo_est, hi_est, alpha=0.3, color='tab:orange')
allaxes[0].axvline(med_est, color='tab:orange', ls='--')

plt.savefig(f'/Users/ratzenboe/Desktop/figures/posterior_samples_age_dist_unseenClusters_{i}.png', dpi=300, bbox_inches='tight')
i

In [None]:
# Plot SNR vs accuracy
