In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch.utils.data import TensorDataset, DataLoader, random_split
import sbi.utils as utils
from astropy.coordinates import SkyCoord, ICRS, Galactic
from astropy import units as u
import pickle
from sbi import analysis as analysis
from sbi.inference.base import infer

In [2]:
# Load data
df = pd.read_csv('/Users/ratzenboe/Library/CloudStorage/Dropbox/work/data/mock_edr3/edr3_mock_field+clusters.csv')

In [3]:
df.logAge.hist(bins=100, log=True);

In [4]:
features_X = [
    'parallax_obs', 
    'phot_g_mean_mag_obs', 'phot_bp_mean_mag_obs', 'phot_rp_mean_mag_obs',
    'j_obs', 'h_obs', 'k_obs', 
    'w1_obs', 'w2_obs', 'w3_obs', 'w4_obs',
    'irac1_obs', 'irac2_obs', 'irac3_obs', 'irac4_obs', 'mips1_obs', 
    # Errors
    'parallax_error',
    'phot_g_mean_mag_error', 'phot_bp_mean_mag_error', 'phot_rp_mean_mag_error', 
    'j_error', 'h_error', 'k_error',
    'w1_error', 'w2_error', 'w3_error', 'w4_error',
    'irac1_error', 'irac2_error', 'irac3_error', 'irac4_error', 'mips1_error'
]

features_y = ['parallax', 'logAge', 'A_V', 'feh']

In [8]:
df[features_y].min()

In [9]:
x_orig = torch.tensor(df[features_X].values.astype(np.float32))
theta_orig = torch.tensor(df[features_y].values.astype(np.float32))

# Normalize the data

In [10]:
x_mean = x_orig.mean(dim=0)
x_std = x_orig.std(dim=0)
x_samples = (x_orig - x_mean) / x_std

theta_mean = theta_orig.mean(dim=0)
theta_std = theta_orig.std(dim=0)
theta_samples = (theta_orig - theta_mean) / theta_std

# Train and validation set

In [11]:
val_fraction = 0.5
n_samples_val = int(val_fraction * len(x_samples))

dataset = TensorDataset(x_samples, theta_samples)
dataset_train, dataset_val = random_split(dataset, [len(x_samples) - n_samples_val, n_samples_val])
x_train, theta_train = dataset_train.dataset.tensors
x_val, theta_val = dataset_val.dataset.tensors

# Define priors

In [12]:
theta_mins = torch.tensor(df[features_y].min().values.astype(np.float32))
theta_maxs = torch.tensor(df[features_y].max().values.astype(np.float32))
# Normalize the mins and maxs
theta_mins = (theta_mins - theta_mean) / theta_std
theta_maxs = (theta_maxs - theta_mean) / theta_std

# Define prior
prior = utils.BoxUniform(
    low=theta_mins,
    high=theta_maxs
)

In [14]:
from sbi.inference import SNPE
# sample parameters theta and observations x
inference = SNPE(prior=prior)
inference.append_simulations(x=x_train, theta=theta_train)

In [16]:
%%time
density_estimator = inference.train()

# Build the posterior

In [17]:
posterior = inference.build_posterior(density_estimator) #, sample_with='mcmc')

In [18]:
# # Save model
fname = "/Users/ratzenboe/Documents/work/code/notebooks/SBI/trained_models/model_new_withPlx.pkl"
with open(fname, "wb") as handle:
    pickle.dump(posterior, handle)

In [19]:
%%time

post_info = {}

i = 0
# i_max = 50_000
for x_val, theta_val in dataset_val:
    posterior_samples = posterior.sample((1_000,), x=x_val, show_progress_bars=False)
    plx_post, logAge_post, A_V_post, feh_post = (posterior_samples * theta_std + theta_mean).numpy().T
    plx_true, logAge_true, A_V_true, feh_true = (theta_val * theta_std + theta_mean).numpy().T

    post_info[i] = {
        'post_samples': posterior_samples * theta_std + theta_mean,
        'true': theta_val * theta_std + theta_mean,
        'plx_post': plx_post,
        'plx_true': plx_true,
        'logAge_post': logAge_post,
        'logAge_true': logAge_true,
        'A_V_post': A_V_post,
        'A_V_true': A_V_true,
        'feh_post': feh_post,
        'feh_true': feh_true
    }
    i+=1

    # if i >= i_max:
    #     break

In [20]:
with open("/Users/ratzenboe/Documents/work/code/notebooks/SBI/trained_models/validation_est_NEW_with_plx_infos.pkl", "wb") as handle:
    pickle.dump(post_info, handle)

In [21]:
import corner

In [22]:
# # Plot posterior samples
# i = 50
# corner.corner(post_info[i]['post_samples'].numpy(), labels=features_y, truths=post_info[i]['true'].numpy());

