In [1]:
from __future__ import absolute_import, division, print_function # Python2 compatibility
import numpy as np
import matplotlib.pyplot as plt

try:
    %matplotlib inline
    %config InlineBackend.figure_format='retina'
except:
    pass

import numpy as np

from The_Payne import utils
from The_Payne import spectral_model
from The_Payne import fitting
from The_Payne import training_v1 as training

In [2]:
utils.create_wavelength_array(survey='galah')
utils.create_galah_mask()

In [3]:
wavelength = utils.load_wavelength_array(survey='galah')
mask = utils.load_galah_mask()
num_pixel = len(wavelength)
ccd = dict()
for each_ccd in range(4):
    ccd[each_ccd] = (wavelength > 1000*(4+each_ccd)) & (wavelength < 1000*(5+each_ccd))

In [4]:
train = False

if train == True:
    training_labels, training_spectra, validation_labels, validation_spectra = utils.load_training_data(survey='galah')
    training_loss, validation_loss = training.neural_net(training_labels, training_spectra,\
                                                     validation_labels, validation_spectra,\
                                                     num_neurons = 300, num_steps=1e3, learning_rate=0.001)
else:
    tmp = np.load("NN_normalized_spectra_191101_small.npz")
    w_array_0 = tmp["w_array_0"]
    w_array_1 = tmp["w_array_1"]
    w_array_2 = tmp["w_array_2"]
    b_array_0 = tmp["b_array_0"]
    b_array_1 = tmp["b_array_1"]
    b_array_2 = tmp["b_array_2"]
    x_min = tmp["x_min"]
    x_max = tmp["x_max"]
    tmp.close()
    NN_coeffs = (w_array_0, w_array_1, w_array_2, b_array_0, b_array_1, b_array_2, x_min, x_max)

In [5]:
observed_spectra = np.load('galah_selection_191030.npz')
all_sobject_id = observed_spectra['sobject_id']
all_labels = observed_spectra['labels']
all_flux = observed_spectra['flux']
all_flux_error = observed_spectra['flux_error']
observed_spectra.close()

In [None]:
payne_labels = np.ones(np.shape(all_labels))

for each_index in range(len(all_sobject_id)):
    # fit the labels
    popt, pcov, best_fit_spec = fitting.fit_normalized_spectrum_single_star_model(norm_spec = all_flux[each_index], 
        spec_err = all_flux_error[each_index], NN_coeffs = NN_coeffs, wavelength=wavelength, mask=mask, p0 = None)

    # just a simple progress print
    if each_index%50==0:
        print(each_index)
    
    # store optimised parameters in 'payne_labels'
    payne_labels[:,each_index] = popt[:-1]

0
50
100
150
200
250
300
350
400
450
500
550
600
650
700


In [None]:
def fit_and_plot(sobject_id):
    each_index = np.where(all_sobject_id == sobject_id)[0]
    if len(each_index) != 1:
        raise RuntimeError('Could not find sobject_id')
    else:
        each_index = each_index[0]
    labels = all_labels[:,each_index]
    spec = all_flux[each_index]
    spec_err = all_flux_error[each_index]
    
    popt, pcov, best_fit_spec = fitting.fit_normalized_spectrum_single_star_model(norm_spec = spec, 
        spec_err = spec_err, NN_coeffs = NN_coeffs, wavelength=wavelength, mask=mask, p0 = None)

    print('GALAH DR3:')
    print(labels)
    print('The Payne')
    print(popt[:-1])
    
    f, gs = plt.subplots(8,1,figsize=(15,20))
    for it in range(4):
        ax = gs[2*it]
        ax2 = gs[2*it+1]
        ax.plot(wavelength[ccd[it]],spec[ccd[it]], 'k', lw=0.5, label = 'GALAH spectrum')
        #ax.fill_between(wavelength[ccd[it]],spec[ccd[it]]-spec_err[ccd[it]],spec[ccd[it]]+spec_err[ccd[it]], facecolor='grey', label = 'flux error')
        ax.plot(wavelength[ccd[it]],best_fit_spec[ccd[it]], 'r', lw=0.5, label = 'Best-fit model')
        ax.set_ylim(0., 1.05)

        ax2.plot(wavelength[ccd[it]],spec[ccd[it]]-best_fit_spec[ccd[it]], 'r', lw=0.5, label = 'Obs - Model')
        ax2.set_ylim(-0.1, 0.1)

In [None]:
solar_twins = np.where(
    (np.abs(all_labels[0] - 5777) < 100) &
    (np.abs(all_labels[1] - 4.43) < 0.1) &
    (np.abs(all_labels[2] - 0.00) < 0.1) &
    (np.abs(all_labels[3] - 0.00) < 0.1)
    )[0]

In [None]:
for each_index in solar_twins[:2]:
    fit_and_plot(all_sobject_id[each_index])

In [None]:
bad_logg = np.where(
    (np.abs(all_labels[1] - payne_labels[1]) > 1.5)
    )[0]
print(bad_logg)

In [None]:
for each_index in bad_logg[:2]:
    fit_and_plot(all_sobject_id[each_index])

In [None]:
f, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2,2,figsize=(15,10))

kwargs = dict(s=2)

ax1.scatter(all_labels[0],all_labels[1],c=all_labels[2],vmin=-2,vmax=0.5,**kwargs)
ax2.scatter(payne_labels[0],payne_labels[1],c=payne_labels[2],vmin=-2,vmax=0.5,**kwargs)
ax3.scatter(all_labels[2],all_labels[3],**kwargs)
ax4.scatter(payne_labels[2],payne_labels[3],**kwargs)

ax1.set_xlim(8000,3000)
ax1.set_xlabel('Teff DR3')
ax1.set_ylim(5,0)
ax1.set_ylabel('logg DR3')
ax2.set_xlim(8000,3000)
ax2.set_xlabel('Teff Payne')
ax2.set_ylim(5,0)
ax2.set_ylabel('logg Payne')

ax3.set_xlim(-2.5,0.5)
ax3.set_ylim(-0.25,0.5)
ax4.set_xlim(-2.5,0.5)
ax4.set_ylim(-0.25,0.5)

ax3.set_xlabel('[Fe/H] DR3')
ax3.set_ylabel('[alpha/Fe] DR3')
ax4.set_xlabel('[Fe/H] Payne')
ax4.set_ylabel('[alpha/Fe] Payne')

plt.tight_layout()

In [None]:
f, gs = plt.subplots(4,1,figsize=(15,10))
label_names = ['Teff','logg','[Fe/H]','[alpha/Fe]']
for each_index in range(4):
    ax=gs[each_index]
    ax.scatter(
        all_labels[each_index], payne_labels[each_index]-all_labels[each_index],s=1
        )
    ax.set_xlabel(label_names[each_index])
    ax.set_ylabel('Mod-Obs '+label_names[each_index])
plt.tight_layout()