In [1]:
# Preamble 
try:
    %matplotlib inline
    %config InlineBackend.figure_format='retina'
except:
    pass

from astropy.table import Table,join,hstack,vstack
import glob
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import corner
import time
import pickle
import imageio as iio
from datetime import date

In [2]:
dates = glob.glob('daily/*.fits')
dates = [date[-11:-5] for date in dates]
dates.sort()

print(dates)

['131216', '140305', '140307', '140314', '140315', '140316']


In [3]:
def combine_dates(dates):
    
    data = Table.read('daily/galah_dr4_allspec_not_validated_'+str(dates[0])+'.fits')
    
    for date in dates[1:]:
        data_next = Table.read('daily/galah_dr4_allspec_not_validated_'+str(date)+'.fits')
        data = vstack([data, data_next])
    data['tmass_id'] = np.array(data['tmass_id'],dtype=str)

    return(data)

data = combine_dates(dates)

In [4]:
for element in [
        'Li','C','N','O',
        'Na','Mg','Al','Si',
        'K','Ca','Sc','Ti','V','Cr','Mn','Co','Ni','Cu','Zn',
        'Rb','Sr','Y','Zr','Mo','Ru',
        'Ba','La','Ce','Nd','Sm','Eu'
]:
    data[element.lower()+'_fe'][np.where(data['flag_'+element.lower()+'_fe'] == -1)[0]] = np.NaN
    data['e_'+element.lower()+'_fe'][np.where(data['flag_'+element.lower()+'_fe'] == -1)[0]] = np.NaN
    data[element.lower()+'_fe'][np.where(data['flag_'+element.lower()+'_fe'] == 2)[0]] = np.NaN
    data['flag_'+element.lower()+'_fe'][np.where(data['flag_'+element.lower()+'_fe'] == -1)[0]] = 2
data['fe_h'][np.where(data['flag_fe_h'] == -1)[0]] = np.NaN
data['e_fe_h'][np.where(data['flag_fe_h'] == -1)[0]] = np.NaN
data['fe_h'][np.where(data['flag_fe_h'] == 2)[0]] = np.NaN
data['flag_fe_h'][np.where(data['flag_fe_h'] == -1)[0]] = 2

In [5]:
data[[0,1,2,3,4,-5,-4,-3,-2,-1]]

