# Figure

#### Author(s):
Sven Buder

#### History:
210410: Created


In [1]:
# Preamble for notebook 

# Compatibility with Python 3
from __future__ import (absolute_import, division, print_function)

try:
    %matplotlib inline
    %config InlineBackend.figure_format='retina'
except:
    pass

# Basic packages
import warnings
warnings.filterwarnings("ignore")
import numpy as np
np.seterr(divide='ignore', invalid='ignore')
from astropy.table import Table, join
from sklearn import mixture

# Matplotlib and associated packages for plotting
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

params = {
    'font.family'        : 'serif',
    'text.usetex'        : True, 
    'text.latex.preamble': [r'\usepackage{upgreek}', r'\usepackage{amsmath}'],
    }   
plt.rcParams.update(params)

In [2]:
def get_and_join_data(directory = '../data/'):
    
    # Join all data
    galah_dr3_main = Table.read(directory+'GALAH_DR3_main_allspec_v2.fits')
    galah_dr3_dynamics = Table.read(directory+'GALAH_DR3_VAC_dynamics_v2.fits')
    galah_dr3_gaiaedr3 = Table.read(directory+'GALAH_DR3_VAC_GaiaEDR3_v2.fits')
    galah_dr3_ages = Table.read(directory+'GALAH_DR3_VAC_ages_v2.fits')
    galah_dr3_rv = Table.read(directory+'GALAH_DR3_VAC_rv_v2.fits')
    data_12 = join(galah_dr3_main, galah_dr3_dynamics, keys='sobject_id')
    data_123 = join(data_12, galah_dr3_gaiaedr3, keys='sobject_id')
    data_1234 = join(data_123, galah_dr3_ages, keys='sobject_id')
    data = join(data_1234, galah_dr3_rv, keys='sobject_id')
    
    # get best rv and e_rv
    data['best_rv'] = np.zeros(len(data['sobject_id'])); data['best_rv'][:] = np.nan;
    data['best_rv'][data['use_rv_flag_1']==0] = data['rv_obst'][data['use_rv_flag_1']==0]
    data['best_rv'][data['use_rv_flag_1']==1] = data['rv_sme_v2'][data['use_rv_flag_1']==1]
    data['best_rv'][data['use_rv_flag_1']==2] = data['dr2_radial_velocity_1'][data['use_rv_flag_1']==2]

    data['best_e_rv'] = np.zeros(len(data['sobject_id'])); data['best_e_rv'][:] = np.nan;
    data['best_e_rv'][data['use_rv_flag_1']==0] = data['e_rv_obst'][data['use_rv_flag_1']==0]
    data['best_e_rv'][data['use_rv_flag_1']==1] = data['e_rv_sme'][data['use_rv_flag_1']==1]
    data['best_e_rv'][data['use_rv_flag_1']==2] = data['dr2_radial_velocity_error_1'][data['use_rv_flag_1']==2]
    
    # get best distance and 16th + 84th percentile
    data['best_d'] = np.zeros(len(data['sobject_id'])); data['best_d'][:] = np.nan;
    data['best_d_16'] = np.zeros(len(data['sobject_id'])); data['best_d_16'][:] = np.nan;
    data['best_d_50'] = np.zeros(len(data['sobject_id'])); data['best_d_50'][:] = np.nan;
    data['best_d_84'] = np.zeros(len(data['sobject_id'])); data['best_d_84'][:] = np.nan;

    data['best_d'][data['use_dist_flag']==0] = 1000.*data['distance_bstep'][data['use_dist_flag']==0]
    data['best_d_16'][data['use_dist_flag']==0] = 1000.*data['e16_distance_bstep'][data['use_dist_flag']==0]
    data['best_d_50'][data['use_dist_flag']==0] = 1000.*data['e50_distance_bstep'][data['use_dist_flag']==0]
    data['best_d_84'][data['use_dist_flag']==0] = 1000.*data['e84_distance_bstep'][data['use_dist_flag']==0]

    data['best_d'][data['use_dist_flag']==1] = data['r_med_photogeo'][data['use_dist_flag']==1]
    data['best_d_16'][data['use_dist_flag']==1] = data['r_lo_photogeo'][data['use_dist_flag']==1]
    data['best_d_50'][data['use_dist_flag']==1] = data['r_med_photogeo'][data['use_dist_flag']==1]
    data['best_d_84'][data['use_dist_flag']==1] = data['r_hi_photogeo'][data['use_dist_flag']==1]

    data['best_d'][data['use_dist_flag']==2] = data['r_med_geo'][data['use_dist_flag']==2]
    data['best_d_16'][data['use_dist_flag']==2] = data['r_lo_geo'][data['use_dist_flag']==2]
    data['best_d_50'][data['use_dist_flag']==2] = data['r_med_geo'][data['use_dist_flag']==2]
    data['best_d_84'][data['use_dist_flag']==2] = data['r_hi_geo'][data['use_dist_flag']==2]

    data['best_d'][data['use_dist_flag']==4] = 1000./data['parallax_corr'][data['use_dist_flag']==4]
    data['best_d_16'][data['use_dist_flag']==4] = 1000./(data['parallax_corr'][data['use_dist_flag']==4]+data['parallax_error'][data['use_dist_flag']==4])
    data['best_d_50'][data['use_dist_flag']==4] = 1000./(data['parallax_corr'][data['use_dist_flag']==4]-data['parallax_error'][data['use_dist_flag']==4])
    data['best_d_84'][data['use_dist_flag']==4] = 1000./data['r_hi_geo'][data['use_dist_flag']==4]

    data['theta_spherical'] = np.arctan2(data['R_Rzphi'], data['z_Rzphi'])
    data['rho_spherical'] = np.sqrt((data['R_Rzphi'])**2+(data['z_Rzphi'])**2)
    data['vrho_spherical'] = (data['R_Rzphi']*data['vR_Rzphi'] + data['z_Rzphi']*data['vz_Rzphi'])/data['rho_spherical']
    data['vtheta_spherical'] = (data['z_Rzphi']*data['vR_Rzphi'] - data['R_Rzphi']*data['vz_Rzphi'])/data['rho_spherical']

    return data

