# The Payne Tutorial

Follow the instructions from https://github.com/tingyuansen/The_Payne/tree/master:


```bash
cd /path/were/you/want/to/install/The_Payne/
git clone https://github.com/tingyuansen/The_Payne.git
cd The_Payne
python setup.py install
```


In [1]:
import numpy as np
from astropy.table import Table, hstack
from astropy.io import fits
import matplotlib.pyplot as plt
import os
from sklearn.model_selection import train_test_split

In [2]:
from The_Payne import training
from The_Payne import utils
from The_Payne import spectral_model

# """
# Changes that need to be made to training.py in The_Payne if no CUDA is available
"""
if torch.cuda.is_available():
    dtype = torch.cuda.FloatTensor
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
    dtype = torch.FloatTensor
    torch.set_default_tensor_type('torch.FloatTensor')
    
if torch.cuda.is_available():
    model.cuda()
        
if torch.cuda.is_available():
    perm = perm.cuda()
    
if torch.cuda.is_available():
    perm_valid = perm_valid.cuda()
"""

# I have also adjusted the output names to be a keyword argument.
# If you want that too, you would need to search->replace:
"""
def neural_net(training_labels, training_spectra, validation_labels, validation_spectra,\
         num_neurons = 300, num_steps=1e4, learning_rate=1e-4, batch_size=512,\
         num_features = 64*5, mask_size=11, num_pixel=7214,
         training_loss_name = "training_loss.npz",
         payne_model_name = "NN_normalized_spectra.npz"
         ):
"training_loss.npz" -> training_loss_name
"NN_normalized_spectra.npz" -> payne_model_name
""";

# Preparing example data

In [3]:
prepare_again = False

In [4]:
if prepare_again:
    galah_dr3 = Table.read('/Users/buder/GALAH_DR3/catalogs/GALAH_DR3_main_allstar_v2.fits')

    elements_to_fit = [
    #     'Li','C',
        'O','Na','Mg','Al','Si',
        'K','Ca','Sc','Ti','V','Cr','Mn','Co','Ni','Cu','Zn',
    #     'Rb','Sr',
        'Y',
    #     'Zr','Mo','Ru',
        'Ba','La',    
    #     'Ce','Nd','Sm',
        'Eu'
    ]

    galah_dr3_subset_giants = galah_dr3[
        (galah_dr3['flag_sp'] == 0) &
        (galah_dr3['flag_fe_h'] == 0) &
        (galah_dr3['snr_c1_iraf'] > 25) &
        (galah_dr3['snr_c2_iraf'] > 25) &
        (galah_dr3['snr_c3_iraf'] > 25) &
        (galah_dr3['snr_c4_iraf'] > 25) &
        (galah_dr3['logg'] < 3.) &
        np.all([galah_dr3['flag_'+element+'_fe'] == 0 for element in elements_to_fit], axis=0)
    ]
    galah_dr3_subset_dwarfs = galah_dr3[
        (galah_dr3['flag_sp'] == 0) &
        (galah_dr3['flag_fe_h'] == 0) &
        (galah_dr3['snr_c1_iraf'] > 25) &
        (galah_dr3['snr_c2_iraf'] > 25) &
        (galah_dr3['snr_c3_iraf'] > 25) &
        (galah_dr3['snr_c4_iraf'] > 25) &
        (galah_dr3['logg'] > 3.5) &
        np.all([galah_dr3['flag_'+element+'_fe'] == 0 for element in elements_to_fit], axis=0)
    ]
    galah_dr3_subset_subgiants = galah_dr3[
        (galah_dr3['flag_sp'] == 0) &
        (galah_dr3['flag_fe_h'] == 0) &
        (galah_dr3['snr_c1_iraf'] > 25) &
        (galah_dr3['snr_c2_iraf'] > 25) &
        (galah_dr3['snr_c3_iraf'] > 25) &
        (galah_dr3['snr_c4_iraf'] > 25) &
        (galah_dr3['logg'] > 3.0) &
        (galah_dr3['logg'] < 3.75) &
        np.all([galah_dr3['flag_'+element+'_fe'] == 0 for element in elements_to_fit], axis=0)
    ]

    np.random.seed(712)
    random_45_giants_subset = np.random.randint(len(galah_dr3_subset_giants['teff']), size=45)
    random_10_subgiants_subset = np.random.randint(len(galah_dr3_subset_subgiants['teff']), size=10)
    random_45_dwarfs_subset = np.random.randint(len(galah_dr3_subset_dwarfs['teff']), size=45)

    training_label_table = Table()
    for key in ['sobject_id','teff','logg','fe_h','vbroad','vmic']:
        training_label_table[key] = np.concatenate((
            galah_dr3_subset_giants[key][random_45_giants_subset],
            galah_dr3_subset_subgiants[key][random_10_subgiants_subset],
            galah_dr3_subset_dwarfs[key][random_45_dwarfs_subset],
        ))
        
    training_label_table.write('GALAH_DR3_100_labels.fits',overwrite=True)