sobject_id,tmass_id,gaiadr3_source_id,ra,dec,flag_sp,chi2_sp,model_name,rv,e_rv,teff,e_teff,logg,e_logg,fe_h,e_fe_h,flag_fe_h,vmic,e_vmic,vsini,e_vsini,li_fe,e_li_fe,flag_li_fe,c_fe,e_c_fe,flag_c_fe,n_fe,e_n_fe,flag_n_fe,o_fe,e_o_fe,flag_o_fe,na_fe,e_na_fe,flag_na_fe,mg_fe,e_mg_fe,flag_mg_fe,al_fe,e_al_fe,flag_al_fe,si_fe,e_si_fe,flag_si_fe,k_fe,e_k_fe,flag_k_fe,ca_fe,e_ca_fe,flag_ca_fe,sc_fe,e_sc_fe,flag_sc_fe,ti_fe,e_ti_fe,flag_ti_fe,v_fe,e_v_fe,flag_v_fe,cr_fe,e_cr_fe,flag_cr_fe,mn_fe,e_mn_fe,flag_mn_fe,co_fe,e_co_fe,flag_co_fe,ni_fe,e_ni_fe,flag_ni_fe,cu_fe,e_cu_fe,flag_cu_fe,zn_fe,e_zn_fe,flag_zn_fe,rb_fe,e_rb_fe,flag_rb_fe,sr_fe,e_sr_fe,flag_sr_fe,y_fe,e_y_fe,flag_y_fe,zr_fe,e_zr_fe,flag_zr_fe,mo_fe,e_mo_fe,flag_mo_fe,ru_fe,e_ru_fe,flag_ru_fe,ba_fe,e_ba_fe,flag_ba_fe,la_fe,e_la_fe,flag_la_fe,ce_fe,e_ce_fe,flag_ce_fe,nd_fe,e_nd_fe,flag_nd_fe,sm_fe,e_sm_fe,flag_sm_fe,eu_fe,e_eu_fe,flag_eu_fe,v_bary_eff,red_rv_com,red_e_rv_com,red_teff,red_logg,red_fe_h,red_alpha_fe,red_vmic,red_vbroad,red_flag,sb2_rv_16,sb2_rv_50,sb2_rv_84,ew_h_beta,ew_h_alpha,ew_k_is,sigma_k_is,rv_k_is,ew_dib5780,sigma_dib5780,rv_dib5780,ew_dib5797,sigma_dib5797,rv_dib5797,ew_dib6613,sigma_dib6613,rv_dib6613,snr_px_ccd1,snr_px_ccd2,snr_px_ccd3,snr_px_ccd4
int64,str16,int64,float64,float64,int64,float32,bytes16,float32,float32,float32,float32,float32,float32,float32,float32,int64,float32,float32,float32,float32,float32,float32,int64,float32,float32,int64,float32,float32,int64,float32,float32,int64,float32,float32,int64,float32,float32,int64,float32,float32,int64,float32,float32,int64,float32,float32,int64,float32,float32,int64,float32,float32,int64,float32,float32,int64,float32,float32,int64,float32,float32,int64,float32,float32,int64,float32,float32,int64,float32,float32,int64,float32,float32,int64,float32,float32,int64,float32,float32,int64,float32,float32,int64,float32,float32,int64,float32,float32,int64,float32,float32,int64,float32,float32,int64,float32,float32,int64,float32,float32,int64,float32,float32,int64,float32,float32,int64,float32,float32,int64,float32,float32,int64,float64,float64,float64,float64,float64,float64,float64,float64,float64,int64,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32
131216001101002,05190449-5849304,4762794963745841536,79.76874542236328,-58.82512664794922,128,,teff_logg_fe_h,,,,,,,,,2,,,,,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,-2.427048444747925,6.073256492614746,0.1358347833156585,4855.169921875,2.953188419342041,-0.3962270617485046,0.2576413750648498,1.2092074155807495,6.407722473144531,0,,,,,,,,,,,,,,,,,,24.843557,36.874645,52.64503,46.820675
131216001101004,05194296-5852488,4762782766038731776,79.92901611328125,-58.88024139404297,128,,teff_logg_fe_h,,,,,,,,,2,,,,,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,-2.3811161518096924,54.54459381103516,0.1737816929817199,5057.93603515625,3.2822084426879883,-0.5430994033813477,0.2488781213760376,1.1383854150772097,5.95846700668335,0,,,,,,,,,,,,,,,,,,26.910444,39.116535,56.098625,49.812645
131216001101006,05242175-5855050,4762746688313325568,81.09065246582031,-58.91807556152344,128,,teff_logg_fe_h,,,,,,,,,2,,,,,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,-2.0605223178863525,20.26144218444824,0.4803440570831299,5876.05126953125,2.978644847869873,0.0700777769088745,0.6561186909675598,2.177366256713867,19.49560546875,0,,,,,,,,,,,,,,,,,,2.7553246,7.7672753,19.283684,26.94849
131216001101007,05235853-5855322,4762746963191266560,80.993896484375,-58.92562484741211,128,,teff_logg_fe_h,,,,,,,,,2,,,,,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,-2.086958408355713,100.09342193603516,0.1350113600492477,4738.349609375,2.2691359519958496,-0.3850209712982178,0.188379094004631,1.3707609176635742,7.659144878387451,0,,,,,,,,,,,,,,,,,,15.785746,32.32001,54.914253,55.59654
131216001101008,05250796-5856306,4762744661088764928,81.28317260742188,-58.94184112548828,128,,teff_logg_fe_h,,,,,,,,,2,,,,,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,-2.0070245265960693,54.40141296386719,0.285716563463211,5019.4443359375,3.933474540710449,-0.3283387720584869,0.2382774800062179,1.1003286838531494,6.664685249328613,0,,,,,,,,,,,,,,,,,,13.803916,22.49135,35.75572,33.920254
140316005601395,16153641-3245057,6035610407408704128,243.90174865722656,-32.751609802246094,128,,teff_logg_fe_h,,,,,,,,,2,,,,,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,28.019428253173828,-8.736614227294922,0.1311791092157364,4615.81591796875,2.527228355407715,-0.056187979876995,0.0021126791834831,1.2893657684326172,7.65748405456543,0,,,,,,,,,,,,,,,,,,17.588314,27.453379,41.523296,35.11009
140316005601396,16142539-3241175,6035625495643217408,243.6057891845703,-32.6882209777832,128,,teff_logg_fe_h,,,,,,,,,2,,,,,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,27.97958374023437,2.2102108001708984,0.1515788584947586,4060.4716796875,1.191803216934204,-0.4915099442005157,0.1860723942518234,1.842038869857788,9.494373321533203,0,,,,,,,,,,,,,,,,,,9.287596,18.213549,30.225048,28.382097
140316005601397,16150005-3243401,6035614122569635840,243.75021362304688,-32.727806091308594,128,,teff_logg_fe_h,,,,,,,,,2,,,,,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,27.99852180480957,-38.146484375,0.934076488018036,6596.67626953125,3.9957125186920166,0.3068985342979431,0.0830441713333129,1.7047388553619385,21.854345321655277,0,,,,,,,,,,,,,,,,,,9.4449215,13.380719,18.037834,13.098002
140316005601398,16170892-3243096,6035589967672418944,244.2872009277344,-32.71933364868164,128,,teff_logg_fe_h,,,,,,,,,2,,,,,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,28.078269958496094,-20.777942657470703,0.1568147838115692,4862.45068359375,2.6765761375427246,-0.2006494849920272,0.2175548821687698,1.2482497692108154,8.162784576416016,0,,,,,,,,,,,,,,,,,,13.017974,23.901575,38.160088,36.23345
140316005601399,16164809-3244336,6035587115814145408,244.20037841796875,-32.74266815185547,128,,teff_logg_fe_h,,,,,,,,,2,,,,,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,,,2,28.06399154663086,-24.45172119140625,0.1533706337213516,5156.1572265625,4.226296424865723,0.3157302141189575,0.0467085316777229,1.1338142156600952,7.242300510406494,0,,,,,,,,,,,,,,,,,,13.501531,23.032406,36.62643,31.874037


