In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pickle
import torch
from tqdm import tqdm

# Load the data

In [2]:
# fname_in = '/Users/ratzenboe/Library/CloudStorage/Dropbox/work/data/mock_edr3/edr3_mock_field_UNSEEN_clusters_TEST_sagitta_sagitta.csv'
fname_in = '/Users/ratzenboe/Library/CloudStorage/Dropbox/work/data/mock_edr3/edr3_mock_field_UNSEEN_clusters_TEST.csv'
df = pd.read_csv(fname_in)

In [3]:
df.train_val_samples.sum()

In [34]:
df_test = df.loc[df.test_samples]

# Sample from posterior over multiple models

In [20]:
fin_folder = '/Users/ratzenboe/Documents/work/code/notebooks/SBI/trained_models/'
# fin_folder = 
fout_folder = '/Users/ratzenboe/Library/CloudStorage/Dropbox/work/data/posterior_samples/test_posterior_samples/'

model_names = [
    # 'X_allFeatures__y_parallax_logAge', 
    'X_SEDonly__y_parallax_logAge', 
    # 'X_Sagitta__y_parallax_logAge'
]

In [21]:
n_samples = 500

for features_str in model_names:
    # Load posterior
    with open(fin_folder + f"posterior_{features_str}.pkl", "rb") as handle:
        posterior = pickle.load(handle)
    # Load scale factors
    with open(fin_folder + f"scale_factors_{features_str}.pkl", "rb") as handle:
        scale_factors = pickle.load(handle)
    # Load feature names
    with open(fin_folder + f'features_{features_str}.pkl', 'rb') as handle:
        features_X_y = pickle.load(handle)
    
    # ----- Prepare the data -----
    x_samples = torch.tensor(df.loc[df.test_samples, features_X_y['X']].values.astype(np.float32))
    theta_samples = torch.tensor(df.loc[df.test_samples, features_X_y['y']].values.astype(np.float32))
    # Scale the data
    x_test = (x_samples - scale_factors['x_mean']) / scale_factors['x_std']
    theta_test = (theta_samples - scale_factors['theta_mean']) / scale_factors['theta_std']
    # ----- Sample from the posterior -----
    print(f"Sampling from posterior for model: {features_str}")
    idx_subset = np.arange(x_test.shape[0])  # [:n_samples]
    post_info = {}
    for idx_i in tqdm(idx_subset):
        posterior_samples = posterior.sample((n_samples,), x=x_test[idx_i], show_progress_bars=False)
        # plx_post, logAge_post, A_V_post, feh_post = (
        plx_post, logAge_post = (
                posterior_samples * scale_factors['theta_std'] + scale_factors['theta_mean']).numpy().T
        # plx_true, logAge_true, A_V_true, feh_true = (
        plx_true, logAge_true = (
                theta_test[idx_i] * scale_factors['theta_std'] + scale_factors['theta_mean']).numpy().T
    
        post_info[idx_i] = {
            # 'post_samples': posterior_samples * scale_factors['theta_std'] + scale_factors['theta_mean'],
            # 'true': theta_test[idx_i] * scale_factors['theta_std'] + scale_factors['theta_mean'],
            'plx_post': plx_post,
            'plx_true': plx_true,
            'logAge_post': logAge_post,
            'logAge_true': logAge_true,
            # 'A_V_post': A_V_post,
            # 'A_V_true': A_V_true,
            # 'feh_post': feh_post,
            # 'feh_true': feh_true
        }

    with open(fout_folder + f"posterior_samples_test_{features_str}.pkl", "wb") as handle:
        pickle.dump(post_info, handle, protocol=pickle.HIGHEST_PROTOCOL)

# Quick results plot

In [22]:
# Load the data
features_str = model_names[0] #'X_allFeatures__y_parallax_logAge'
with open(fout_folder + f"posterior_samples_test_{features_str}.pkl", "rb") as handle:
    post_info = pickle.load(handle)
    
print(features_str)

In [23]:
fout_folder + f"posterior_samples_test_{features_str}.pkl"

In [24]:
def mode_reals(array, bins=150):
    counts, bin_edges = np.histogram(array, bins=bins)
    # Take left edges as approximation for bin midpoint
    bins_left_edges = bin_edges[:-1]
    return bins_left_edges[np.argmax(counts)]

In [46]:
f_summary = mode_reals

dist_pred = np.array([f_summary(1000/post_info[i]['plx_post']) for i in df.loc[df.test_samples].id.values])
dist_true = np.array([1000/post_info[i]['plx_true'] for i in df.loc[df.test_samples].id.values])
dist_obs = 1000/df.loc[df.test_samples, 'parallax_obs'].values

logAge_pred = np.array([f_summary(post_info[i]['logAge_post']) for i in df.loc[df.test_samples].id.values])
logAge_true = np.array([post_info[i]['logAge_true'] for i in df.loc[df.test_samples].id.values])

