In [1]:
import pandas as pd
import pylab as pl
import numpy as np
import seaborn as sns
%matplotlib inline

import yaml
from nnaps import predictors

Using TensorFlow backend.


In [2]:
pre_interaction_model = predictors.FCPredictor(saved_model='model_pre_interacton.h5')
stable_model = predictors.FCPredictor(saved_model='model_stable_systems.h5')
ce_model = predictors.FCPredictor(saved_model='model_ce_systems.h5')

In [3]:
df = pd.read_csv('../data/sdBShortP_large_BPS_training_set_ce_applied.csv')

In [4]:
df = df[['M1_init', 'q_init', 'P_init', 'FeH_init', 'stability_limit', 'alpha_ce']]

In [32]:
def predict_bps(data):
    
    results = data.copy()
    results['stability'] = 'stable'
    for key in stable_model.classifiers:
        results[key] = 0
    results['CE'] = 0
    
    
    
    for key in pre_interaction_model.regressors:
        results[key] = np.nan
        
    for key in stable_model.regressors:
        results[key] = np.nan
    
    for key in ce_model.regressors:
        results[key] = np.nan 
    
    # Pre interaction phase parameters
    pre_interaction_props = pre_interaction_model.predict(data)
    
    results['stability'] = pre_interaction_props['stability'].values
    results.loc[:,pre_interaction_model.regressors] = \
        pre_interaction_props.loc[:, pre_interaction_model.regressors].values
    
    # stable RLOF systems
    stable_ind = pre_interaction_props[pre_interaction_props['stability'] == 'stable'].index
    stable_props = stable_model.predict(data.loc[stable_ind])
    
    results.loc[stable_ind, stable_model.classifiers] = \
        stable_props.loc[:, stable_model.classifiers].values
    
    parameters = ['P', 'q', 'M1']
    phases = ['HeCoreBurning', 'HeShellBurning', 'He-WD']
    
    for phase in phases:
        p_pars = [phase + '_' + p for p in parameters]
        
        res_inds = results[(results['stability'] == 'stable') & (results[phase] == 1)].index
        mod_inds = stable_props[(stable_props[phase] == 1)].index
        results.loc[res_inds, p_pars] = stable_props.loc[mod_inds, p_pars].values
    
    # ce systems
    ce_ind = pre_interaction_props[pre_interaction_props['stability'] == 'CE'].index
    ce_props = ce_model.predict(data.loc[ce_ind])
    ce_props['CE'] = 1
    
    results.loc[ce_ind, ce_model.classifiers] = \
        ce_props.loc[:, ce_model.classifiers].values
    results.loc[ce_ind, 'CE'] = 1
    results.loc[ce_ind, 'He-WD'] = results.loc[ce_ind, 'HeCoreBurning'].apply(lambda x: abs(x-1))
    
    
    parameters = ['P', 'q', 'M1']
    phases = ['CE', 'HeCoreBurning', 'He-WD']
    
    for phase in phases:
        p_pars = [phase + '_' + p for p in parameters]
        
        res_inds = results[(results['stability'] == 'CE') & (results[phase] == 1)].index
        mod_inds = ce_props[(ce_props[phase] == 1)].index
        results.loc[res_inds, p_pars] = ce_props.loc[mod_inds, p_pars].values
    
    
    return results
    

In [33]:
predict_bps(df.loc[0:5, :])

Unnamed: 0,M1_init,q_init,P_init,FeH_init,stability_limit,alpha_ce,stability,HeCoreBurning,HeShellBurning,He-WD,...,HeCoreBurning_M1,HeShellBurning_P,HeShellBurning_q,HeShellBurning_M1,He-WD_P,He-WD_q,He-WD_M1,CE_P,CE_q,CE_M1
0,0.856,8.734693,354.120035,-1.04773,-3.0,0.974,CE,0,0,1,...,,,,,176.311264,7.612911,0.828416,182.353271,8.454698,0.837148
1,1.537999,1.028074,204.200076,-0.151481,0.0,0.18,stable,1,1,0,...,0.507186,1059.510986,0.255838,0.402485,,,,,,
2,1.262999,3.205582,213.940216,0.065559,-2.0,0.929,CE,0,0,1,...,,,,,141.320969,2.72504,1.102065,136.389297,2.672935,1.150143
3,1.617999,6.082705,131.93007,-0.098807,0.0,0.441,contact,0,0,0,...,,,,,,,,,,
4,1.259999,2.81879,653.090525,-0.196673,0.0,0.434,CE,1,0,0,...,0.79743,,,,,,,160.219543,2.11765,0.940438
5,1.192999,2.419877,655.560568,-0.290128,-2.0,0.517,CE,1,0,0,...,0.718213,,,,,,,593.1203,1.845839,0.954381


In [10]:
df.loc[0:5, :]

Unnamed: 0,M1_init,q_init,P_init,FeH_init,stability_limit,alpha_ce
0,0.856,8.734693,354.120035,-1.04773,-3.0,0.974
1,1.537999,1.028074,204.200076,-0.151481,0.0,0.18
2,1.262999,3.205582,213.940216,0.065559,-2.0,0.929
3,1.617999,6.082705,131.93007,-0.098807,0.0,0.441
4,1.259999,2.81879,653.090525,-0.196673,0.0,0.434
5,1.192999,2.419877,655.560568,-0.290128,-2.0,0.517