In [6]:
data.write('galah_dr4_allspec_not_validated.fits',overwrite=True)

In [None]:
a_file = open("final_flag_sp_dictionary.pkl", "rb")
flag_sp_dictionary = pickle.load(a_file)
a_file.close()
print(flag_sp_dictionary)

entries = []
for flag in np.unique(data['flag_sp']):
    
    flag_text = []
    for flag_key in flag_sp_dictionary.keys():
        if((flag & flag_sp_dictionary[flag_key][0]) == flag_sp_dictionary[flag_key][0]):
            flag_text.append(flag_key)
    entries.append([flag,len(data['flag_sp'][data['flag_sp']==flag]),", ".join(flag_text)])
entries = np.array(entries)

a = Table()
a['flag_sp'] = entries[:,0]
a['nr_spectra'] = entries[:,1]
a['flag_sp_keys'] = entries[:,2]
for s in flag_sp_dictionary.keys():
    print(flag_sp_dictionary[s][0],flag_sp_dictionary[s][1])

In [None]:
bad = data[(data['flag_sp'] >= 64) & (data['flag_sp'] < 128) & (data['teff'] > 6500)]
bad

In [None]:
from scipy.spatial import cKDTree
grids = Table.read('../spectrum_grids/galah_dr4_model_trainingset_gridpoints.fits')
grid_index_tree = cKDTree(np.c_[grids['teff_subgrid'],grids['logg_subgrid'],grids['fe_h_subgrid']])

model_needed = []
for i in data[(data['flag_sp'] >= 64) & (data['flag_sp'] < 128)]:
    model_needed.append(grid_index_tree.query([i['teff'],i['logg'],i['fe_h']],k=1)[1])
model_needed = np.array(model_needed)
models_needed = Table()
models_needed['models'], models_needed['counts'] = np.unique(model_needed,return_counts=True)
models_needed.sort(keys = 'counts', reverse=True)
models_needed[:10]

In [None]:
number = Table()
number['model_name'],number['count'] = np.unique(data['model_name'][(data['flag_sp'] >= 64) & (data['flag_sp'] < 128)],return_counts=True)
number.sort(keys='count',reverse=True)
number[:10]

In [None]:
# Plot HRD & abundances 

flag_sp_0 = data['flag_sp'] == 0
flag_sp_above0_but_results = (data['flag_sp'] > 0) & (data['flag_sp'] < np.max(data['flag_sp']))
flag_sp_results = data['flag_sp'] < np.max(data['flag_sp'])