# test if joined data file already exists
try:
    data = Table.read('../../data/GALAH_DR3_all_joined_v2.fits')
except:
    data = get_and_join_data(directory = '../../data/')
    data.write('../../data/GALAH_DR3_all_joined_v2.fits')

In [10]:
# Selection basics

basic_cuts = (
    (data['flag_sp'] == 0) &
    (data['flag_fe_h'] == 0) &
    (data['best_d'] <= 10000.) &
    np.isfinite(data['L_Z']) &
    np.isfinite(data['ecc']) &
    np.isfinite(data['age_bstep'])
)

basic_cuts_mgalpha = (
    basic_cuts &
    (data['flag_alpha_fe'] == 0) &
    (data['flag_Mg_fe'] == 0)
)

def nissen_slope_mg_fe(fe_h):
    """
    Slope [Fe/H] vs. [Mg/Fe] as estimated from 2010A%26A...511L..10N
    
    -1/12 * [Fe/H] + 1/6
    
    """
    return -0.1/1.2*(fe_h)+(0.3 - 1.6*0.1/1.2)

def nissen_slope_alpha_fe(fe_h):
    """
    Slope [Fe/H] vs. [alpha/Fe] as estimated from 2010A%26A...511L..10N
    
    -0.2/1.2*(fe_h)+(0.325 - 1.6*0.2/1.2)
    
    -1/6 * [Fe/H] + 0.7/12
    
    """
    return -0.2/1.2*(fe_h)+(0.325 - 1.6*0.2/1.2)

def total_velocity(data):
    """
    return the total space velocity:
    
    sqrt(pow(rv_galah,2) + (pow(4.7623*r_est/1000.,2)*(pow(pmra,2) + pow(pmdec,2))))
    
    """
    return np.sqrt(
            (data['best_rv'])**2 + 
            (4.7623*data['best_d']/1000.)**2 * 
            ((data['pmra'])**2 + (data['pmdec'])**2)
        )