In [23]:
dist_pred = np.array([np.median(1000/pi['plx_post']) for pi in post_info.values()])
dist_true = np.array([1000/pi['plx_true'] for pi in post_info.values()])
dist_obs = 1000/df.loc[dataset_val.indices, 'parallax_obs'].values

logAge_pred = np.array([np.median(pi['logAge_post']) for pi in post_info.values()])
logAge_true = np.array([pi['logAge_true'] for pi in post_info.values()])

In [24]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
lim_x = [0, 1100]
lim_y = [0, 1500]

axes[0].scatter(dist_true, dist_obs, s=1, alpha=0.1)
axes[0].set_xlabel('True distance [pc]')
axes[0].set_ylabel('Observed distance [pc]')
axes[0].set_xlim(lim_x)
axes[0].set_ylim(lim_y)

axes[1].scatter(dist_true, dist_pred, s=1, alpha=0.1)
axes[1].set_xlabel('True distance [pc]')
axes[1].set_ylabel('Predicted distance [pc]')
axes[1].set_xlim(lim_x)
axes[1].set_ylim(lim_y)

plt.tight_layout()

In [25]:
# Compute predicted X,Y,Z positions
df_val = df.loc[dataset_val.indices]

c_gal = SkyCoord(
    u=df_val.X_obs.values*u.pc, 
    v=df_val.Y_obs.values*u.pc, 
    w=df_val.Z_obs.values * u.pc, 
    frame='galactic',
    representation_type='cartesian'
)

c_icrs = c_gal.transform_to(ICRS())
c_icrs.representation_type = 'spherical'

In [26]:
c = SkyCoord(ra=c_icrs.ra.value*u.deg, dec=c_icrs.dec.value*u.deg, distance=dist_pred * u.pc, frame='icrs')
c = c.transform_to(Galactic())
c.representation_type = 'cartesian'
df_val['X_pred'] = c.u.value
df_val['Y_pred'] = c.v.value
df_val['Z_pred'] = c.w.value

In [38]:
plt.subplots(1, 3, figsize=(15, 5), sharex=False, sharey=True) #, =0.1, wspace=0.1)
plt.subplot(1, 3, 1)
lim = 1000
cmap = 'cool_r'

# cut_labels =  df_val.labels!=8
df_plt = df_val.copy()

# plt.scatter(df_plt.X_obs, df_plt.Y_obs, s=1, alpha=0.05, c='tab:grey') #cmap='viridis')
plt.scatter(df_plt.loc[df_plt.labels==-1, 'X_obs'], df_plt.loc[df_plt.labels==-1, 'Y_obs'], s=1, alpha=0.05, c='tab:grey')
plt.scatter(df_plt.loc[df_plt.labels!=-1, 'X_obs'], df_plt.loc[df_plt.labels!=-1, 'Y_obs'], s=2, alpha=0.1, cmap=cmap, c=df_plt.loc[df_plt.labels!=-1, 'labels'])

plt.xlabel('X [kpc]', fontsize=15)
plt.ylabel('Y [kpc]', fontsize=15)
plt.tick_params(axis='both', which='both', bottom=True, top=False, left=True, right=False, labelbottom=True, labelleft=True)
plt.yticks([-1000, -500, 0, 500, 1000], [-1, -0.5, 0, 0.5, 1], fontsize=13)
plt.xticks([-1000, -500, 0, 500, 1000], [-1, -0.5, 0, 0.5, 1], fontsize=13)
plt.xlim(-lim, lim)
plt.ylim(-lim, lim)
plt.title('Observed positions', fontsize=15)

plt.subplot(1, 3, 2)
plt.scatter(df_plt.loc[df_plt.labels==-1, 'X'], df_plt.loc[df_plt.labels==-1, 'Y'], s=1, alpha=0.05, c='tab:grey')
plt.scatter(df_plt.loc[df_plt.labels!=-1, 'X'], df_plt.loc[df_plt.labels!=-1, 'Y'], s=2, alpha=0.1, cmap=cmap, c=df_plt.loc[df_plt.labels!=-1, 'labels'])
# plt.scatter(df_val['X'], df_val['Y'], s=2, alpha=0.1, cmap='coolwarm', c=logAge_true)
plt.xlabel('X [kpc]', fontsize=15)
plt.tick_params(axis='both', which='both', bottom=True, top=False, left=False, right=False, labelbottom=True, labelleft=False)
plt.xticks([-1000, -500, 0, 500, 1000], [-1, -0.5, 0, 0.5, 1], fontsize=13)
plt.xlim(-lim, lim)
plt.ylim(-lim, lim)
plt.title('True positions', fontsize=15)

plt.subplot(1, 3, 3)
# plt.scatter(df_plt.loc[logAge_pred[cut_labels]<7.9, 'X_pred'], df_plt.loc[logAge_pred[cut_labels]<7.9, 'Y_pred'], s=1, alpha=0.05, c='k')