for label in [
    'Li',
    'C',
    'N',
    'O',
    'Na',
    'Mg',
#     'Al',
#     'Si',
#     'K',
#     'Ca',
#     'Sc',
#     'Ti',
#     'V',
#     'Cr',
#     'Mn',
#     'Co',
#     'Ni',
#     'Cu',
#     'Zn',
#     'Rb',
#     'Sr',
#     'Y',
#     'Zr',
#     'Mo',
#     'Ru',
#     'Ba',
#     'La',
#     'Ce',
#     'Nd',
#     'Sm',
#     'Eu'
    ]:
    
    flag0 = flag_sp_0 & (data['flag_'+label.lower()+'_fe'] == 0) #& (data['fe_h'] > -1)
    flag1 = flag_sp_0 & (data['flag_'+label.lower()+'_fe'] == 1) #& (data['fe_h'] > -1)
    flag_rest = flag_sp_above0_but_results & (data['flag_'+label.lower()+'_fe'] <= 1)

    f, gs = plt.subplots(1,3,figsize=(10,3),sharey=True)

    xbins = np.linspace(-2.5,0.75,100)
    if label == 'Li':
        ybins = np.linspace(0,4,100)
    elif label in ['O','Ba','Y']:
        ybins = np.linspace(-2,2,100)
    else:
        ybins = np.linspace(-1,1.25,100)
    
    if label == 'Li':
        ydata = data[label.lower()+'_fe'] + data['fe_h'] + 1.05
    else:
        ydata = data[label.lower()+'_fe']
    
    # First panel: Detections for GALAH DR4 [Fe/H] vs. [X/Fe]
    ax = gs[0]
    ax.text(0.05,0.9,'a) Detections for ['+label+'/Fe]',ha='left',transform=ax.transAxes,fontsize=12,bbox=dict(boxstyle='round', facecolor='w', alpha=0.75))
    ax.text(0.05,0.785,str(len(data['fe_h'][flag0]))+' ('+"{:.0f}".format(100*len(data['fe_h'][flag0])/len(data['fe_h'][flag_sp_results]))+r'%)',ha='left',transform=ax.transAxes,bbox=dict(boxstyle='round', facecolor='w', alpha=0.75,lw=0))
    ax.set_xlabel('[Fe/H] (GALAH DR4)')
    if label == 'Li':
        ax.set_ylabel('A('+label+') (GALAH DR4)')
    else:
        ax.set_ylabel('['+label+'/Fe] (GALAH DR4)')

    if len(data['fe_h'][flag0]) > 10:
        corner.hist2d(
            data['fe_h'][flag0],
            ydata[flag0],
            bins = (xbins,ybins),
            range=[(xbins[0],xbins[-1]),(ybins[0],ybins[-1])],
            ax = ax
        )
    ax.set_xlim(xbins[0],xbins[-1])
    ax.set_ylim(ybins[0],ybins[-1])
    ax.errorbar(
        0.9*xbins[0]+0.1*xbins[-1],
        0.9*ybins[0]+0.1*ybins[-1],
        xerr=np.ma.median(data['e_fe_h'][flag0]),
        yerr=np.ma.median(data['e_'+label.lower()+'_fe'][flag0]),
        capsize=2,color='k'
    )
    
    # Second panel: Upper Limits for GALAH DR4 [Fe/H] vs. [X/Fe]
    ax = gs[1]
    ax.text(0.05,0.9,'b) Upper limits',ha='left',transform=ax.transAxes,fontsize=12,bbox=dict(boxstyle='round', facecolor='w', alpha=0.75))
    ax.text(0.05,0.785,str(len(data['fe_h'][flag1]))+' ('+"{:.0f}".format(100*len(data['fe_h'][flag1])/len(data['fe_h'][flag_sp_results]))+r'%)',ha='left',transform=ax.transAxes,bbox=dict(boxstyle='round', facecolor='w', alpha=0.75,lw=0))
    ax.set_xlabel('[Fe/H] (GALAH DR4)')
    
    if len(data['fe_h'][flag1]) > 10:
        corner.hist2d(
            data['fe_h'][flag1],
            ydata[flag1],
            bins = (xbins,ybins),
            range=[(xbins[0],xbins[-1]),(ybins[0],ybins[-1])],
            ax = ax
        )
    ax.set_xlim(xbins[0],xbins[-1])
    ax.set_ylim(ybins[0],ybins[-1])
    ax.errorbar(
        0.9*xbins[0]+0.1*xbins[-1],
        0.9*ybins[0]+0.1*ybins[-1],
        xerr=np.ma.median(data['e_fe_h'][flag1]),
        yerr=np.ma.median(data['e_'+label.lower()+'_fe'][flag1]),
        capsize=2,color='k'
    )

    # Second panel: Upper Limits for GALAH DR4 [Fe/H] vs. [X/Fe]
    ax = gs[2]
    ax.text(0.05,0.9,'c) Flagged Spectra',ha='left',transform=ax.transAxes,fontsize=12,bbox=dict(boxstyle='round', facecolor='w', alpha=0.75))
    ax.text(0.05,0.785,str(len(data['fe_h'][flag_rest]))+' ('+"{:.0f}".format(100*len(data['fe_h'][flag_rest])/len(data['fe_h'][flag_sp_results]))+r'%)',ha='left',transform=ax.transAxes,bbox=dict(boxstyle='round', facecolor='w', alpha=0.75,lw=0))
    ax.set_xlabel('[Fe/H] (GALAH DR4)')

    if len(data['fe_h'][flag_rest]) > 10:
        corner.hist2d(
            data['fe_h'][flag_rest],
            ydata[flag_rest],
            bins = (xbins,ybins),
            range=[(xbins[0],xbins[-1]),(ybins[0],ybins[-1])],
            ax = ax
        )
    ax.set_xlim(xbins[0],xbins[-1])
    ax.set_ylim(ybins[0],ybins[-1])
    ax.errorbar(
        0.9*xbins[0]+0.1*xbins[-1],
        0.9*ybins[0]+0.1*ybins[-1],
        xerr=np.ma.median(data['e_fe_h'][flag_rest]),
        yerr=np.ma.median(data['e_'+label.lower()+'_fe'][flag_rest]),
        capsize=2,color='k'
    )
    
    plt.tight_layout()
    plt.show()
#     plt.savefig('figures/galah_dr4_validation_overview_'+label.lower()+'_fe_density.png',dpi=200,bbox_inches='tight')
    plt.close()