def tangential_velocity(data):
    return (4.7623*data['best_d']/1000.)*np.sqrt((data['pmra'])**2 + (data['pmdec'])**2)

rgb = (
    (data['teff'] < 5500) &
    (data['logg'] < 3.25)
)

msto = (
    (data['teff'] >= 5350) &
    (data['logg'] >= 3.5)
)

stars_with_high_vtot = (
    basic_cuts_mgalpha & 
    (total_velocity(data) > 180)
)

stars_with_high_vtan = (
    basic_cuts_mgalpha & 
    (tangential_velocity(data) > 180)
)

preliminary_low_alpha_halo = (
    basic_cuts_mgalpha &
    (total_velocity(data) > 180) &
    (data['fe_h'] >= -2.0) & (data['fe_h'] <= -0.4) &
    (data['Mg_fe'] < nissen_slope_mg_fe(data['fe_h'])) &
    (data['alpha_fe'] < nissen_slope_alpha_fe(data['fe_h']))
)

preliminary_high_alpha_halo = (
    basic_cuts_mgalpha &
    (total_velocity(data) > 180) &
    (data['fe_h'] >= -2.0) & (data['fe_h'] <= -0.4) &
    (data['Mg_fe'] >= nissen_slope_mg_fe(data['fe_h'])) &
    (data['alpha_fe'] >= nissen_slope_alpha_fe(data['fe_h']))
)

# Selection via [Mg/Cu] vs. [Na/Fe]
chem1 = (
    basic_cuts &
    (data['flag_Mg_fe'] == 0) &
    (data['flag_Na_fe'] == 0) &
    (data['flag_Cu_fe'] == 0) &
    (data['Mg_fe'] - data['Cu_fe'] > 0.5) &
    (data['Na_fe'] < -0.1)
)

# Selection via [Mg/Mn] vs. [Na/Fe]
chem2 = (
    basic_cuts &
    (data['flag_Mg_fe'] == 0) &
    (data['flag_Mn_fe'] == 0) &
    (data['flag_Na_fe'] == 0) &
    (np.abs(data['Mg_fe'] - data['Mn_fe'] - 0.5) <= 0.25) &
    (data['Na_fe'] < -0.1)
)

# Selection via [Mg/Mn] vs. [Na/Fe] + [Fe/H]
chem2_mp = (
    basic_cuts &
    (data['flag_Mg_fe'] == 0) &
    (data['flag_Mn_fe'] == 0) &
    (data['flag_Na_fe'] == 0) &
    (np.abs(data['Mg_fe'] - data['Mn_fe'] - 0.5) <= 0.25) &
    (data['Na_fe'] < -0.1) &
    (data['fe_h'] <= -0.5)
)


# Selection via [Mg/Mn] vs. [Al/Fe], like Das+2020 'blob' selection
chem3 = (
    basic_cuts &
    (data['flag_Mg_fe'] == 0) &
    (data['flag_Al_fe'] == 0) &
    (data['flag_Mn_fe'] == 0) &
    (np.abs(data['Al_fe'] - (-0.35)) <= 0.25) &
    (np.abs(data['Mg_fe'] - data['Mn_fe'] - 0.5) <= 0.25)
)

# Selection via [Mg/Mn] vs. [Al/Fe] + [Fe/H] like Das+2020 'metal-poor blob' selection
chem3_mp = (
    basic_cuts &
    (data['flag_Mg_fe'] == 0) &
    (data['flag_Al_fe'] == 0) &
    (data['flag_Mn_fe'] == 0) &
    (np.abs(data['Al_fe'] - (-0.35)) <= 0.25) &
    (np.abs(data['Mg_fe'] - data['Mn_fe'] - 0.5) <= 0.25) &
    (data['fe_h'] < -0.5)
)

# Selection via L_Z and J_R like Feuillet+2020
feuillet2020 = (
    basic_cuts &
    (data['L_Z'] >= -500) &
    (data['L_Z'] <= 500) &
    (data['J_R'] >= 30**2) &
    (data['J_R'] <= 50**2)
)


# Selection via L_Z and E like Helmi+2018
helmi2018 = (
    basic_cuts & 
    (data['L_Z'] > -1500) &
    (data['L_Z'] < 150) &
    (data['Energy'] > -1.8*10**5)
)

