In [20]:
import numpy as np
import pandas as pd
import torch
import sbi.utils as utils
import pickle
from sbi.inference import SNPE

In [21]:
# Load data
fname_in = '/Users/ratzenboe/Library/CloudStorage/Dropbox/work/data/mock_edr3/edr3_mock_field_newClusters_ageSpread_May2024_ALL_NEW.csv'
df = pd.read_csv(fname_in)

In [50]:
df.train_val_samples.sum() + df.test_samples.sum()

In [23]:
# 3 Different models
# ------ All features ------
df['train_val_samples'] = False
df['test_samples'] = False

In [32]:
# Define training and test set rows

features_X_max = [
    'parallax_obs', 'A_V_obs',
    'phot_g_mean_mag_obs', 'phot_bp_mean_mag_obs', 'phot_rp_mean_mag_obs',
    'j_obs', 'h_obs', 'k_obs', 
    'w1_obs', 'w2_obs', #'w3_obs', #'w4_obs',
    'irac1_obs', 'irac2_obs', 'irac3_obs', 'irac4_obs', #'mips1_obs', 
    # Errors (A_V error constant, so not needed)
    'parallax_error',
    'phot_g_mean_mag_error', 'phot_bp_mean_mag_error', 'phot_rp_mean_mag_error', 
    'j_error', 'h_error', 'k_error',
    'w1_error', 'w2_error', #'w3_error', #'w4_error',
    'irac1_error', 'irac2_error', 'irac3_error', 'irac4_error', #'mips1_error'
]

features_y_max = ['parallax_true', 'logAge', 'A_V', 'feh']

train_test_set = df[features_X_max + features_y_max].isna().sum(axis=1)==0
df_X_y = df.loc[train_test_set]

n_samples_val = 30_000 
idx_rand_perm = np.random.permutation(df_X_y.shape[0])
idx_train = idx_rand_perm[n_samples_val:]
idx_test = idx_rand_perm[:n_samples_val]
idx_test.shape

In [33]:
df.shape

In [34]:
l, counts = np.unique(df_X_y.labels, return_counts=True)
counts[l>-1].sum(), counts[l==-1].sum()

In [35]:
(df_X_y.logAge).hist(bins=50, log=True)

In [36]:
df.loc[np.arange(df.shape[0])[train_test_set][idx_train], 'train_val_samples'] = True
df.loc[np.arange(df.shape[0])[train_test_set][idx_test], 'test_samples'] = True

In [37]:
df.loc[df.test_samples].shape

In [38]:
df['id'] = -1
df.loc[df.test_samples, 'id'] = np.arange(df.loc[df.test_samples].shape[0])

In [39]:
# fname_out = '/Users/ratzenboe/Library/CloudStorage/Dropbox/work/data/mock_edr3/sim_field_clusters_TrainTest_new.csv'
# df.to_csv(fname_out, index=False)

In [40]:
# Define feature sets
features_X_shared = [
    'phot_g_mean_mag_obs', 'phot_bp_mean_mag_obs', 'phot_rp_mean_mag_obs',
    'j_obs', 'h_obs', 'k_obs',
    'phot_g_mean_mag_error', 'phot_bp_mean_mag_error', 'phot_rp_mean_mag_error', 
    'j_error', 'h_error', 'k_error',
]
features_X_wise_irac = [
    'w1_obs', 'w2_obs', #'w3_obs', 'w4_obs',
    'irac1_obs', 'irac2_obs', 'irac3_obs', 'irac4_obs', #'mips1_obs',
    'w1_error', 'w2_error', #'w3_error', 'w4_error',
    'irac1_error', 'irac2_error', 'irac3_error', 'irac4_error', #'mips1_error'
]
features_plx = ['parallax_obs', 'parallax_error']
features_Av = ['A_V_obs']

model_specifics = {
    # 'all': {
    #     'features_X': features_X_shared + features_X_wise_irac + features_plx,
    #     'features_y': ['parallax_true', 'logAge'],
    #     'model_str': 'X_allFeatures__y_parallax_logAge'
    # },
    'SED_only': {
        'features_X': features_X_shared + features_X_wise_irac,
        'features_y': ['parallax_true', 'logAge'],
        'model_str': 'X_SEDonly__y_parallax_logAge'
    },
    # 'Sagitta': {
    #     'features_X': features_X_shared + features_plx + features_Av,
    #     'features_y': ['parallax_true', 'logAge'],
    #     'model_str': 'X_Sagitta__y_parallax_logAge'
    # }
}

In [41]:
# for model_abbr, m_infos in model_specifics.items():
#     features_X = m_infos['features_X']
#     features_y = m_infos['features_y']
#     model_str = m_infos['model_str']
#     print(model_str)

In [42]:
fpath = '/Users/ratzenboe/Documents/work/code/notebooks/SBI/trained_models/'

# Training loop: all models
for model_abbr, m_infos in model_specifics.items():
    features_X = m_infos['features_X']
    features_y = m_infos['features_y']
    model_str = m_infos['model_str']

    # Define training and test data
    x_samples = torch.tensor(df.loc[df.train_val_samples, features_X].values.astype(np.float32))
    theta_samples = torch.tensor(df.loc[df.train_val_samples, features_y].values.astype(np.float32))
    # Normalize the data
    # -- X --
    x_mean = x_samples.mean(dim=0)
    x_std = x_samples.std(dim=0)
    x_train = (x_samples - x_mean) / x_std
    # -- theta --
    theta_mean = theta_samples.mean(dim=0)
    theta_std = theta_samples.std(dim=0)
    theta_train = (theta_samples - theta_mean) / theta_std
    
    # ----- Define priors ------
    theta_mins = torch.tensor(df.loc[df.train_val_samples, features_y].min().values.astype(np.float32))
    theta_maxs = torch.tensor(df.loc[df.train_val_samples, features_y].max().values.astype(np.float32))
    print(theta_mins, theta_maxs)
    # Normalize the mins and maxs
    theta_mins = (theta_mins - theta_mean) / theta_std
    theta_maxs = (theta_maxs - theta_mean) / theta_std
    # Define prior
    prior = utils.BoxUniform(
        low=theta_mins,
        high=theta_maxs
    )
    print('Training model:', model_str)
    # sample parameters theta and observations x
    inference = SNPE(prior=prior)
    inference.append_simulations(x=x_train, theta=theta_train)
    density_estimator = inference.train(
        num_atoms=50,
        training_batch_size=512, 
        validation_fraction=0.1,
        learning_rate = 5e-4,
        show_train_summary=True,
    )
    posterior = inference.build_posterior(density_estimator)
    # ----------- Save model -----------
    with open(fpath + f"posterior_{model_str}.pkl", "wb") as handle:
        pickle.dump(posterior, handle)
        
    scale_factors = {
        'theta_mean': theta_mean,
        'theta_std': theta_std,
        'x_mean': x_mean,
        'x_std': x_std
    }
    with open(fpath + f"scale_factors_{model_str}.pkl", "wb") as handle:
        pickle.dump(scale_factors, handle)
        
    features_X_y = {
        'X': features_X,
        'y': features_y
    }
    with open(fpath + f'features_{model_str}.pkl', 'wb') as handle:
        pickle.dump(features_X_y, handle)
    print('Done!')

In [48]:
1000/500