In [None]:
def plot_galah_dr4_overview_A4():
    elements = [
        'Li',
        'C',
        'N',
        'O',
        'Na',
        'Mg',
        'Al',
        'Si',
        'K',
        'Ca',
        'Sc',
        'Ti',
        'V',
        'Cr',
        'Mn',
        'Co',
        'Ni',
        'Cu',
        'Zn',
        'Rb',
        'Sr',
        'Y',
        'Zr',
        'Mo',
        'Ru',
        'Ba',
        'La',
        'Ce',
        'Nd',
        'Sm',
        'Eu'
        ]

    fig, axs = plt.subplots(ncols=6, nrows=6, figsize=(11.75,8.25))
    gs = axs[0, 0].get_gridspec()
    for ax in axs[0, :2]:
        ax.remove()
    for ax in axs[1, :2]:
        ax.remove()

    # HRD 

    flag_sp_0 = data['flag_sp'] == 0
    flag_sp_above0_but_results = (data['flag_sp'] > 0) & (data['flag_sp'] < np.max(data['flag_sp']))
    flag_sp_results = data['flag_sp'] < np.max(data['flag_sp'])

    teff_limits = [3750,7250]
    logg_limits = [-0.25,5.5]

    axbig = fig.add_subplot(gs[:2, :2])
    corner.hist2d(
        data['teff'][flag_sp_0],
        data['logg'][flag_sp_0],
        bins=(np.linspace(teff_limits[0],teff_limits[1],100),np.linspace(logg_limits[0],logg_limits[1],100)),
        ax = axbig
    )
    axbig.set_xlim(teff_limits[1],teff_limits[0])
    axbig.set_ylim(logg_limits[1],logg_limits[0])
    axbig.text(0.5,0.93,str(len(data['teff'][flag_sp_0]))+r' unflagged spectra (of '+str(len(data['teff']))+')',transform=axbig.transAxes,ha='center',bbox=dict(boxstyle='square,pad=0',lw=0, facecolor='w', alpha=0.75))
    axbig.text(0.5,0.035,r'$T_\mathrm{eff}~/~\mathrm{K}$',transform=axbig.transAxes,ha='center',bbox=dict(boxstyle='square,pad=0',lw=0, facecolor='w', alpha=0.75))
    axbig.text(0.02,0.5,r'$\log (g~/~\mathrm{cm\,s^{-2}})$',transform=axbig.transAxes,va='center',rotation=90,bbox=dict(boxstyle='square,pad=0',lw=0, facecolor='w', alpha=0.75))

    # read an image
    img = iio.imread("figures/logo_desktop.png")

    ax = axs[-1,-1]
    ax.imshow(img)
    ax.axis('off')
    ax.text(0.5,+1.2,'DR4 Preview',transform=ax.transAxes,ha='center')
    ax.text(0.6,-0.6,'Buder et al. \n stand '+date.today().strftime("%y%m%d"),transform=ax.transAxes,ha='center')

    for ind, label in enumerate(elements):    
        if ind <= 3:
            ax = axs[0,2+ind]
        elif ind <= 7:
            ax = axs[1,2+ind-4]
        elif ind <= 13:
            ax = axs[2,ind-8]
        elif ind <= 19:
            ax = axs[3,ind-14]
        elif ind <= 25:
            ax = axs[4,ind-20]
        elif ind <= 32:
            ax = axs[5,ind-26]
        else:
            raise ValueError('blob')
        if label == 'Li':
            color = '#2B292C'
        elif label in ['C','O','Na','Mg','Al','Si','K','Ca','Sc','Ti','Cu','Zn','Rb']:
            color = '#9AA2C9'
        elif label in ['V','Cr','Mn','Fe','Co','Ni']:
            color = '#E69774'
        elif label in ['N','Sr','Y','Zr','Mo','Ru','Ba','La','Ce','Nd','Sm']:
            color = '#C4DAAB'
        elif label in ['Eu']:
            color = '#D0A8C5'

        xbins = np.linspace(-2.5,0.75,100)
        if label == 'Li':
            ybins = np.linspace(-3,4,100)
        elif label in ['O','Ba','Y']:
            ybins = np.linspace(-2,2,100)
        else:
            ybins = np.linspace(-1,1.25,100)

        # First panel: Detections for GALAH DR4 [Fe/H] vs. [X/Fe]
        if label == 'Li':
            textcolor = 'w'
            ax.set_yticks([-2,0,2])
        else:
            textcolor = 'k'

        flag0 = flag_sp_0 & (data['flag_'+label.lower()+'_fe'] == 0) #& (data['fe_h'] > -1)

        ax.text(0.05,0.83,'['+label+'/Fe]',ha='left',color=textcolor, transform=ax.transAxes,fontsize=10,bbox=dict(boxstyle='round', facecolor=color, alpha=0.75))
        ax.text(0.97,0.05,"{:.0f}".format(100*len(data['e_fe_h'][flag0])/len(data['teff'][flag_sp_0]))+'%',ha='right', transform=ax.transAxes,fontsize=10,bbox=dict(boxstyle='square,pad=0',lw=0,facecolor='w', alpha=0.75))

        levels = (0.97,0.9,0.68,0.5,0.3)
        if label in ['La','Ce','Sm','Eu']:
            levels = (0.9,0.68,0.5,0.3)
        elif label in ['Rb','Sr','Mo','Ru']:
            levels = (0.5,)
            
        if len(data['fe_h'][flag0]) > 0:
            corner.hist2d(
                data['fe_h'][flag0],
                data[label.lower()+'_fe'][flag0],
                bins = (xbins,ybins),
                levels = levels,
                contour_kwargs=dict(colors=['k']),
                color=color,
                range=[(xbins[0],xbins[-1]),(ybins[0],ybins[-1])],
                ax = ax
            )
        ax.set_xlim(xbins[0],xbins[-1])
        ax.set_ylim(ybins[0],ybins[-1])
        ax.errorbar(
            0.9*xbins[0]+0.1*xbins[-1],
            0.9*ybins[0]+0.1*ybins[-1],
            xerr=np.ma.median(data['e_fe_h'][flag0]),
            yerr=np.ma.median(data['e_'+label.lower()+'_fe'][flag0]),
            capsize=2,color='k'
        )
        if ind >= 26:
            ax.set_xlabel('[Fe/H]')
    #     else:
        ax.set_xticks([-2,-1,0])
    plt.tight_layout(h_pad=0,w_pad=0)
    plt.savefig('figures/galah_dr4_overview.png',dpi=150,bbox_inches='tight')
    plt.show()
    plt.close()