# feh_pred = np.array([f_summary(post_info[i]['feh_post']) for i in np.arange(max_int)])
# A_V_pred = np.array([f_summary(post_info[i]['A_V_post']) for i in np.arange(max_int)])

In [47]:
# plt.hist(df.loc[df.train_val_samples, 'logAge'], bins=50, alpha=0.5, log=True)
# plt.hist(df.loc[df.test_samples, 'logAge'], bins=50, alpha=0.5, log=True)

In [48]:
# plt.scatter(10**logAge_true/1e6, 10**logAge_pred/1e6, s=1, alpha=0.1) #, c=df_test.is_binary_true) #, c=df_val.labels==21)
plt.scatter(logAge_true, logAge_pred, s=1, alpha=0.1)

plt.xlabel('True logAge')
plt.ylabel('Predicted logAge')
min_age, max_age = 6.0, 10
plt.plot([min_age, max_age], [min_age, max_age], color='k')
plt.xlim(min_age, max_age)
plt.ylim(min_age, max_age)
# plt.axhline(6.8)
plt.show()

# plt.hist(logAge_true - logAge_pred, bins=np.linspace(-1, 1, 100))
# plt.xlabel('True - predicted logAge')
# plt.ylabel('Predicted logAge')
# plt.show()

In [49]:
delta_logAge_lim = -5, 5  # limits should be -4, 4 for the test cases (for a full comparison with Sagitta)
cut_age = logAge_true < 8.

plt.hist(logAge_true[cut_age] - logAge_pred[cut_age], bins=np.linspace(*delta_logAge_lim, 100), log=False)
plt.xlabel('True - predicted logAge')
plt.ylabel('Counts')
plt.axvline(0, color='k', ls='--', lw=0.5)
plt.show()

In [50]:
plt.scatter(dist_true, dist_pred, s=1, alpha=0.05) #, c=df_val.labels==21)
plt.plot([0, 1000], [0, 1000], color='k', lw=0.5)

In [53]:
err_band_3s_kwargs = dict(alpha=0.1, color='k')
err_band_1s_kwargs = dict(alpha=0.2, color='k')

# snr = df_test['parallax_obs'] / df_test['parallax_error']
# snr = df_test['phot_g_mean_mag_obs'] / df_test['phot_g_mean_mag_error']

bin_width = 0.1
# bins = np.arange(snr_plx.min(), snr_plx.max(), bin_width)
n_bins = 25
bins = np.linspace(snr.min(), snr.max()/2, n_bins)
la_diffs_lo_1s = []
la_diffs_hi_1s = []
la_diffs_lo_2s = []
la_diffs_hi_2s = []
bin_x = []
bin_medians = []
for i, bin in enumerate(bins[:-1]):
    cut_bin = (snr > bin) & (snr <= bins[i+1])
    if np.sum(cut_bin) == 0:
        continue
    lo_2s, lo_1s, med, hi_1s, hi_2s = np.percentile(
        abs(logAge_pred[cut_bin] - logAge_true[cut_bin]),   
        [50-95.4/2, 16, 50, 84, 50+95.4/2]
    )
    bin_medians.append(med)
    la_diffs_lo_1s.append(lo_1s)
    la_diffs_hi_1s.append(hi_1s)
    la_diffs_lo_2s.append(lo_2s)
    la_diffs_hi_2s.append(hi_2s)
    bin_x.append(bin + bin_width/2)
        
bin_medians = np.array(bin_medians)
plt.fill_between(bin_x, la_diffs_lo_1s, la_diffs_hi_1s, **err_band_1s_kwargs)
plt.fill_between(bin_x, la_diffs_lo_2s, la_diffs_hi_2s, **err_band_3s_kwargs)
plt.plot(bin_x, bin_medians, color='tab:red', lw=1)

plt.xlabel('SNR Gaia G-band')
plt.ylabel('logAge abs error')
# plt.savefig('/Users/ratzenboe/Desktop/figures/Fig-SNR-vs-Sagitta.png', dpi=300, bbox_inches='tight')

In [20]:
# plt.scatter(df_val.feh, feh_pred, s=5, alpha=0.1)
# plt.xlim(-1, 0.5)
# plt.ylim(-1, 0.5)
# plt.plot([-1, 0.5], [-1, 0.5], color='k')

In [83]:
# plt.scatter(df_val.A_V, A_V_pred, s=5, alpha=0.1)
# plt.xlim(-0.1, 4)
# plt.ylim(-0.1, 4)
# plt.plot([-0.1, 4], [-0.1, 4], color='k')

In [84]:
# plt.hist(df_val.A_V - A_V_pred, bins=100, range=(-0.3, 0.3), log=True)