In [2]:
# Preamble 
try:
    %matplotlib inline
    %config InlineBackend.figure_format='retina'
except:
    pass

from astropy.table import Table,join,hstack,vstack
import glob
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import corner
import time
import pickle
import imageio as iio
from datetime import date

In [3]:
# Solar Abundances
marcs2014_a_x_sun = dict()
elements = [
 "H",  "He",  "Li",  "Be",   "B",   "C",   "N",   "O",   "F",  "Ne",
"Na",  "Mg",  "Al",  "Si",   "P",   "S",  "Cl",  "Ar",   "K",  "Ca",
"Sc",  "Ti",   "V",  "Cr",  "Mn",  "Fe",  "Co",  "Ni",  "Cu",  "Zn",
"Ga",  "Ge",  "As",  "Se",  "Br",  "Kr",  "Rb",  "Sr",   "Y",  "Zr",
"Nb",  "Mo",  "Tc",  "Ru",  "Rh",  "Pd",  "Ag",  "Cd",  "In",  "Sn",
"Sb",  "Te",   "I",  "Xe",  "Cs",  "Ba",  "La",  "Ce",  "Pr",  "Nd",
"Pm",  "Sm",  "Eu",  "Gd",  "Tb",  "Dy",  "Ho",  "Er",  "Tm",  "Yb",
"Lu",  "Hf",  "Ta",   "W",  "Re",  "Os",  "Ir",  "Pt",  "Au",  "Hg",
"Tl",  "Pb",  "Bi",  "Po",  "At",  "Rn",  "Fr",  "Ra",  "Ac",  "Th",
"Pa",   "U",  "Np",  "Pu",  "Am",  "Cm",  "Bk",  "Cs",  "Es"
]
zeropoints = [
12.00, 10.93,  1.05,  1.38,  2.70,  8.39,  7.78,  8.66,  4.56,  7.84,
 6.17,  7.53,  6.37,  7.51,  5.36,  7.14,  5.50,  6.18,  5.08,  6.31,
 3.17,  4.90,  4.00,  5.64,  5.39,  7.45,  4.92,  6.23,  4.21,  4.60,
 2.88,  3.58,  2.29,  3.33,  2.56,  3.25,  2.60,  2.92,  2.21,  2.58,
 1.42,  1.92, -8.00,  1.84,  1.12,  1.66,  0.94,  1.77,  1.60,  2.00,
 1.00,  2.19,  1.51,  2.24,  1.07,  2.17,  1.13,  1.70,  0.58,  1.45,
-8.00,  1.00,  0.52,  1.11,  0.28,  1.14,  0.51,  0.93,  0.00,  1.08,
 0.06,  0.88, -0.17,  1.11,  0.23,  1.25,  1.38,  1.64,  1.01,  1.13,
 0.90,  2.00,  0.65, -8.00, -8.00, -8.00, -8.00, -8.00, -8.00,  0.06,
-8.00, -0.52, -8.00, -8.00, -8.00, -8.00, -8.00, -8.00, -8.00]
for (element, zeropoint) in zip(elements, zeropoints):
    marcs2014_a_x_sun[element] = zeropoint

galah_zeropoints = Table.read('galah_dr4_zeropoints.fits')

# Parameter Biases
parameter_biases = dict()
parameter_biases['teff']  = 5772.0 - galah_zeropoints['teff'][0]
parameter_biases['logg']  = 4.438 - galah_zeropoints['logg'][0] # DR3: offset without non-spectroscopic information
parameter_biases['fe_h']  = marcs2014_a_x_sun['Fe'] - galah_zeropoints['A_Fe'][0]  # -0.017 VESTA, GAS07: 7.45, DR3: 7.38
parameter_biases['vmic']  = 0.
parameter_biases['vsini'] = 0.
for element in [
    'Li','C','N','O',
    'Na','Mg','Al','Si',
    'K','Ca','Sc','Ti','V','Cr','Mn','Co','Ni','Cu','Zn',
    'Rb','Sr','Y','Zr','Mo','Ru',
    'Ba','La','Ce','Nd','Sm','Eu'
]:
    parameter_biases[element.lower()+'_fe'] = marcs2014_a_x_sun[element] - galah_zeropoints['A_'+element][0]



In [4]:
def combine_allstar():
    
    dates = glob.glob('daily/*plxcom.fits')
    dates = [date[38:38+6] for date in dates]
    dates = np.unique(dates)
    dates.sort()
    
    print('REVERSING LOGG PARAMETER_BIAS')
    
    print(len(dates),list(dates))
    
    data_allstar  = Table.read('daily/galah_dr4_allspec_not_validated_'+str(dates[0])+'_plxcom.fits')

    print(1,dates[0])
    for ind,date in enumerate(dates[1:]):
        print(ind+2,date)
        try:
            data_next = Table.read('daily/galah_dr4_allspec_not_validated_'+str(date)+'_plxcom.fits')
            data_allstar = vstack([data_allstar, data_next])
        except:
            print('Could not read single: '+str(date))
        
    return(data_allstar)

data_allstar = combine_allstar()

for label in ['logg']:
    data_allstar[label] -= parameter_biases[label]

data_allstar.write('galah_dr4_allstar_not_validated.fits',overwrite=True)
print(len(data_allstar['sobject_id']))