# Selection like Naidu+2020, excluding Sagitarius, Alpeh, high-alpha disk and then selection via eccentricity
#naidu2020_sagitarius = (data['L_Y'] < -0.3*data['L_Z'] - 2.5*10**3)
naidu2020_aleph = (
    (data['vT_Rzphi'] < -175) &
    (data['vT_Rzphi'] > 300) & 
    (np.abs(data['vR_Rzphi']) < 75) & 
    (data['fe_h'] > -0.8) & 
    (data['alpha_fe'] < 0.27)
)
naidu2020_highalphadisk = (
    (data['alpha_fe'] > 0.25 - 0.5*(data['fe_h'] + 0.7))
)
naidu2020 = (
    basic_cuts & 
    (data['flag_fe_h'] == 0) &
    (data['flag_alpha_fe'] == 0) &
    # Sagitarius
    (~naidu2020_aleph) &
    (~naidu2020_highalphadisk) &
    (data['ecc'] > 0.7)
)

In [4]:
def plot_density(x, y, bins=100, range=None, normed=False, weights=None, 
                 scaling=None, reduce_fn='sum', smooth=0, ax=None, cmin=0, **kwargs):
    """
    Compute the bi-dimensional histogram of two data samples.

    Parameters
    ----------
    x : array_like, shape (N,)
        An array containing the x coordinates of the points to be
        histogrammed.
    y : array_like, shape (N,)
        An array containing the y coordinates of the points to be
        histogrammed.
    bins : int or array_like or [int, int] or [array, array], optional
        The bin specification:

        * If int, the number of bins for the two dimensions (nx=ny=bins).
        * If array_like, the bin edges for the two dimensions
            (x_edges=y_edges=bins).
        * If [int, int], the number of bins in each dimension
            (nx, ny = bins).
        * If [array, array], the bin edges in each dimension
            (x_edges, y_edges = bins).
        * A combination [int, array] or [array, int], where int
            is the number of bins and array is the bin edges.

    range : array_like, shape(2,2), optional
        The leftmost and rightmost edges of the bins along each dimension
        (if not specified explicitly in the `bins` parameters):
        ``[[xmin, xmax], [ymin, ymax]]``. All values outside of this range
        will be considered outliers and not tallied in the histogram.
    normed : bool, optional
        If False, returns the number of samples in each bin. If True,
        returns the bin density ``bin_count / sample_count / bin_area``.
    weights : array_like, shape(N,), optional
        An array of values ``w_i`` weighing each sample ``(x_i, y_i)``.
        Weights are normalized to 1 if `normed` is True. If `normed` is
        False, the values of the returned histogram are equal to the sum of
        the weights belonging to the samples falling into each bin.
    """
    import pylab as plt
    from scipy.ndimage import gaussian_filter

    ind = (np.isfinite(x) & np.isfinite(y))
    if False in ind:
        print("Warning: Not all values are finite.")
        w = weights
        if w is not None:
            w = w[ind]
        n, bx, by = plt.histogram2d(x[ind], y[ind], bins=bins, range=range, normed=normed,
                                    weights=w)
    else:
        n, bx, by = plt.histogram2d(x, y, bins=bins, range=range, normed=normed,
                                    weights=weights)

    if reduce_fn.lower() in ('average', 'mean', 'avg'):
        n0, _, _ = plt.histogram2d(x, y, bins=bins, range=range, normed=normed,
                                   weights=None)
        
        cmin_limit = (n0 < cmin)
        ind = n0 > 0
        n = n.astype(float)
        n[ind] /= n0[ind].astype(float)
    n[(n==0)] = np.nan
    n[cmin_limit] = np.nan

    defaults = dict(zorder = 2, cmap='RdYlBu_r', origin='lower', aspect='auto', rasterized=True,
                    interpolation='nearest')
    defaults.update(**kwargs)
    extent = (bx[0], bx[-1], by[0], by[-1])

    if smooth > 0:
        gaussian_filter(n, smooth, output=n)

    if ax==None:
        ax = plt.gca()
    else:
        ax=ax
        
    if scaling is None:
        scaling = 'None'
    s = scaling.lower().replace(' ', '')
    if s in (None, 'none', 'count'):
        return ax.imshow(n.T, extent=extent, **defaults)#, n
    else:
        if s in ('log', 'log(n)', 'log10', 'log(count)'):
            return ax.imshow(np.log10(n).T, extent=extent, **defaults)#, np.log10(n)
        elif s in ('log(n+1)', 'log(count+1)'):
            return ax.imshow(np.log10(n + 1).T, extent=extent, **defaults)#, np.log10(n + 1)

