In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pickle
import torch
from astropy.coordinates import SkyCoord, ICRS, Galactic
from astropy import units as u

In [2]:
def mode_reals(array, bins=50):
    counts, bin_edges = np.histogram(array, bins=bins)
    # Take left edges as approximation for bin midpoint
    bins_left_edges = bin_edges[:-1]
    return bins_left_edges[np.argmax(counts)]

In [3]:
fin_folder = '/Users/ratzenboe/Documents/work/code/notebooks/SBI/posterior_samples_validation/'
fname = '500clExp_allInfos_Sagitta_like'
with open(fin_folder + fname, "rb") as handle:
    samples = pickle.load(handle)

In [4]:
df_val = samples['validation_set']

In [5]:
max_int = np.max([k for k in samples.keys() if isinstance(k, int)]) + 1

In [11]:
f_summary = np.mean

dist_pred = np.array([f_summary(1000/samples[i]['plx_post']) for i in np.arange(max_int)])
dist_true = np.array([1000/samples[i]['plx_true'] for i in np.arange(max_int)])
dist_obs = 1000/df_val['parallax_obs'].values

logAge_pred = np.array([f_summary(samples[i]['logAge_post']) for i in np.arange(max_int)])
logAge_true = np.array([samples[i]['logAge_true'] for i in np.arange(max_int)])

feh_pred = np.array([f_summary(samples[i]['feh_post']) for i in np.arange(max_int)])
A_V_pred = np.array([f_summary(samples[i]['A_V_post']) for i in np.arange(max_int)])

# Prepare plots

In [12]:
# Compute predicted X,Y,Z positions
c_gal = SkyCoord(
    u=df_val.X_obs.values*u.pc, 
    v=df_val.Y_obs.values*u.pc, 
    w=df_val.Z_obs.values * u.pc, 
    frame='galactic',
    representation_type='cartesian'
)

c_icrs = c_gal.transform_to(ICRS())
c_icrs.representation_type = 'spherical'

In [13]:
c = SkyCoord(ra=c_icrs.ra.value*u.deg, dec=c_icrs.dec.value*u.deg, distance=dist_pred * u.pc, frame='icrs')
c = c.transform_to(Galactic())
c.representation_type = 'cartesian'
df_val['X_pred'] = c.u.value
df_val['Y_pred'] = c.v.value
df_val['Z_pred'] = c.w.value

In [14]:
plt.subplots(1, 3, figsize=(15, 5), sharex=False, sharey=True) #, =0.1, wspace=0.1)
plt.subplot(1, 3, 1)
lim = 1000

cut_labels =  df_val.labels!=8
df_plt = df_val.loc[cut_labels]

# plt.scatter(df_plt.X_obs, df_plt.Y_obs, s=1, alpha=0.05, c='tab:grey') #cmap='viridis')
plt.scatter(df_plt.loc[df_plt.labels==-1, 'X_obs'], df_plt.loc[df_plt.labels==-1, 'Y_obs'], s=1, alpha=0.05, c='tab:grey')
plt.scatter(df_plt.loc[df_plt.labels!=-1, 'X_obs'], df_plt.loc[df_plt.labels!=-1, 'Y_obs'], s=2, alpha=0.1, cmap='tab10', c=df_plt.loc[df_plt.labels!=-1, 'labels'])

plt.xlabel('X [kpc]', fontsize=15)
plt.ylabel('Y [kpc]', fontsize=15)
plt.tick_params(axis='both', which='both', bottom=True, top=False, left=True, right=False, labelbottom=True, labelleft=True)
plt.yticks([-1000, -500, 0, 500, 1000], [-1, -0.5, 0, 0.5, 1], fontsize=13)
plt.xticks([-1000, -500, 0, 500, 1000], [-1, -0.5, 0, 0.5, 1], fontsize=13)
plt.xlim(-lim, lim)
plt.ylim(-lim, lim)
plt.title('Observed positions', fontsize=15)