plot_galah_dr4_overview_A4()

In [None]:
# Plot CNO

f, gs = plt.subplots(1,3,sharex=True,sharey=True,figsize=(2.5*3,2.5))

x_low  = -2.50
x_high =  0.75
y_low  = -1.00
y_high =  1.75

panels = ['a)','b)','c)','d)']

for i,label in enumerate(['C','N','O']):
#for i,label in enumerate(['C','N','CN','O']):
    ax = gs[i]

    if label in ['C','N','O']:
        flag0 = (data['flag_'+label.lower()+'_fe'] == 0) #& (data['fe_h'] > -1)
        flag1 = (data['flag_'+label.lower()+'_fe'] == 1) #& (data['fe_h'] > -1)
        corner.hist2d(
            data['fe_h'][flag0],
            data[label.lower()+'_fe'][flag0],
            ax = ax,bins=(np.linspace(x_low,x_high,50),np.linspace(y_low,y_high,50))
        )
#         ax.hist2d(
#             data['fe_h'][flag0],
#             data[label.lower()+'_fe'][flag0],
#             cmin = 1, bins=50
#             #s=1,label='Detection'
#         )
#         ax.scatter(
#             data['fe_h'][flag1],
#             data[label.lower()+'_fe'][flag1],
#             marker='v',label='Upper limit',
#             s=0.5
#         )
        
    if label == 'CN':
        flag0 = (data['flag_c_fe'] == 0) & (data['flag_n_fe'] == 0) #& (data['fe_h'] > -1)
        flag1 = ((data['flag_c_fe'] == 1) | (data['flag_n_fe'] == 1)) #& (data['fe_h'] > -1)
        corner.hist2d(
            data['fe_h'][flag0],
            data['c_fe'][flag0]-data['n_fe'][flag0],
            ax = ax,bins=(np.linspace(x_low,x_high,50),np.linspace(y_low,y_high,50))
        )
        #         ax.hist2d(
#             data['fe_h'][flag0],
#             data['c_fe'][flag0]-data['n_fe'][flag0],
#             cmin = 1, bins=50
# #             s=1,label='Detection'
#         )
#         ax.scatter(
#             data['fe_h'][flag1],
#             data['c_fe'][flag1]-data['n_fe'][flag1],
#             marker='v',label='Upper limit',
#             s=0.5
#         )
    
    ax.set_xlim(x_low,x_high)
    ax.set_ylim(y_low,y_high)
    
#     if i==0:
#         ax.legend()
    ax.set_xlabel('[Fe/H]')
    if label in ['C','N','O']:
        ax.set_ylabel('['+label+'/Fe]')
    if label == 'CN':
        ax.set_ylabel('[C/N]')
plt.tight_layout()
#plt.savefig('figures/overview_CNO_incl_upper_limits.png',dpi=200,bbox_inches='tight')
plt.savefig('figures/overview_CNO.png',dpi=200,bbox_inches='tight')
plt.show()
plt.close()

In [None]:
# [X/Fe] = [X/H] - [M/H]
# [X/H] + (A_X - 12) = log(N_X/N_H) 
# [X/Fe] = log(N_X / N_H) - log(N_X / N_H)_Sun - [Fe/H]
# [C + N / Fe] = log((N_C + N_N) / N_H) - log((N_C_Sun + (N_N_Sun / N_H)_Sun - [Fe/H]
# [C + N / Fe] = log(N_C/N_H + N_N/N_H) - log(N_C_Sun/H_Sun + N_N_Sun/N_H_Sun) - [Fe/H]

A_N_Sun = 7.78+0.15
A_C_Sun = 8.39+0.037

A_C = data['c_fe'] + data['fe_h'] + (A_C_Sun)
A_N = data['n_fe'] + data['fe_h'] + (A_N_Sun)

N_C_N_H = 10**(A_C + 12)
N_N_N_H = 10**(A_N + 12)
N_C_N_H_sun = 10**(A_C_Sun + 12)
N_N_N_H_sun = 10**(A_N_Sun + 12)