In [5]:
# for sobject_id in training_label_table['sobject_id'][10:20]:
#     try:
#         os.system('rsync -azu galah@galahobs.datacentral.org.au:/galah/DR3/data/galah/dr3/spectra/'+str(sobject_id)+'1.fits spectra/')
#         os.system('rsync -azu galah@galahobs.datacentral.org.au:/galah/DR3/data/galah/dr3/spectra/'+str(sobject_id)+'2.fits spectra/')
#         os.system('rsync -azu galah@galahobs.datacentral.org.au:/galah/DR3/data/galah/dr3/spectra/'+str(sobject_id)+'3.fits spectra/')
#         os.system('rsync -azu galah@galahobs.datacentral.org.au:/galah/DR3/data/galah/dr3/spectra/'+str(sobject_id)+'4.fits spectra/')
#     except:
#         print(sobject_id)

# Prepare necessary data

In [6]:
training_set_label_table = Table.read('GALAH_DR3_100_labels.fits')
# put all the labels (i.e. the ones that are not sobject_id) into an array
training_set_labels = np.array([list(training_set_label_table[key]) for key in list(training_set_label_table.keys())[1:]]).T

In [7]:
wavelengths_ccds = dict()
wavelengths_ccds['1']=np.arange(4715.94,4896.00,0.046)
wavelengths_ccds['2']=np.arange(5650.06,5868.25,0.055)
wavelengths_ccds['3']=np.arange(6480.52,6733.92,0.064)
wavelengths_ccds['4']=np.arange(7693.50,7875.55,0.074)

In [8]:
training_set_wavelength = np.concatenate(([wavelengths_ccds[ccd] for ccd in ['1','2','3','4']]))
training_set_flux = []
training_set_flux_uncertainty = []

for sobject_id in training_set_label_table['sobject_id'][:20]:
    flux_per_spectrum = []
    flux_uncertainty_per_spectrum = []
    for ccd in ['1','2','3','4']:
        fits_file = fits.open('spectra/'+str(sobject_id)+ccd+'.fits')
        wavelength_raw = fits_file[4].header['CRVAL1'] + fits_file[4].header['CDELT1'] * np.arange(fits_file[4].header['NAXIS1'])
        flux_per_spectrum.append(
            np.array(np.interp(wavelengths_ccds[ccd], wavelength_raw, fits_file[4].data).clip(min=0.01,max=1.2))
        )
        flux_uncertainty_per_spectrum.append(
            np.array(np.interp(wavelengths_ccds[ccd], wavelength_raw, fits_file[1].data * fits_file[4].data).clip(min=0.001))
        )
        fits_file.close()
    training_set_flux.append(np.concatenate((flux_per_spectrum)))
    training_set_flux_uncertainty.append(np.concatenate((flux_uncertainty_per_spectrum)))

training_set_flux = np.array(training_set_flux)
training_set_flux_uncertainty = np.array(training_set_flux_uncertainty)

# Training

In [9]:
train, test = train_test_split(np.arange(np.shape(training_set_flux)[0]), test_size=0.10, random_state=8876)

In [None]:
model_file = 'galah_tutorial'

training.neural_net(
    training_labels = training_set_labels[train,:], 
    training_spectra = training_set_flux[train,:],
    validation_labels = training_set_labels[test,:], 
    validation_spectra = training_set_flux[test,:],
    num_neurons=20,
    learning_rate=1e-4,
    num_steps=1e4,
    batch_size=28,
    num_pixel=np.shape(training_set_flux[0])[0]
    training_loss_name = model_file+'_loss.npz',
    payne_model_name = model_file+'.npz'
    )