In [5]:
# for each in [
#     'fe_h','alpha_fe',
#     'C_fe','O_fe','Al_fe','Si_fe','K_fe','Ca_fe','Sc_fe','Ti_fe','V_fe',
#     'Cr_fe','Mn_fe','Co_fe','Ni_fe','Zn_fe','Rb_fe','Sr_fe','Y_fe','Zr_fe',
#     'Mo_fe','Ru_fe','Ba_fe','La_fe','Ce_fe','Nd_fe','Sm_fe','Eu_fe',
#     'X_XYZ','Y_XYZ','Z_XYZ',
#     'R_Rzphi','phi_Rzphi','z_Rzphi',
#     'vR_Rzphi','vT_Rzphi','vz_Rzphi',
#     'J_R','L_Z','J_Z',
#     'omega_R','omega_phi','omega_z',
#     'angle_R','angle_phi','angle_z',
#     'ecc','zmax','R_peri','R_ap','Energy',
#     'best_rv','best_d',
#     'age_bstep'
#     ]:
    
#     f, ax = plt.subplots(constrained_layout=True)

#     hist_kwargs_mean_value = dict(
#         reduce_fn='mean',
#         bins = 50,
#         cmin = 1,
#         rasterized = True,
#         zorder=2
#     )
    
#     if each in ['fe_h','alpha_fe',
#     'C_fe','O_fe','Al_fe','Si_fe','K_fe','Ca_fe','Sc_fe','Ti_fe','V_fe',
#     'Cr_fe','Mn_fe','Co_fe','Ni_fe','Zn_fe','Rb_fe','Sr_fe','Y_fe','Zr_fe',
#     'Mo_fe','Ru_fe','Ba_fe','La_fe','Ce_fe','Nd_fe','Sm_fe','Eu_fe']:
#         use = basic_cuts_mgnacu & (data['flag_'+each]==0)
#     else:
#         use = basic_cuts_mgnacu & np.isfinite(data[each])
    
#     s1 = plot_density(
#         data['Na_fe'][use],
#         data['Mg_fe'][use] - data['Cu_fe'][use],
#         weights = data[each][use],
#         ax = ax,
#         **hist_kwargs_mean_value
#     )

#     c = plt.colorbar(s1,ax=ax,location='top')
#     c.set_label('Mean '+each.replace('_','\_'))
#     ax.set_xlabel('[Na/Fe]')
#     ax.set_ylabel('[Mg/Cu]')

In [7]:
f, gs = plt.subplots(2,5,figsize=(15,5),sharex=True,sharey=True,constrained_layout=True)

