In [None]:
from chromatic import *
import emcee
import corner
import speclite as speclite
import glob as glob
from speclite import filters
from tqdm import tqdm
from matplotlib import cm
from matplotlib.artist import Artist
from scipy.optimize import minimize
from scipy.optimize import curve_fit
from PyAstronomy import pyasl
from specutils.spectra import Spectrum1D, SpectralRegion
from specutils.fitting import fit_generic_continuum

params = {'legend.fontsize': 'x-large',
          'figure.figsize': (6, 4),
         'axes.labelsize': 'x-large',
         'axes.titlesize':'x-large',
         'xtick.labelsize':'x-large',
         'ytick.labelsize':'x-large'}
plt.rcParams.update(params)
plt.style.use('tableau-colorblind10')

transmission_data = np.loadtxt('../data/transmission_comp.txt')
mol_wave = transmission_data[:,0] * u.nm
mol_data = transmission_data[:,1]

In [None]:
def normalize_nres(data_flux,wavelength=None,**kwargs):
    
    flux = data_flux * u.erg
    
    _m = flux/np.nanmedian(flux)
    spectrum = Spectrum1D(flux=_m, spectral_axis=wavelength)
    with warnings.catch_warnings():  # Ignore warnings
        warnings.simplefilter('ignore')
        g1_fit = fit_generic_continuum(spectrum)
        continuum_fit = g1_fit(wavelength)
        
    normed_flux = _m/continuum_fit
    
    return normed_flux

In [None]:
def normalized_1T_PHOENIX(Tspec=5000,wavelength=None,**kwargs):
        
    _model = get_phoenix_photons(temperature=float(Tspec),wavelength=wavelength,
                                 logg=4.4, metallicity=0.0)
    
    _m = _model[1]/np.nanmedian(_model[1])
    spectrum = Spectrum1D(flux=_m, spectral_axis=_model[0])
    with warnings.catch_warnings():  # Ignore warnings
        warnings.simplefilter('ignore')
        g1_fit = fit_generic_continuum(spectrum)
        continuum_fit = g1_fit(_model[0])
        
    normed_model = _m/continuum_fit
    
    return normed_model

In [None]:
def normalized_2T_PHOENIX(spot_params=[0.4,0.1,3000,3800], spotspec=None,ambspec=None, **kwargs):
    
    f_spot,df_spot,T_spot, T_amb = spot_params
    
    _m = f_spot*spotspec[1] + (1.-f_spot)*ambspec[1]
    normed_spec = _m/np.nanmedian(_m)
    spectrum = Spectrum1D(flux=normed_spec, spectral_axis=spotspec[0])
    with warnings.catch_warnings():  # Ignore warnings
        warnings.simplefilter('ignore')
        g1_fit = fit_generic_continuum(spectrum)
        continuum_fit = g1_fit(spotspec[0])
    model = normed_spec/continuum_fit
    
    return model

In [None]:
def average_spectrum_model(parameters=[0.35,0.07,2500.,4000.],
                           data_flux=None, data_err=None,wavelength=None,
                           samples_exist=False,samples=None, **kwargs):
    
    f_spot,df_spot,T_spot,T_amb = parameters

    S_spot = get_phoenix_photons(temperature=int(parameters[2]),wavelength=wavelength,
                                    logg=4.4,metallicity=0.0)
    S_amb = get_phoenix_photons(temperature=int(parameters[3]),wavelength=wavelength,
                                   logg=4.4,metallicity=0.0)

    model = normalized_2T_PHOENIX(spot_params=parameters,wavelength = wavelength,
                                  spotspec=S_spot,ambspec=S_amb)
        
    fig, [ax0,ax1] = plt.subplots(2,1,figsize=(9,6),sharex=True,gridspec_kw=dict(height_ratios=[1,0.3]))
    title_label=f'{visit} {order}'
    fig.suptitle(title_label,fontsize=20)

    # Top plot, the processed data and combined model
    ax0.plot(wavelength, model, label = 'Spot model',color='k',zorder=100)
    # ax0.errorbar(ref_r.wavelength,ref_r.flux[:,2], yerr=ref_r.uncertainty[:,2], zorder=-1000,color='gray',
    #              label='Unprocessed NRES Spectrum',fmt='',alpha=0.2)
    ax0.errorbar(wavelength, data_flux, yerr=data_err, zorder=-100,color='teal',
                 label='NRES time-averaged spectrum',fmt='',alpha=1)
    ax0.set_ylim(0.25,1.5)
    ax0.set_ylabel('Relative Flux',fontsize=18)
    ax0.legend(loc='upper right',fontsize=14)
    # The middle plot, residuals of the above data and model
    ax1.errorbar(wavelength,(data_flux-model)/(data_err),
                 yerr=1,label='Residual',color='k',zorder=100)
    ax1.set_ylabel(r'Residual ($\sigma$)',fontsize=18)
    ax1.set_ylim(-5,5)
    # ax1.legend(loc='lower right')
    # ax1.axhspan(-1,1,color='red',alpha=0.2,zorder=10)
    # ax1.axhspan(-2,2,color='green',alpha=0.2,zorder=0)
    # ax1.axhspan(-3,3,color='gray',alpha=0.3,zorder=-10)
    ax1.set_xlim(ref_r.wavelength.value[0],ref_r.wavelength.value[-1])

    if samples_exist:
        fspot_sam, dfspot_sam, Tspot_sam, Tamb_sam = samples  
        for k in range(0,500):
            i = np.random.randint(low=0,high=(len(Tamb_sam)-1))
            S_spot = get_phoenix_photons(temperature=Tspot_sam[i], wavelength = wavelength,
                                                           logg=4.4, metallicity=0.0)
            S_amb = get_phoenix_photons(temperature=Tamb_sam[i], wavelength = wavelength,
                                                           logg=4.4, metallicity=0.0)
            _combined = fspot_sam[i]*S_spot[1] + (1.-fspot_sam[i])*S_amb[1]
            spectrum = Spectrum1D(flux=_combined/np.nanmedian(_combined), spectral_axis=S_spot[0])
            with warnings.catch_warnings():  # Ignore warnings
                warnings.simplefilter('ignore')
                g1_fit = fit_generic_continuum(spectrum)
                continuum_fit = g1_fit(S_spot[0])
            _model = _combined/np.nanmedian(_combined)/continuum_fit
            ax0.plot(wavelength, _model, zorder=10, alpha=0.02, color='firebrick')

    plt.savefig(f'../figs/{label}_specmodel.png',dpi=100)
    plt.show()
    plt.close()