cut_bg = (df_plt.labels==-1) & (logAge_pred<7.8)
cut_sig = (df_plt.labels!=-1) & (logAge_pred<7.8)

plt.scatter(df_plt.loc[cut_bg, 'X_pred'], df_plt.loc[cut_bg, 'Y_pred'], s=1, alpha=0.05, c='tab:grey')
plt.scatter(df_plt.loc[cut_sig, 'X_pred'], df_plt.loc[cut_sig, 'Y_pred'], s=2, alpha=0.1, cmap=cmap, c=df_plt.loc[cut_sig, 'labels'])
plt.xlabel('X [kpc]', fontsize=15)
# Remove ticks
plt.tick_params(axis='both', which='both', bottom=True, top=False, left=False, right=False, labelbottom=True, labelleft=False)
plt.xticks([-1000, -500, 0, 500, 1000], [-1, -0.5, 0, 0.5, 1], fontsize=13)
plt.xlim(-lim, lim)
plt.ylim(-lim, lim)
plt.title('Predicted positions (age < 60 Myr)', fontsize=15)

plt.tight_layout(w_pad=0.1, h_pad=.5)
plt.savefig('/Users/ratzenboe/Desktop/figures/observed_vs_true_vs_predicted_new.png', dpi=300)

In [59]:
x_plt = (dist_true - dist_obs)/(dist_true - dist_pred)
plt.scatter(dist_true, x_plt, s=1, alpha=0.1)
plt.ylim(-5, 5)
plt.axhline(1.12, c='k', ls='--')

In [49]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
lim_x = [0, 1100]
lim_y = [0, 1300]

cut_labels = df_val.labels > -1

axes[0].plot(lim_y, lim_y, c='k', alpha=0.1)
axes[0].scatter(dist_true[cut_labels], dist_obs[cut_labels], s=10, alpha=0.1, c=df_val.loc[cut_labels, 'labels'], cmap=cmap)
axes[0].set_xlabel('True distance [pc]')
axes[0].set_ylabel('Observed distance [pc]')
axes[0].set_xlim(lim_x)
axes[0].set_ylim(lim_y)

axes[1].plot(lim_y, lim_y, c='k', alpha=0.1)
axes[1].scatter(dist_true[cut_labels], dist_pred[cut_labels], s=10, alpha=0.1, c=df_val.loc[cut_labels, 'labels'], cmap=cmap)
axes[1].set_xlabel('True distance [pc]')
axes[1].set_ylabel('Predicted distance [pc]')
axes[1].set_xlim(lim_x)
axes[1].set_ylim(lim_y)

plt.tight_layout()
plt.savefig('/Users/ratzenboe/Desktop/figures/distance_comparison.png', dpi=300)

In [33]:
for label_i in np.arange(0, 10):
    cut_l_i = (df_val.labels==label_i)
    lA_true = 10**df_val.loc[cut_l_i, 'logAge'].median()/1e6
    dist_i_pred = dist_pred[cut_l_i]
    linAge_i_pred = 10**logAge_pred[cut_l_i]/1e6
    
    A_V_i = df_val.loc[cut_l_i, 'A_V'].median()
    feh_i = df_val.loc[cut_l_i, 'feh'].median()
    
    plt.scatter(linAge_i_pred, dist_i_pred, label=f'A_V: {A_V_i:.1f} | FeH: {feh_i:.1f}', s=10, alpha=0.5)
    plt.axvline(lA_true, c='k', ls='--', lw=1)
    plt.axhline(1000/df_val.loc[cut_l_i, 'parallax'].median(), c='k', ls='--', lw=1)
    plt.xlabel('Predicted logAge [Myr]')
    plt.ylabel('Predicted distance [pc]')
    plt.legend()
    plt.xlim(0, 80)
    plt.ylim(0, 1000)
    plt.show()

In [32]:
A_V_true

In [160]:
# Plot posterior samples
i = np.arange(df_val.shape[0])[df_val.labels==3][3]
corner.corner(post_info[i]['post_samples'].numpy(), labels=features_y, truths=post_info[i]['true'].numpy());

In [143]:
avg_purity = []
for label_i in np.arange(-1, 10): 
    nb_identified = np.sum(df_val.loc[logAge_pred<8, 'labels'] == label_i)
    nb_total = np.sum(df_val['labels'] == label_i)
    avg_purity.append(nb_identified/nb_total)
    print(f'Label {label_i}: {nb_identified} / {nb_total} ({nb_identified/nb_total*100:.2f}%)')

In [144]:
np.median(avg_purity[1:])

In [51]:
plt.scatter(logAge_true, logAge_pred, s=1, alpha=0.1)
plt.xlabel('True logAge')
plt.ylabel('Predicted logAge')
plt.plot([6, 10], [6, 10], 'k--')