for each_index, each in enumerate([
#     'density',
    'fe_h',
    'C_fe',
    'Al_fe',
    'Cr_fe',
    'Mn_fe',
    'Ni_fe',
    'Y_fe',
    'Ba_fe',
    'La_fe',
    'Ce_fe',
#     'vR_Rzphi','vT_Rzphi','vz_Rzphi',
#     'J_R','L_Z','J_Z',
#     'ecc','zmax','R_peri','R_ap','Energy',
#     'best_rv','best_d',
#     'age_bstep'
    ]):

    hist_kwargs_mean_value = dict(
        bins = 50,
        cmin = 1,
        rasterized = True,
        zorder=2
    )
    
    if each_index < 5:
        ax=gs[0,each_index]
    elif (each_index >= 5) & (each_index < 10):
        ax=gs[1,each_index-5]
    elif (each_index >= 10) & (each_index < 15):
        ax=gs[2,each_index-10]
    elif (each_index >= 15) & (each_index < 20):
        ax=gs[3,each_index-15]
    else:
        ax=gs[4,each_index-20]

    if each == 'density':
        pass
    else:
        if each in ['fe_h','alpha_fe',
        'C_fe','O_fe','Al_fe','Si_fe','K_fe','Ca_fe','Sc_fe','Ti_fe','V_fe',
        'Cr_fe','Mn_fe','Co_fe','Ni_fe','Zn_fe','Rb_fe','Sr_fe','Y_fe','Zr_fe',
        'Mo_fe','Ru_fe','Ba_fe','La_fe','Ce_fe','Nd_fe','Sm_fe','Eu_fe']:
            use = basic_cuts_mgnacu & (data['flag_'+each]==0)
        else:
            use = basic_cuts_mgnacu & np.isfinite(data[each])

        s1 = plot_density(
            data['Na_fe'][use],
            data['Mg_fe'][use] - data['Cu_fe'][use],
            weights = data[each][use],
            reduce_fn='mean',
            ax = ax,
            **hist_kwargs_mean_value
        )

        c = plt.colorbar(s1,ax=ax,location='top')
        c.set_label('Mean '+each.replace('_','\_'))
    if each_index in [0,5,10,15,20]:
        ax.set_ylabel('[Mg/Cu]')
    if each_index >= 5:
        ax.set_xlabel('[Na/Fe]')

plt.savefig('mgcunafe1.png',dpi=300,bbox_inches='tight')
plt.close()

In [9]:
f, gs = plt.subplots(3,5,figsize=(15,7.5),sharex=True,sharey=True,constrained_layout=True)

for each_index, each in enumerate([
#     'density',
#     'fe_h',
#     'C_fe',
#     'Al_fe',
#     'Cr_fe',
#     'Mn_fe',
#     'Ni_fe',
#     'Y_fe',
#     'Ba_fe',
#     'La_fe',
#     'Ce_fe',
    'vR_Rzphi','vT_Rzphi','vz_Rzphi',
    'J_R','L_Z','J_Z',
    'ecc','zmax','R_peri','R_ap','Energy',
    'best_rv','best_d',
    'age_bstep'
    ]):

    hist_kwargs_mean_value = dict(
        bins = 50,
        cmin = 1,
        rasterized = True,
        zorder=2
    )
    
    if each_index < 5:
        ax=gs[0,each_index]
    elif (each_index >= 5) & (each_index < 10):
        ax=gs[1,each_index-5]
    elif (each_index >= 10) & (each_index < 15):
        ax=gs[2,each_index-10]
    elif (each_index >= 15) & (each_index < 20):
        ax=gs[3,each_index-15]
    else:
        ax=gs[4,each_index-20]

    if each == 'density':
        pass
    else:
        if each in ['fe_h','alpha_fe',
        'C_fe','O_fe','Al_fe','Si_fe','K_fe','Ca_fe','Sc_fe','Ti_fe','V_fe',
        'Cr_fe','Mn_fe','Co_fe','Ni_fe','Zn_fe','Rb_fe','Sr_fe','Y_fe','Zr_fe',
        'Mo_fe','Ru_fe','Ba_fe','La_fe','Ce_fe','Nd_fe','Sm_fe','Eu_fe']:
            use = basic_cuts_mgnacu & (data['flag_'+each]==0)
        else:
            use = basic_cuts_mgnacu & np.isfinite(data[each])

        s1 = plot_density(
            data['Na_fe'][use],
            data['Mg_fe'][use] - data['Cu_fe'][use],
            weights = data[each][use],
            reduce_fn='mean',
            ax = ax,
            **hist_kwargs_mean_value
        )

        c = plt.colorbar(s1,ax=ax,location='top')
        c.set_label('Mean '+each.replace('_','\_'))
    if each_index in [0,5,10,15,20]:
        ax.set_ylabel('[Mg/Cu]')
    if each_index >= 10:
        ax.set_xlabel('[Na/Fe]')

plt.savefig('mgcunafe2.png',dpi=300,bbox_inches='tight')
plt.close()