data['cn_fe'] = np.log10(N_C_N_H + N_N_N_H) - np.log10(N_C_N_H_sun + N_N_N_H_sun) - data['fe_h']

In [None]:
def cn_masses(fe_h, c_fe, n_fe, cn_fe):
    return(
        1.08 - 0.18 * fe_h + 4.30 * c_fe +1.43 * n_fe - 7.55 * cn_fe
        - 1.05 * (fe_h)**2 - 1.12 * (fe_h * c_fe) - 0.67 * (fe_h * n_fe) - 1.30 * (fe_h * cn_fe)
        - 49.92 * (c_fe)**2 - 41.04 * (c_fe * n_fe) + 139.92 * (c_fe * cn_fe)
        - 0.63 * (n_fe)**2 + 47.33 * (n_fe * cn_fe)
        - 86.62 * (cn_fe)**2
    )
data['mass'] = cn_masses(fe_h=data['fe_h'], c_fe=data['c_fe'], n_fe=data['n_fe'], cn_fe=(data['cn_fe']-0.1)/2.)

In [None]:
def cn_ages(fe_h, c_fe, n_fe, cn_fe, teff, logg):
    return(
        -54.35 + 6.53*fe_h -19.02 *c_fe -12.18*n_fe +37.22*cn_fe +59.58*teff +16.14*logg
        +0.74*fe_h*fe_h +4.04*fe_h*c_fe +0.76*fe_h*n_fe -4.94*fe_h*cn_fe -1.46*fe_h*teff -1.56*fe_h*logg
        +26.90*c_fe*c_fe +13.33*c_fe*n_fe -77.84*c_fe*cn_fe +48.29*c_fe*teff -13.12*c_fe*logg
        -1.04*n_fe*n_fe -17.60*n_fe*cn_fe +13.99*n_fe*teff -1.77*n_fe*logg
        +51.24*cn_fe*cn_fe -65.67*cn_fe*teff +14.24*cn_fe*logg
        +15.54*teff*teff -34.68*teff*logg
        +4.17*logg*logg
    )
data['age'] = 10**(cn_ages(fe_h=data['fe_h'], c_fe=data['c_fe'], n_fe=data['n_fe'], cn_fe=data['cn_fe'], teff=data['teff']/4000., logg=data['logg']))

In [None]:
def hist2d_bin_colored(X,Y,Z,X_label='X\_label',Y_label='Y\_label',Z_label='Z\_label',bins=30,bin_function='median',ax=None,cmap='seismic_r',minimum_bin_entries = 5,**kwargs):
    """
    INPUT:
    X : x-axis parameter
    Y : y-axis parameter
    Z : parameter that will be used for coloring the bins
    X/Y/Z_label : label names
    bins = 30, but you can also give it bins = (np.linspace(x_min,x_max,30),np.linspace(y_min,y_max,30))
    bin_function : median/average/sum
    ax : if you plot it as part of an f,ax = plt.subplots()
    minimum_bin_entries : how many entries do we expect, before we even consider throwing some fancy function at them
    
    OUTPUT:
    plt.imshow
    """
    
    # First make sure we only work with finite values
    finite = np.isfinite(X) & np.isfinite(Y) & np.isfinite(Z)
    if len(X[finite])!=len(X):
        print('Not all values were finite! Continuing with only finite ones')
    X=X[finite];Y=Y[finite];Z=Z[finite]
    
    # Now create the matrix of bins and its bin-edges
    H,xedges,yedges = np.histogram2d(X,Y,bins=bins)

    # Create the matrix that we want to store color-values in
    color_matrix = np.zeros_like(H)
    color_matrix[:] = np.nan
    
    # Loop through the x- and y-bins
    for x_bin in range(len(xedges)-1):
        for y_bin in range(len(yedges)-1):
            in_xy_bin = (X>=xedges[x_bin])&(X<xedges[x_bin+1])&(Y>=yedges[y_bin])&(Y<yedges[y_bin+1])
            
            # We only add a value if there are more than *minimum_bin_entries* in the bin
            if len(Z[in_xy_bin]) >= minimum_bin_entries:
                if bin_function=='median':
                    color_matrix[x_bin,y_bin]=np.median(Z[in_xy_bin])
                elif bin_function=='average':
                    color_matrix[x_bin,y_bin]=np.average(Z[in_xy_bin])
                elif bin_function=='sum':
                    color_matrix[x_bin,y_bin]=np.sum(Z[in_xy_bin])
                elif bin_function=='std':
                    color_matrix[x_bin,y_bin]=np.std(Z[in_xy_bin])
                else:
                    raise NameError('Only bin_function = median/average/sum available')

    # Create an axis if not given
    if ax==None:
        ax = plt.gca()
    else:
        ax=ax
    ax.set_xlabel(X_label)
    ax.set_ylabel(Y_label)

    # Populate the keyword arguments for the imshow
    imshow_kwargs = dict(
        cmap = cmap,aspect='auto',origin='lower'
    )
    # Update by any arguments given through **kwargs
    imshow_kwargs.update(kwargs)

    # Plot!
    s = ax.imshow(color_matrix.T,extent=(xedges[0],xedges[-1],yedges[0],yedges[-1]),**imshow_kwargs)
    c = plt.colorbar(s, ax=ax)
    c.set_label(Z_label)