REVERSING LOGG PARAMETER_BIAS
141 ['131216', '131217', '131220', '140111', '140112', '140113', '140114', '140115', '140116', '140117', '140118', '140207', '140208', '140209', '140303', '140304', '140305', '140307', '140308', '140309', '140310', '140312', '140313', '140314', '140315', '140316', '140409', '140412', '140413', '140414', '140607', '140608', '140609', '140610', '140611', '140707', '140708', '140709', '140710', '140711', '140713', '140805', '140806', '140807', '140808', '140809', '140810', '140811', '140812', '140813', '140814', '140822', '140823', '140824', '141031', '141102', '141103', '141104', '141202', '141231', '150101', '150102', '150103', '150105', '150106', '150107', '150108', '150109', '150112', '150330', '150504', '150531', '150703', '150704', '150705', '150706', '150718', '150719', '150901', '150902', '150903', '151008', '151009', '151109', '151110', '151111', '160325', '160326', '160327', '160328', '160330', '160331', '160602', '160610', '160611', '160612', '1606

In [5]:
def combine_allspec():
    
    dates = glob.glob('daily/*single.fits')
    dates = [date[38:38+6] for date in dates]
    dates = np.unique(dates)
    dates.sort()
    try:
        dates = np.delete(dates,np.where('plxlog' == dates)[0][0])
    except:
        pass
    print(len(dates),list(dates))
    
#     data_single  = Table.read('daily/galah_dr4_allspec_not_validated_'+str(dates[0])+'_single.fits')
# #     data_binary  = Table.read('daily/galah_dr4_allspec_not_validated_'+str(dates[0])+'_binary.fits')
#     data_plxlogg = Table.read('daily/galah_dr4_allspec_not_validated_plxlogg_'+str(dates[0])+'.fits')

#     for ind,date in enumerate(dates[1:]):
#         print(ind,date)
#         try:
#             data_next = Table.read('daily/galah_dr4_allspec_not_validated_'+str(date)+'_single.fits')
#             data_single = vstack([data_single, data_next])
#         except:
#             print('Could not read single: '+str(date))
        
# #         try:
# #             data_next = Table.read('daily/galah_dr4_allspec_not_validated_'+str(date)+'_binary.fits')
# #             data_binary = vstack([data_binary, data_next])
# #         except:
# #             pass

#         try:
#             data_next = Table.read('daily/galah_dr4_allspec_not_validated_plxlogg_'+str(date)+'.fits')
#         except:
#             data_next = data_next['sobject_id']
#             print('Could not read plxcom: '+str(date))
#         data_plxlogg = vstack([data_plxlogg, data_next])
#     return(data_single, 1, data_plxlogg)

data, data_binary, data_plxlogg = combine_allspec()

574 ['131216', '131217', '131220', '140111', '140112', '140113', '140114', '140115', '140116', '140117', '140118', '140207', '140208', '140209', '140210', '140211', '140212', '140303', '140304', '140305', '140307', '140308', '140309', '140310', '140312', '140313', '140314', '140315', '140316', '140409', '140412', '140413', '140414', '140607', '140608', '140609', '140610', '140611', '140707', '140708', '140709', '140710', '140711', '140713', '140805', '140806', '140807', '140808', '140809', '140810', '140811', '140812', '140813', '140814', '140822', '140823', '140824', '141031', '141102', '141103', '141104', '141202', '141231', '150101', '150102', '150103', '150105', '150106', '150107', '150108', '150109', '150112', '150204', '150205', '150206', '150207', '150208', '150209', '150210', '150211', '150330', '150401', '150405', '150406', '150407', '150408', '150409', '150410', '150411', '150412', '150413', '150426', '150427', '150428', '150429', '150430', '150504', '150531', '150601', '1506

TypeError: cannot unpack non-iterable NoneType object

In [None]:
if len(data['sobject_id']) == len(data_plxlogg['sobject_id']):
    for key in data_plxlogg.keys():
        
        if key in ['sobject_id','gaiadr3_source_id','teff','fe_h','logg_plx']:
            pass
        elif key == 'logg':
            if 'logg_spec' not in data.keys():
                data['logg_spec'] = data['logg']
                data['e_logg_spec'] = data['e_logg']
                data['logg'] = data_plxlogg['logg_plx']
                data['e_logg'][:] = np.NaN
        elif key in ['r_med','r_lo','r_hi']:
            data_plxlogg[key][data_plxlogg[key] > 1000000.] = np.NaN
            data[key] = np.array(data_plxlogg[key],dtype=np.float32)
        else:
            data[key] = np.array(data_plxlogg[key],dtype=np.float32)
else:
    print('Not same length')

In [None]:
data_plxlogg

In [None]:
use = np.isfinite(data['logg_spec']) #& (~data['logg'].mask)

f, gs = plt.subplots(1,3,figsize=(10,3))
ax = gs[0]
ax.hist2d(
    data['teff'][use],
    data['logg_spec'][use],
    bins=(np.arange(2900,8100,25),np.arange(-0.5,5.5,0.05)), cmin = 1, norm=LogNorm()
)
ax.set_xlim(8100,2900)
ax.set_ylim(5.5,-0.5)
ax.set_xlabel('Teff(spec)',fontsize=15)
ax.set_ylabel('logg(spec)',fontsize=15)

ax = gs[1]
ax.hist2d(
    data['teff'][use],
    data['logg'][use],
    bins=(np.arange(2900,8100,25),np.arange(-0.5,5.5,0.05)), cmin = 1, norm=LogNorm()
)
ax.set_xlim(8100,2900)
ax.set_ylim(5.5,-0.5)
ax.set_xlabel('Teff(spec)',fontsize=15)
ax.set_ylabel('logg(plx)',fontsize=15)

ax = gs[2]
s = ax.hist2d(
    data['logg'][use],
    data['logg_spec'][use] - data['logg'][use],
    bins=(np.linspace(-0.5,5.5,100),np.linspace(-2,2,100)), cmin=1, norm=LogNorm(),label='100,000 stars'
)

steps = 0.5
loggs = np.arange(-0.5,5.51,steps)
dloggs = []
sloggs = []
for logg in loggs:
    in_bin = abs(data['logg'][use] - logg) < 0.5*steps
    if len(data['logg'][use][in_bin]) > 100:
        dloggs.append(np.nanmedian(data['logg_spec'][use][in_bin] - data['logg'][use][in_bin]))
        sloggs.append(np.nanstd(data['logg_spec'][use][in_bin] - data['logg'][use][in_bin]))
    else:
        dloggs.append(np.NaN)
        sloggs.append(np.NaN)
        
print(dloggs)
plt.errorbar(
    loggs,dloggs,yerr=sloggs,fmt='o',c='r',ms=3,capsize=5,lw=1,label='median/std'
)
plt.axhline(0,c='lightblue',lw=2,label='dlogg = 0')
ax.set_xlabel('logg(plx)',fontsize=15)
ax.set_ylabel('dlogg(plx-spec)',fontsize=15)
ax.legend(loc='lower left')

plt.tight_layout(w_pad=0)
plt.savefig('figures/dlogg_spec_plx.png',dpi=200,bbox_inches='tight')

In [None]:
use = np.isfinite(data['logg_spec']) & (data['logg'] < 3) #& (~data['logg'].mask)

f, gs = plt.subplots(1,3,figsize=(10,3))
ax = gs[0]
ax.hist2d(
    data['teff'][use],
    data['logg_spec'][use] - data['logg'][use],
    bins=(np.arange(2900,8100,25),np.linspace(-2,2,100)), cmin = 1, norm=LogNorm()
)
ax.set_xlim(8100,2900)
# ax.set_ylim(5.5,-0.5)
ax.set_xlabel('Teff(spec)',fontsize=15)
ax.set_ylabel('dlogg(plx-spec)',fontsize=15)

ax = gs[1]
s = ax.hist2d(
    data['logg'][use],
    data['logg_spec'][use] - data['logg'][use],
    bins=(np.linspace(-0.5,5.5,100),np.linspace(-2,2,100)), cmin=1, norm=LogNorm(),label='100,000 stars'
)

steps = 0.5
loggs = np.arange(-0.5,5.51,steps)
dloggs = []
sloggs = []
for logg in loggs:
    in_bin = abs(data['logg'][use] - logg) < 0.5*steps
    if len(data['logg'][use][in_bin]) > 100:
        dloggs.append(np.nanmedian(data['logg_spec'][use][in_bin] - data['logg'][use][in_bin]))
        sloggs.append(np.nanstd(data['logg_spec'][use][in_bin] - data['logg'][use][in_bin]))
    else:
        dloggs.append(np.NaN)
        sloggs.append(np.NaN)
        
#print(dloggs)
ax.errorbar(
    loggs,dloggs,yerr=sloggs,fmt='o',c='r',ms=3,capsize=5,lw=1,label='median/std'
)
ax.axhline(0,c='lightblue',lw=2,label='dlogg = 0')
ax.set_xlabel('logg(plx)',fontsize=15)
ax.set_ylabel('dlogg(plx-spec)',fontsize=15)
ax.legend(loc='lower left')

ax = gs[2]
ax.hist2d(
    data['fe_h'][use],
    data['logg_spec'][use] - data['logg'][use],
    bins=(np.linspace(-3.0,0.75,100),np.linspace(-2,2,100)), cmin = 1, norm=LogNorm()
)
ax.set_xlim(-3.0,0.75)
# ax.set_ylim(5.5,-0.5)
ax.set_xlabel('[Fe/H]',fontsize=15)
ax.set_ylabel('dlogg(plx-spec)',fontsize=15)


plt.tight_layout(w_pad=0)

use = np.isfinite(data['logg_spec']) & (data['logg'] >= 3) #& (~data['logg'].mask)

f, gs = plt.subplots(1,3,figsize=(10,3))
ax = gs[0]
ax.hist2d(
    data['teff'][use],
    data['logg_spec'][use] - data['logg'][use],
    bins=(np.arange(2900,8100,25),np.linspace(-2,2,100)), cmin = 1, norm=LogNorm()
)
ax.set_xlim(8100,2900)
# ax.set_ylim(5.5,-0.5)
ax.set_xlabel('Teff(spec)',fontsize=15)
ax.set_ylabel('dlogg(plx-spec)',fontsize=15)

ax = gs[1]
s = ax.hist2d(
    data['logg'][use],
    data['logg_spec'][use] - data['logg'][use],
    bins=(np.linspace(-0.5,5.5,100),np.linspace(-2,2,100)), cmin=1, norm=LogNorm(),label='100,000 stars'
)

steps = 0.5
loggs = np.arange(-0.5,5.51,steps)
dloggs = []
sloggs = []
for logg in loggs:
    in_bin = abs(data['logg'][use] - logg) < 0.5*steps
    if len(data['logg'][use][in_bin]) > 100:
        dloggs.append(np.nanmedian(data['logg_spec'][use][in_bin] - data['logg'][use][in_bin]))
        sloggs.append(np.nanstd(data['logg_spec'][use][in_bin] - data['logg'][use][in_bin]))
    else:
        dloggs.append(np.NaN)
        sloggs.append(np.NaN)
        
#print(dloggs)
ax.errorbar(
    loggs,dloggs,yerr=sloggs,fmt='o',c='r',ms=3,capsize=5,lw=1,label='median/std'
)
ax.axhline(0,c='lightblue',lw=2,label='dlogg = 0')
ax.set_xlabel('logg(plx)',fontsize=15)
ax.set_ylabel('dlogg(plx-spec)',fontsize=15)
ax.legend(loc='lower left')

ax = gs[2]
ax.hist2d(
    data['fe_h'][use],
    data['logg_spec'][use] - data['logg'][use],
    bins=(np.linspace(-3.0,0.75,100),np.linspace(-2,2,100)), cmin = 1, norm=LogNorm()
)
ax.set_xlim(-3.0,0.75)
# ax.set_ylim(5.5,-0.5)
ax.set_xlabel('[Fe/H]',fontsize=15)
ax.set_ylabel('dlogg(plx-spec)',fontsize=15)

plt.tight_layout(w_pad=0)

In [None]:
init = Table.read('../spectrum_analysis/galah_dr4_initial_parameters_230101_lite.fits')
init['date'] = np.array([str(x)[:6] for x in init['sobject_id']])

bad_ccd3 = init[init['cdelt_flag']==4]
bad_ccd3['date'] = np.array([str(x)[:6] for x in bad_ccd3['sobject_id']])

bad_ccd3_dates = np.unique(bad_ccd3['date'])

obs6p1 = np.unique(np.array([x[20:20+6] for x in glob.glob('../observations_6p1/*')]))

# rerun_input = []
# for unique_date in np.unique(init['date']):
#     print(len(init['date'][init['date'] == unique_date]))
#     if (unique_date in obs6p1) & (unique_date in bad_ccd3_dates):
#         print('X')
#         print(len(bad_ccd3['sobject_id'][bad_ccd3['date'] == unique_date]))
#     else:
#         print('')#Cannot rerun '+unique_date)
# print('done')

In [None]:
a_file = open("final_flag_sp_dictionary.pkl", "rb")
flag_sp_dictionary = pickle.load(a_file)
a_file.close()

binary_setup = data['setup']=='binary'

triple = binary_setup & ((data['flag_sp'] & flag_sp_dictionary['is_sb2'][0]) > 0)
double = binary_setup & ((data['flag_sp'] & flag_sp_dictionary['is_sb2'][0]) == 0)

data['flag_sp'][double] += flag_sp_dictionary['is_sb2'][0]

data['flag_sp'][triple] += flag_sp_dictionary['sb_triple_warn'][0]

In [None]:
# for element in [
#         'Li','C','N','O',
#         'Na','Mg','Al','Si',
#         'K','Ca','Sc','Ti','V','Cr','Mn','Co','Ni','Cu','Zn',
#         'Rb','Sr','Y','Zr','Mo','Ru',
#         'Ba','La','Ce','Nd','Sm','Eu'
# ]:
#     data[element.lower()+'_fe'][np.where(data['flag_'+element.lower()+'_fe'] == -1)[0]] = np.NaN
#     data['e_'+element.lower()+'_fe'][np.where(data['flag_'+element.lower()+'_fe'] == -1)[0]] = np.NaN
#     data[element.lower()+'_fe'][np.where(data['flag_'+element.lower()+'_fe'] == 2)[0]] = np.NaN
#     data['flag_'+element.lower()+'_fe'][np.where(data['flag_'+element.lower()+'_fe'] == -1)[0]] = 2
# data['fe_h'][np.where(data['flag_fe_h'] == -1)[0]] = np.NaN
# data['e_fe_h'][np.where(data['flag_fe_h'] == -1)[0]] = np.NaN
# data['fe_h'][np.where(data['flag_fe_h'] == 2)[0]] = np.NaN
# data['flag_fe_h'][np.where(data['flag_fe_h'] == -1)[0]] = 2

In [None]:
data[[0,1,2,3,4,-5,-4,-3,-2,-1]]

# Populate best_spec4star

In [None]:
# debug = False
# # debug = True

# for tmass_id in np.unique(data['tmass_id']):
    
#     if tmass_id != 'None':

#         # First let's see how many matches we actually have
#         find_matches = np.where(tmass_id == data['tmass_id'])[0]

#         # If only 1 entry found, we most likely can simply take that
#         if len(find_matches) == 1:

#             # Sanity check: is it really a single one?
#             if data['setup'][find_matches[0]] == 'single':
#                 data['best_spec4star'][find_matches[0]] = True
#             # If not: you better run the single setup as well!
#             else:
#                 raise ValueError('1 entry found, but not single?!')

#         # If we have multiple matches: check out which ones are single/binary/coadds
#         else:

#             single_matches = find_matches[np.where(data['setup'][find_matches]=='single')[0]]
#             binary_matches = find_matches[np.where(data['setup'][find_matches]=='binary')[0]]
#             coadds_matches = find_matches[np.where(data['setup'][find_matches]=='coadds')[0]]
#             if len(coadds_matches) > 1:
#                 raise ValueError('Multiple coadded measurements available?!')
#             elif len(coadds_matches) == 1:
#                 coadds_match = coadds_matches[0]
#             nonbin_matches = find_matches[np.where(data['setup'][find_matches]!='binary')[0]]

#             if debug:
#                 print(tmass_id,':',list(data['sobject_id'][find_matches]))
#                 print('    flag_sp:',list(find_matches),list(data['setup'][find_matches]))
#                 print('    flag_sp:',list(data['flag_sp'][find_matches]))

#             # The rundown of our preference:
#             # if it is clearly better fit with a binary analysis: prefer that!
#             #     Double check if there are mutiple binary analyses and take 
#             #       1) the highest unflagged one
#             #       2) the unflagged one
#             #       3) the highest snr one
#             # elif there is a coadds run, prefer that
#             # else: there should be a coadds run!

#             # When do we believe it is a binary?
#             it_is_a_binary = False

#             # Let's test that, if binary setup(s) avaliable
#             if len(binary_matches) > 0:

#                 if debug:
#                     print('--> binary found!')

#                 # If only 1 binary measurement is available
#                 if len(binary_matches) == 1:
#                     binary_match = binary_matches[0]

#                     # chi2 has to be better by at least 5% in more than 50% of the cases
#                     binary_criterion_a = (np.nanmedian(data['chi2_sp'][binary_match] / list(data['chi2_sp'][nonbin_matches])) < 0.95)
#                     # ΔRV of the 2 binary components has to be beyond 10 km/s
#                     binary_criterion_b = (np.abs(data['rv_comp_1'][binary_match] - data['rv_comp_2'][binary_match]) > 10)
#                     # ΔRV of the 2 binary components has to be within 300 km/s
#                     binary_criterion_c = (np.abs(data['rv_comp_1'][binary_match] - data['rv_comp_2'][binary_match]) < 300)

#                     if (
#                         binary_criterion_a &
#                         binary_criterion_b &
#                         binary_criterion_c
#                     ):
#                         it_is_a_binary = True

#                     if debug:
#                         print('    a) chi2 ratios:',binary_criterion_a,data['chi2_sp'][binary_match] / list(data['chi2_sp'][nonbin_matches]))
#                         print('    delta rv1/2:',binary_criterion_b,data['rv_comp_1'][binary_match],data['rv_comp_2'][binary_match],np.abs(data['rv_comp_1'][binary_match] - data['rv_comp_2'][binary_match]))
#                         print('    b) delta rv1/2  >  10?: ',binary_criterion_b)
#                         print('    c) delta rv1/2  < 300?: ',binary_criterion_c)
#                         print('    Vertict?:',it_is_a_binary)

#                     if it_is_a_binary:
#                         data['best_spec4star'][binary_match] = True

#                 # If multiple binary measurements are available
#                 else:

#                     single_matches_to_binary_matches = []
#                     for sobject_id in data['sobject_id'][binary_matches]:

#                         single_matches_to_binary_matches.append(single_matches[np.where(data['sobject_id'][single_matches] == sobject_id)[0][0]])

#                     if np.all(data['sobject_id'][single_matches_to_binary_matches] == data['sobject_id'][binary_matches]):

#                         # Let's first check if these binaries hold up the basic tests from criterions a,b,c
#                         valid_binary_matches = (
#                             (abs(data['chi2_sp'][binary_matches]/data['chi2_sp'][single_matches_to_binary_matches]) < 0.95) &
#                             (abs(data['rv_comp_1'][binary_matches]-data['rv_comp_2'][binary_matches]) > 10) &
#                             (abs(data['rv_comp_1'][binary_matches]-data['rv_comp_2'][binary_matches]) < 300)
#                         )
#                         # If there is a coadds match, let's also check agains that measurement
#                         if len(coadds_matches) == 1:
#                             valid_binary_matches = (
#                                 (abs(data['chi2_sp'][binary_matches]/data['chi2_sp'][single_matches_to_binary_matches]) < 0.95) &
#                                 (abs(data['chi2_sp'][binary_matches]/data['chi2_sp'][coadds_match]) < 0.95) &
#                                 (abs(data['rv_comp_1'][binary_matches]-data['rv_comp_2'][binary_matches]) > 10) &
#                                 (abs(data['rv_comp_1'][binary_matches]-data['rv_comp_2'][binary_matches]) < 300)
#                             )

#                         # If we have at least 1 good measurement
#                         if len(binary_matches[valid_binary_matches]) > 0:

#                             it_is_a_binary = True

#                             # Then find the one with the largest separation and use it as best_spec4star
#                             larger_rv_separation = np.argmax(abs(data['rv_comp_1'][binary_matches[valid_binary_matches]]-data['rv_comp_2'][binary_matches[valid_binary_matches]]))
#                             data['best_spec4star'][binary_matches[valid_binary_matches][larger_rv_separation]] = True
#                     else:

#                         raise ValueError('sobject_id of single and binary not in the same order...')

#             if not it_is_a_binary:

#                 # If it is not a binary, prefer the coadds
#                 if len(coadds_matches) == 1:

#                     if debug:
#                         print('--> coadds found (but not binary)!')

#                         print(tmass_id,':',list(data['sobject_id'][find_matches]))
#                         print('    flag_sp:',list(find_matches),list(data['setup'][find_matches]))
#                         print('    flag_sp:',list(data['flag_sp'][find_matches]))

#                         print('    SNRs:')
#                         print('         ',data['snr_px_ccd2'][coadds_match])
#                         print('         ',list(data['snr_px_ccd2'][single_matches]))

#                     data['best_spec4star'][coadds_match] = True

#                 # If there is no coadds: take the single measurement
#                 elif len(single_matches) == 1:
#                     data['best_spec4star'][single_matches[0]] = True

#         if len(np.where(data['best_spec4star'][find_matches] == True)[0]) != 1:

#             # This should only be activated, if we have not run coadds for all stars!

#             best_single_match = single_matches[np.argmin(data['flag_sp'][single_matches])]
#             data['best_spec4star'][best_single_match] = True

# no_tmass_id = data['tmass_id'] == 'None'
# entries = np.arange(len(data['sobject_id']))

# for object_index, sobject_id in enumerate(data['sobject_id'][no_tmass_id]):
        
#     index = entries[no_tmass_id][object_index]
#     same_coordinates = np.where((data['ra'][index]==data['ra']) & (data['dec'][index]==data['dec']))[0]

#     single_matches = np.where(data['setup'][same_coordinates] == 'single')[0]
#     binary_matches = np.where(data['setup'][same_coordinates] == 'binary')[0]
#     coadds_matches = np.where(data['setup'][same_coordinates] == 'coadds')[0]
#     nonbin_matches = np.where(data['setup'][same_coordinates] != 'binary')[0]
    
#     already_best = np.where(data['best_spec4star'][same_coordinates]==True)[0]

#     it_is_a_binary = False
    
#     if len(same_coordinates) == 1:
#         data['best_spec4star'][index] = True
        
#     elif len(already_best) > 0:
#         pass
    
#     elif len(binary_matches) > 0:
        
#         if len(binary_matches) == 1:
#             binary_match = binary_matches[0]
        
#             # chi2 has to be better by at least 5% in more than 50% of the cases
#             binary_criterion_a = (np.nanmedian(data['chi2_sp'][same_coordinates[binary_match]] / list(data['chi2_sp'][same_coordinates[nonbin_matches]])) < 0.95)
#             # ΔRV of the 2 binary components has to be beyond 10 km/s
#             binary_criterion_b = (np.abs(data['rv_comp_1'][same_coordinates[binary_match]] - data['rv_comp_2'][same_coordinates[binary_match]]) > 10)
#             # ΔRV of the 2 binary components has to be within 300 km/s
#             binary_criterion_c = (np.abs(data['rv_comp_1'][same_coordinates[binary_match]] - data['rv_comp_2'][same_coordinates[binary_match]]) < 300)
            
#             if debug:
#                 print('    a) chi2 ratios:',binary_criterion_a,data['chi2_sp'][same_coordinates[binary_match]] / list(data['chi2_sp'][same_coordinates[nonbin_matches]]))
#                 print('    delta rv1/2:',binary_criterion_b,data['rv_comp_1'][same_coordinates[binary_match]],data['rv_comp_2'][same_coordinates[binary_match]],np.abs(data['rv_comp_1'][same_coordinates[binary_match]] - data['rv_comp_2'][same_coordinates[binary_match]]))
#                 print('    b) delta rv1/2  >  10?: ',binary_criterion_b)
#                 print('    c) delta rv1/2  < 300?: ',binary_criterion_c)
#                 print('    Vertict?:',it_is_a_binary)

#             if (
#                 binary_criterion_a &
#                 binary_criterion_b &
#                 binary_criterion_c
#             ):
#                 it_is_a_binary = True
#         else:
            
#             single_matches_to_binary_matches = []
#             for sobject_id in data['sobject_id'][same_coordinates[binary_matches]]:

#                 single_matches_to_binary_matches.append(single_matches[np.where(data['sobject_id'][same_coordinates][single_matches] == sobject_id)[0][0]])
                
#             if np.all(data['sobject_id'][same_coordinates][single_matches_to_binary_matches] == data['sobject_id'][same_coordinates[binary_matches]]):

#                 # Let's first check if these binaries hold up the basic tests from criterions a,b,c
#                 valid_binary_matches = (
#                     (abs(data['chi2_sp'][same_coordinates[binary_matches]]/data['chi2_sp'][same_coordinates][single_matches_to_binary_matches]) < 0.95) &
#                     (abs(data['rv_comp_1'][same_coordinates[binary_matches]]-data['rv_comp_2'][same_coordinates[binary_matches]]) > 10) &
#                     (abs(data['rv_comp_1'][same_coordinates[binary_matches]]-data['rv_comp_2'][same_coordinates[binary_matches]]) < 300)
#                 )
#                 # If there is a coadds match, let's also check agains that measurement
#                 if len(coadds_matches) == 1:
#                     valid_binary_matches = (
#                         (abs(data['chi2_sp'][same_coordinates[binary_matches]]/data['chi2_sp'][same_coordinates][single_matches_to_binary_matches]) < 0.95) &
#                         (abs(data['chi2_sp'][same_coordinates[binary_matches]]/data['chi2_sp'][same_coordinates][coadds_match]) < 0.95) &
#                         (abs(data['rv_comp_1'][same_coordinates[binary_matches]]-data['rv_comp_2'][same_coordinates[binary_matches]]) > 10) &
#                         (abs(data['rv_comp_1'][same_coordinates[binary_matches]]-data['rv_comp_2'][same_coordinates[binary_matches]]) < 300)
#                     )
                    
#                 # If we have at least 1 good measurement
#                 if len(binary_matches[valid_binary_matches]) > 0:

#                     it_is_a_binary = True

#                     # Then find the one with the largest separation and use it as best_spec4star
#                     larger_rv_separation = np.argmax(abs(data['rv_comp_1'][same_coordinates[binary_matches[valid_binary_matches]]]-data['rv_comp_2'][same_coordinates[binary_matches[valid_binary_matches]]]))
#                     data['best_spec4star'][same_coordinates][binary_matches[valid_binary_matches][larger_rv_separation]] = True
#             else:

#                 raise ValueError('sobject_id of single and binary not in the same order...')

#     if not it_is_a_binary:

#         if len(coadds_matches) > 0:
            
#             if len(coadds_matches) == 1:
#                 data['setup'][same_coordinates][coadds_matches[0]] = True
            
#             else:
#                 raise ValueError('Multiple coadds measurements for this non-2MASS star')
#         elif len(single_matches) == 1:
#             data['best_spec4star'][same_coordinates][single_matches[0]] = True

#         else:

#             # This should only be activated, if we have not run coadds for all stars!

#             if debug:
#                 print('There should be a coadds spectrum here too!')

#             best_single_match = single_matches[np.argmin(data['flag_sp'][same_coordinates][single_matches])]
#             data['best_spec4star'][same_coordinates[best_single_match]] = True


# Save FITS Files

In [None]:
data.write('galah_dr4_allspec_not_validated.fits',overwrite=True)
print(len(data['sobject_id']))

# data_allstar = data[data['best_spec4star']==True]
# data_allstar.sort(keys='ra')
# data_allstar.write('galah_dr4_allstar_not_validated.fits',overwrite=True)
# data_allstar

# Flag dictionaries

In [None]:
a_file = open("final_flag_sp_dictionary.pkl", "rb")
flag_sp_dictionary = pickle.load(a_file)
a_file.close()
print(flag_sp_dictionary)

entries = []
for flag in np.unique(data['flag_sp']):
    
    flag_text = []
    for flag_key in flag_sp_dictionary.keys():
        if((flag & flag_sp_dictionary[flag_key][0]) == flag_sp_dictionary[flag_key][0]):
            flag_text.append(flag_key)
    entries.append([flag,len(data['flag_sp'][data['flag_sp']==flag]),", ".join(flag_text)])
entries = np.array(entries)

a = Table()
a['flag_sp'] = entries[:,0]
a['nr_spectra'] = entries[:,1]
a['flag_sp_keys'] = entries[:,2]
for s in flag_sp_dictionary.keys():
    print(flag_sp_dictionary[s][0],flag_sp_dictionary[s][1])

In [None]:
bad = data[(data['flag_sp'] >= 64) & (data['flag_sp'] < np.max(data['flag_sp'])) & (data['teff'] > 6500)]
bad

In [None]:
from scipy.spatial import cKDTree
grids = Table.read('../spectrum_grids/galah_dr4_model_trainingset_gridpoints.fits')
grid_index_tree = cKDTree(np.c_[grids['teff_subgrid'],grids['logg_subgrid'],grids['fe_h_subgrid']])

model_needed = []
for i in data[(data['flag_sp'] >= 64) & (data['flag_sp'] < np.max(data['flag_sp']))]:
    model_needed.append(grid_index_tree.query([i['teff'],i['logg'],i['fe_h']],k=1)[1])
model_needed = np.array(model_needed)
models_needed = Table()
models_needed['models'], models_needed['counts'] = np.unique(model_needed,return_counts=True)
models_needed.sort(keys = 'counts', reverse=True)
models_needed[:10]

In [None]:
number = Table()
number['closest_model'],number['count'] = np.unique(data['closest_model'][(data['flag_sp_fit'] == 1)],return_counts=True)
number.sort(keys='count',reverse=True)
number[:10]

In [None]:
# date = [str(x)[:6] for x in data['sobject_id'][(data['flag_sp_fit'] == 1) & (data['closest_model']=='5000_2.00_-0.75')]]
data['sobject_id'][(data['flag_sp_fit'] == 1) & (data['closest_model']=='5000_2.00_-0.75')]

In [None]:
data[(data['n_fe'] > 1) & (data['flag_n_fe'] == 0) & (data['flag_sp'] == 0)][['sobject_id','model_name','closest_model','flag_sp','c_fe','flag_c_fe','n_fe','flag_n_fe','o_fe','flag_o_fe']]

In [None]:
# Plot HRD & abundances 

flag_sp_0 = data['flag_sp'] == 0
flag_sp_above0_but_results = (data['flag_sp'] > 0) & (data['flag_sp'] < np.max(data['flag_sp']))
flag_sp_results = data['flag_sp'] < np.max(data['flag_sp'])

for label in [
    'Li',
    'C',
    'N',
    'O',
    'Na',
    'Mg',
    'Al',
    'Si',
    'K',
    'Ca',
    'Sc',
    'Ti',
    'V',
    'Cr',
    'Mn',
    'Co',
    'Ni',
    'Cu',
    'Zn',
    'Rb',
    'Sr',
    'Y',
    'Zr',
    'Mo',
    'Ru',
    'Ba',
    'La',
    'Ce',
    'Nd',
    'Sm',
    'Eu'
    ]:
    
    flag0 = flag_sp_0 & (data['flag_'+label.lower()+'_fe'] == 0) #& (data['fe_h'] > -1)
    flag1 = flag_sp_0 & (data['flag_'+label.lower()+'_fe'] == 1) #& (data['fe_h'] > -1)
    flag_rest = flag_sp_above0_but_results & (data['flag_'+label.lower()+'_fe'] <= 1)

    f, gs = plt.subplots(1,3,figsize=(10,3),sharey=True)

    xbins = np.linspace(-2.5,0.75,50)
    if label == 'Li':
        ybins = np.linspace(0,4,50)
    elif label in ['C','N','O','Y','Ba','La','Ce','Nd']:
        ybins = np.linspace(-1,2,50)
    elif label in ['Mg','Si','Ti']:
        ybins = np.linspace(-0.5,1,50)
    else:
        ybins = np.linspace(-1,1,50)
    
    if label == 'Li':
        ydata = data[label.lower()+'_fe'] + data['fe_h'] + 1.05
    else:
        ydata = data[label.lower()+'_fe']
    
    # First panel: Detections for GALAH DR4 [Fe/H] vs. [X/Fe]
    ax = gs[0]
    ax.text(0.05,0.9,'a) Detections for ['+label+'/Fe]',ha='left',transform=ax.transAxes,fontsize=12,bbox=dict(boxstyle='round', facecolor='w', alpha=0.75))
    ax.text(0.05,0.785,str(len(data['fe_h'][flag0]))+' ('+"{:.0f}".format(100*len(data['fe_h'][flag0])/len(data['fe_h'][flag_sp_results]))+r'%)',ha='left',transform=ax.transAxes,bbox=dict(boxstyle='round', facecolor='w', alpha=0.75,lw=0))
    ax.set_xlabel('[Fe/H] (GALAH DR4)')
    if label == 'Li':
        ax.set_ylabel('A('+label+') (GALAH DR4)')
    else:
        ax.set_ylabel('['+label+'/Fe] (GALAH DR4)')

    if len(data['fe_h'][flag0]) > 10:
        corner.hist2d(
            data['fe_h'][flag0],
            ydata[flag0],
            bins = (xbins,ybins),
            range=[(xbins[0],xbins[-1]),(ybins[0],ybins[-1])],
            ax = ax
        )
    ax.set_xlim(xbins[0],xbins[-1])
    ax.set_ylim(ybins[0],ybins[-1])
    ax.errorbar(
        0.9*xbins[0]+0.1*xbins[-1],
        0.9*ybins[0]+0.1*ybins[-1],
        xerr=np.ma.median(data['e_fe_h'][flag0]),
        yerr=np.ma.median(data['e_'+label.lower()+'_fe'][flag0]),
        capsize=2,color='k'
    )
    
    # Second panel: Upper Limits for GALAH DR4 [Fe/H] vs. [X/Fe]
    ax = gs[1]
    ax.text(0.05,0.9,'b) Upper limits',ha='left',transform=ax.transAxes,fontsize=12,bbox=dict(boxstyle='round', facecolor='w', alpha=0.75))
    ax.text(0.05,0.785,str(len(data['fe_h'][flag1]))+' ('+"{:.0f}".format(100*len(data['fe_h'][flag1])/len(data['fe_h'][flag_sp_results]))+r'%)',ha='left',transform=ax.transAxes,bbox=dict(boxstyle='round', facecolor='w', alpha=0.75,lw=0))
    ax.set_xlabel('[Fe/H] (GALAH DR4)')
    
    if len(data['fe_h'][flag1]) > 10:
        corner.hist2d(
            data['fe_h'][flag1],
            ydata[flag1],
            bins = (xbins,ybins),
            range=[(xbins[0],xbins[-1]),(ybins[0],ybins[-1])],
            ax = ax
        )
    ax.set_xlim(xbins[0],xbins[-1])
    ax.set_ylim(ybins[0],ybins[-1])
    ax.errorbar(
        0.9*xbins[0]+0.1*xbins[-1],
        0.9*ybins[0]+0.1*ybins[-1],
        xerr=np.ma.median(data['e_fe_h'][flag1]),
        yerr=np.ma.median(data['e_'+label.lower()+'_fe'][flag1]),
        capsize=2,color='k'
    )

    # Second panel: Upper Limits for GALAH DR4 [Fe/H] vs. [X/Fe]
    ax = gs[2]
    ax.text(0.05,0.9,'c) Flagged Spectra',ha='left',transform=ax.transAxes,fontsize=12,bbox=dict(boxstyle='round', facecolor='w', alpha=0.75))
    ax.text(0.05,0.785,str(len(data['fe_h'][flag_rest]))+' ('+"{:.0f}".format(100*len(data['fe_h'][flag_rest])/len(data['fe_h'][flag_sp_results]))+r'%)',ha='left',transform=ax.transAxes,bbox=dict(boxstyle='round', facecolor='w', alpha=0.75,lw=0))
    ax.set_xlabel('[Fe/H] (GALAH DR4)')

    if len(data['fe_h'][flag_rest]) > 10:
        corner.hist2d(
            data['fe_h'][flag_rest],
            ydata[flag_rest],
            bins = (xbins,ybins),
            range=[(xbins[0],xbins[-1]),(ybins[0],ybins[-1])],
            ax = ax
        )
    ax.set_xlim(xbins[0],xbins[-1])
    ax.set_ylim(ybins[0],ybins[-1])
    ax.errorbar(
        0.9*xbins[0]+0.1*xbins[-1],
        0.9*ybins[0]+0.1*ybins[-1],
        xerr=np.ma.median(data['e_fe_h'][flag_rest]),
        yerr=np.ma.median(data['e_'+label.lower()+'_fe'][flag_rest]),
        capsize=2,color='k'
    )
    
    plt.tight_layout()
    plt.show()
#     plt.savefig('figures/galah_dr4_validation_overview_'+label.lower()+'_fe_density.png',dpi=200,bbox_inches='tight')
    plt.close()

In [None]:
flag_table = Table()
flag_table['flag_sp'],flag_table['counts'] = np.unique(data['flag_sp'],return_counts=True)
flag_table

In [None]:
plt.scatter(flag_table['flag_sp'],flag_table['counts'])

In [None]:
# Plot CNO

f, gs = plt.subplots(1,3,sharex=True,sharey=True,figsize=(2.5*3,2.5))

x_low  = -2.50
x_high =  0.75
y_low  = -1.00
y_high =  1.75

panels = ['a)','b)','c)','d)']

for i,label in enumerate(['C','N','O']):
#for i,label in enumerate(['C','N','CN','O']):
    ax = gs[i]

    if label in ['C','N','O']:
        flag0 = (data['flag_'+label.lower()+'_fe'] == 0) #& (data['fe_h'] > -1)
        flag1 = (data['flag_'+label.lower()+'_fe'] == 1) #& (data['fe_h'] > -1)
        corner.hist2d(
            data['fe_h'][flag0],
            data[label.lower()+'_fe'][flag0],
            ax = ax,bins=(np.linspace(x_low,x_high,50),np.linspace(y_low,y_high,50))
        )
#         ax.hist2d(
#             data['fe_h'][flag0],
#             data[label.lower()+'_fe'][flag0],
#             cmin = 1, bins=50
#             #s=1,label='Detection'
#         )
#         ax.scatter(
#             data['fe_h'][flag1],
#             data[label.lower()+'_fe'][flag1],
#             marker='v',label='Upper limit',
#             s=0.5
#         )
        
    if label == 'CN':
        flag0 = (data['flag_c_fe'] == 0) & (data['flag_n_fe'] == 0) #& (data['fe_h'] > -1)
        flag1 = ((data['flag_c_fe'] == 1) | (data['flag_n_fe'] == 1)) #& (data['fe_h'] > -1)
        corner.hist2d(
            data['fe_h'][flag0],
            data['c_fe'][flag0]-data['n_fe'][flag0],
            ax = ax,bins=(np.linspace(x_low,x_high,50),np.linspace(y_low,y_high,50))
        )
        #         ax.hist2d(
#             data['fe_h'][flag0],
#             data['c_fe'][flag0]-data['n_fe'][flag0],
#             cmin = 1, bins=50
# #             s=1,label='Detection'
#         )
#         ax.scatter(
#             data['fe_h'][flag1],
#             data['c_fe'][flag1]-data['n_fe'][flag1],
#             marker='v',label='Upper limit',
#             s=0.5
#         )
    
    ax.set_xlim(x_low,x_high)
    ax.set_ylim(y_low,y_high)
    
#     if i==0:
#         ax.legend()
    ax.set_xlabel('[Fe/H]')
    if label in ['C','N','O']:
        ax.set_ylabel('['+label+'/Fe]')
    if label == 'CN':
        ax.set_ylabel('[C/N]')
plt.tight_layout()
#plt.savefig('figures/overview_CNO_incl_upper_limits.png',dpi=200,bbox_inches='tight')
plt.savefig('figures/overview_CNO.png',dpi=200,bbox_inches='tight')
plt.show()
plt.close()

In [None]:
# # [X/Fe] = [X/H] - [M/H]
# # [X/H] + (A_X - 12) = log(N_X/N_H) 
# # [X/Fe] = log(N_X / N_H) - log(N_X / N_H)_Sun - [Fe/H]
# # [C + N / Fe] = log((N_C + N_N) / N_H) - log((N_C_Sun + (N_N_Sun / N_H)_Sun - [Fe/H]
# # [C + N / Fe] = log(N_C/N_H + N_N/N_H) - log(N_C_Sun/H_Sun + N_N_Sun/N_H_Sun) - [Fe/H]

# A_N_Sun = 7.78+0.15
# A_C_Sun = 8.39+0.037

# A_C = data['c_fe'] + data['fe_h'] + (A_C_Sun)
# A_N = data['n_fe'] + data['fe_h'] + (A_N_Sun)

# N_C_N_H = 10**(A_C + 12)
# N_N_N_H = 10**(A_N + 12)
# N_C_N_H_sun = 10**(A_C_Sun + 12)
# N_N_N_H_sun = 10**(A_N_Sun + 12)

# data['cn_fe'] = np.log10(N_C_N_H + N_N_N_H) - np.log10(N_C_N_H_sun + N_N_N_H_sun) - data['fe_h']

In [None]:
# def cn_masses(fe_h, c_fe, n_fe, cn_fe):
#     return(
#         1.08 - 0.18 * fe_h + 4.30 * c_fe +1.43 * n_fe - 7.55 * cn_fe
#         - 1.05 * (fe_h)**2 - 1.12 * (fe_h * c_fe) - 0.67 * (fe_h * n_fe) - 1.30 * (fe_h * cn_fe)
#         - 49.92 * (c_fe)**2 - 41.04 * (c_fe * n_fe) + 139.92 * (c_fe * cn_fe)
#         - 0.63 * (n_fe)**2 + 47.33 * (n_fe * cn_fe)
#         - 86.62 * (cn_fe)**2
#     )
# data['mass'] = cn_masses(fe_h=data['fe_h'], c_fe=data['c_fe'], n_fe=data['n_fe'], cn_fe=(data['cn_fe']-0.1)/2.)

In [None]:
# def cn_ages(fe_h, c_fe, n_fe, cn_fe, teff, logg):
#     return(
#         -54.35 + 6.53*fe_h -19.02 *c_fe -12.18*n_fe +37.22*cn_fe +59.58*teff +16.14*logg
#         +0.74*fe_h*fe_h +4.04*fe_h*c_fe +0.76*fe_h*n_fe -4.94*fe_h*cn_fe -1.46*fe_h*teff -1.56*fe_h*logg
#         +26.90*c_fe*c_fe +13.33*c_fe*n_fe -77.84*c_fe*cn_fe +48.29*c_fe*teff -13.12*c_fe*logg
#         -1.04*n_fe*n_fe -17.60*n_fe*cn_fe +13.99*n_fe*teff -1.77*n_fe*logg
#         +51.24*cn_fe*cn_fe -65.67*cn_fe*teff +14.24*cn_fe*logg
#         +15.54*teff*teff -34.68*teff*logg
#         +4.17*logg*logg
#     )
# data['age'] = 10**(cn_ages(fe_h=data['fe_h'], c_fe=data['c_fe'], n_fe=data['n_fe'], cn_fe=data['cn_fe'], teff=data['teff']/4000., logg=data['logg']))