plt.subplot(1, 3, 2)
plt.scatter(df_plt.loc[df_plt.labels==-1, 'X'], df_plt.loc[df_plt.labels==-1, 'Y'], s=1, alpha=0.05, c='tab:grey')
plt.scatter(df_plt.loc[df_plt.labels!=-1, 'X'], df_plt.loc[df_plt.labels!=-1, 'Y'], s=2, alpha=0.1, cmap='tab10', c=df_plt.loc[df_plt.labels!=-1, 'labels'])
# plt.scatter(df_val['X'], df_val['Y'], s=2, alpha=0.1, cmap='coolwarm', c=logAge_true)
plt.xlabel('X [kpc]', fontsize=15)
plt.tick_params(axis='both', which='both', bottom=True, top=False, left=False, right=False, labelbottom=True, labelleft=False)
plt.xticks([-1000, -500, 0, 500, 1000], [-1, -0.5, 0, 0.5, 1], fontsize=13)
plt.xlim(-lim, lim)
plt.ylim(-lim, lim)
plt.title('True positions', fontsize=15)

plt.subplot(1, 3, 3)
th = 8.
cut_bg = (df_plt.labels==-1) & (logAge_pred[cut_labels]<th) & (df_val.logAge < th)
cut_sig = (df_plt.labels!=-1) & (logAge_pred[cut_labels]<th) & (df_val.logAge < th)


plt.scatter(df_plt.loc[cut_bg, 'X_pred'], df_plt.loc[cut_bg, 'Y_pred'], s=1, alpha=0.05, c='tab:grey')
plt.scatter(df_plt.loc[cut_sig, 'X_pred'], df_plt.loc[cut_sig, 'Y_pred'], s=2, alpha=0.1, cmap='tab10', c=df_plt.loc[cut_sig, 'labels'])

# plt.scatter(df_plt.loc[logAge_pred[cut_labels]<7.9, 'X_pred'], df_plt.loc[logAge_pred[cut_labels]<7.9, 'Y_pred'], s=1, alpha=0.05, c='k')

plt.xlabel('X [kpc]', fontsize=15)
# Remove ticks
plt.tick_params(axis='both', which='both', bottom=True, top=False, left=False, right=False, labelbottom=True, labelleft=False)
plt.xticks([-1000, -500, 0, 500, 1000], [-1, -0.5, 0, 0.5, 1], fontsize=13)
plt.xlim(-lim, lim)
plt.ylim(-lim, lim)
plt.title(f'Predicted positions (age < {10**th/1e6:.0f} Myr)', fontsize=15)

plt.tight_layout(w_pad=0.1, h_pad=.5)
# plt.savefig('/Users/ratzenboe/Desktop/figures/observed_vs_true_vs_predicted_noPlx_100clusters.png', dpi=300)

In [15]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
lim_x = [0, 1100]
lim_y = [0, 1200]

cut_labels = (df_val.logAge<7.8) & (df_val.labels > -1)

axes[0].scatter(dist_true[cut_labels], dist_obs[cut_labels], s=10, alpha=0.1)
axes[0].set_xlabel('True distance [pc]')
axes[0].set_ylabel('Observed distance [pc]')
axes[0].set_xlim(lim_x)
axes[0].set_ylim(lim_y)

axes[1].scatter(dist_true[cut_labels], dist_pred[cut_labels], s=10, alpha=0.1)
axes[1].set_xlabel('True distance [pc]')
axes[1].set_ylabel('Predicted distance [pc]')
axes[1].set_xlim(lim_x)
axes[1].set_ylim(lim_y)

plt.tight_layout()

In [16]:
plt.scatter(df_val.loc[cut_labels, 'feh'].values, dist_true[cut_labels] - dist_obs[cut_labels], s=5, alpha=0.01)
plt.ylim(-300, 300)
plt.axhline(0)

In [17]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
lim_x = [0, 1100]
lim_y = [0, 1500]

axes[0].scatter(dist_true, dist_obs, s=1, alpha=0.1)
axes[0].set_xlabel('True distance [pc]')
axes[0].set_ylabel('Observed distance [pc]')
axes[0].set_xlim(lim_x)
axes[0].set_ylim(lim_y)

axes[1].scatter(dist_true, dist_pred, s=1, alpha=0.1)
axes[1].set_xlabel('True distance [pc]')
axes[1].set_ylabel('Predicted distance [pc]')
axes[1].set_xlim(lim_x)
axes[1].set_ylim(lim_y)

plt.tight_layout()