In [None]:
ordernumbers = np.arange(53,84)
bad_orders = np.array([56,57,59,60,62,63,64,65,66,67,68,69,
                       72,73,74,75,78,79,80,83])
good_orders = np.array([53,54,55,60,61,69,70,71,75,76,77,80,81,82,83])

# nsteps = 1000
# modeltype = 'Final_Teff_Spec'
# modelname = 'Teff + Spec'
# visits = ['F21','S22','combined']
# orders_I_care_about = good_orders

# nsteps = 2000
# modeltype = f'Final_Teff_Phot_Spec'
# modelname = 'Teff + Phot + Spec'
# suptitle = r'T$_{\rm{eff}}$ & Photometric & S($\lambda$) Model Results'
# visits = ['F21','S22']
# orders_I_care_about = good_orders

nsteps = 3000
modeltype = f'Final_Teff_Phot_Spec'
modelname = 'Teff + Phot + Spec'
suptitle = r'T$_{\rm{eff}}$ & Photometric & S($\lambda$) Model Results'
visits = ['F21','S22']
orders_I_care_about = good_orders

for visit in tqdm(visits):
    
    N_data = 0
                
    for order in orders_I_care_about:
        label = f'{visit}_{order}_{modeltype}'
        samples_file_label = label+f'_{nsteps}steps'
        '''
        READ IN THE MCMC SAMPLES
        '''
#         reader = emcee.backends.HDFBackend(f'../data/samples/{samples_file_label}.h5')
#         sampler = reader.get_chain(discard=int(0.25*nsteps), flat=True)
#         samples = sampler.reshape((-1, 4)).T
#         fspot_sam, dfspot_sam, Tspot_sam, Tamb_sam = samples
#         sig1_fspot = np.percentile(fspot_sam, [15.9, 50., 84.1]) # central 1-sigma values
#         sig1_dfspot = np.percentile(dfspot_sam, [15.9, 50., 84.1])
#         sig1_Tspot = np.percentile(Tspot_sam, [15.9, 50., 84.1])
#         sig1_Tamb = np.percentile(Tamb_sam, [15.9, 50., 84.1])
#         # Define the 50th percentile and 1-sigma interval parameter values from the samples
#         best_params = np.array([sig1_fspot[1], sig1_dfspot[1], sig1_Tspot[1], sig1_Tamb[1]])
#         best_params_err_lower = best_params-[sig1_fspot[0], sig1_dfspot[0], sig1_Tspot[0], sig1_Tamb[0]]
#         best_params_err_higher = [sig1_fspot[2], sig1_dfspot[2], sig1_Tspot[2], sig1_Tamb[2]]-best_params
#         variable_names = [r'f$_{\rm{spot}}$',r'$\Delta$f$_{\rm{spot}}$',r'T$_{\rm{spot}}$',r'T$_{\rm{amb}}$']
#         normie_names = ['fspot','dfspot','Tspot','Tamb']
        
        _r = read_rainbow(f"../data/rainbows/{visit}_{order}_original.rainbow.npy")
        ref_r = _r.trim()
        nres_rainbow = read_rainbow(f"../data/rainbows/{visit}_{order}_clipped.rainbow.npy")
        data_wave = nres_rainbow.wavelength
        nres_avg_1dspec = nres_rainbow.get_average_spectrum()
        nres_avg_1derr = np.nanmedian(nres_rainbow.uncertainty,axis=1)/np.sqrt(len(nres_rainbow.timelike['time']))
        template_spec = normalized_1T_PHOENIX(temp=3600,wavelength=data_wave)
        rchisq = np.nansum((template_spec-nres_avg_1dspec)**2/(nres_avg_1derr)**2) / (len(data_wave.value)-1)
        if rchisq > 1.0:
            nres_avg_1derr = nres_avg_1derr * np.sqrt(rchisq)
            
        N_data += len(data_wave)
         
        average_spectrum_model(parameters=best_params, wavelength=data_wave,
                               data_flux=nres_avg_1dspec, data_err=nres_avg_1derr,
                               plot=True,samples_exist=True,samples=samples)
        
    # print(N_data)