In [None]:
giants = (data['logg'] < 3.5) & (data['flag_c_fe'] == 0) & (data['flag_n_fe'] == 0)

f, gs = plt.subplots(2,2,figsize=(7,5))
hist2d_bin_colored(
    data['fe_h'][giants],
    data['mg_fe'][giants],
    data['mass'][giants],
    X_label='[Fe/H]',
    Y_label='[Mg/Fe]',
    Z_label=r'Mass / $\mathrm{M}_\odot$',
    ax = gs[0,0],
    bins = (np.linspace(-0.8,0.3,30),np.linspace(-0.1,0.4,30)),
    vmin = 0.8, vmax = 2.0,
    cmap = 'RdYlBu_r',
    minimum_bin_entries = 1
)
hist2d_bin_colored(
    data['fe_h'][giants],
    data['c_fe'][giants]-data['n_fe'][giants],
    data['mass'][giants],
    X_label='[Fe/H]',
    Y_label='[C/N]',
    Z_label=r'Mass / $\mathrm{M}_\odot$',
    ax = gs[0,1],
    bins = (np.linspace(-0.8,0.3,30),np.linspace(-0.8,0.4,30)),
    vmin = 0.8, vmax = 2.0,
    cmap = 'RdYlBu_r',
    minimum_bin_entries = 1
)
hist2d_bin_colored(
    data['teff'][giants],
    data['logg'][giants],
    data['mass'][giants],
    X_label='Teff / K',
    Y_label='logg',
    Z_label=r'Mass / $\mathrm{M}_\odot$',
    ax = gs[1,0],
    bins = (np.linspace(3000,5500,30),np.linspace(0,4,30)),
    vmin = 0.8, vmax = 2.0,
    cmap = 'RdYlBu_r',
    minimum_bin_entries = 1
)
gs[1,0].set_xlim(5500,3750)
gs[1,0].set_ylim(4,0)
hist2d_bin_colored(
    data['c_fe'][giants]-data['n_fe'][giants],
    data['mass'][giants],
    data['fe_h'][giants],
    X_label='[C/N]',
    Y_label=r'Mass / $\mathrm{M}_\odot$',
    Z_label='[Fe/H]',
    ax = gs[1,1],
    bins = (np.linspace(-0.8,0.4,30),np.linspace(0.5,3.0,30)),
    vmin = -0.8, vmax = 0.3,
    cmap = 'RdYlBu_r',
    minimum_bin_entries = 1
)
plt.tight_layout()

In [None]:
giants = (data['logg'] > 1.8) & (data['logg'] < 3.3) & (data['flag_c_fe'] == 0) & (data['flag_n_fe'] == 0)

f, gs = plt.subplots(2,2,figsize=(7,5))
hist2d_bin_colored(
    data['fe_h'][giants],
    data['mg_fe'][giants],
    data['age'][giants],
    X_label='[Fe/H]',
    Y_label='[Mg/Fe]',
    Z_label=r'Age / Gyr',
    ax = gs[0,0],
    bins = (np.linspace(-0.8,0.3,30),np.linspace(-0.1,0.4,30)),
    vmin = 0, vmax = 14,
    cmap = 'RdYlBu_r',
    minimum_bin_entries = 1
)
hist2d_bin_colored(
    data['fe_h'][giants],
    data['c_fe'][giants]-data['n_fe'][giants],
    data['age'][giants],
    X_label='[Fe/H]',
    Y_label='[C/N]',
    Z_label=r'Age / Gyr',
    ax = gs[0,1],
    bins = (np.linspace(-0.8,0.3,30),np.linspace(-0.8,0.4,30)),
    vmin = 0, vmax = 14,
    cmap = 'RdYlBu_r',
    minimum_bin_entries = 1
)
hist2d_bin_colored(
    data['teff'][giants],
    data['logg'][giants],
    data['age'][giants],
    X_label='Teff / K',
    Y_label='logg',
    Z_label=r'Age / Gyr',
    ax = gs[1,0],
    bins = (np.linspace(3000,5500,30),np.linspace(0,4,30)),
    vmin = 0, vmax = 14,
    cmap = 'RdYlBu_r',
    minimum_bin_entries = 1
)
gs[1,0].set_xlim(5500,3750)
gs[1,0].set_ylim(4,0)
hist2d_bin_colored(
    data['c_fe'][giants]-data['n_fe'][giants],
    data['age'][giants],
    data['fe_h'][giants],
    X_label='[C/N]',
    Y_label=r'Age / Gyr',
    Z_label='[Fe/H]',
    ax = gs[1,1],
    bins = (np.linspace(-0.8,0.4,30),np.linspace(0,14,30)),
    vmin = -0.8, vmax = 0.3,
    cmap = 'RdYlBu_r',
    minimum_bin_entries = 1
)
plt.tight_layout()