In [None]:
import os
import sys
import specplot
import matplotlib.pyplot as plt
import numpy as np
import glob
from astropy.table import Table
from astropy.io import fits

In [None]:
EAZY_FLAM_SCALE = 1.55E-30 # 3.63078E-30 #?
fname_test = 'hlsp_candels_hst_wfc3_cos-tot-multiband_f160w_v1-1photom_sfr_mass_specbasiscoeffs_cat_20210430a.fits'
fname_templates = 'eazy_13_spectral_templates.dat'
template_data = specplot.load_eazypy_templates(os.path.join('data',fname_templates))
test_set_data = Table.read(os.path.join('data',fname_test),format='fits')

def get_eazy_template(hostgalID):
    basis_cols = [ x for x in test_set_data.colnames if 'SPECBASIS' in x]
    test_set_ind = np.where(test_set_data['pind'] == hostgalID)[0][0]
    eazycoeffs = [test_set_data[test_set_ind][x] for x in basis_cols]

    return(specplot.simulate_eazy_sed_from_coeffs(eazycoeffs, template_data, 
                                                  test_set_data[0]['redshift']))

def get_snana_spec_ascii(snana_fname_ascii):
    with open(snana_fname_ascii,'r') as f:
        text = f.readlines()
    data_lines = []
    for line in text:
        if 'SIM_HOSTLIB_GALID' in line:
            galid = int(line.split()[-1])
        elif 'VARNAMES_SPEC' in line:
            colnames = line.split()[1:]
        elif len(line.split()) > 0 and line.split()[0] == 'SPEC:':
            data_lines.append([float(x) for x in line.split()[1:]])

    snana_data_table = Table(np.array(data_lines).astype(float),names=colnames)
    snana_data_table['WAVE'] = np.mean([snana_data_table['LAMMIN'],
                                          snana_data_table['LAMMAX']],axis=0)
    if len(snana_data_table)<2:
        return None
    return galid, snana_data_table

def get_snana_spec_marz(snana_fname_marz):
    FLAM_SCALE = shallow_marz_file[0].header['HIERARCH FLAM_SCALE']
    marz_file = fits.open(snana_fname_marz)
    marz_wave = marz_file[2].data
    marz_flux = shallow_marz_file[0].data
    print('Cannot yet get hostgal id from marz file, stopping...')
    sys.exit()
    
def plot_snana_eazy_comp(snana_file_list,num_plots='all',saveplot_root=None):
    if num_plots == 'all':
        num_plots = len(snana_file_list)
    for file_ind in range(num_plots):
        snana_file = snana_file_list[file_ind]
        hostgal_id, snana_spec = get_snana_spec_ascii(snana_file)
        snana_wave, snana_flux = snana_spec['WAVE'], snana_spec['FLAM']
        eazy_wave, eazy_flux = get_eazy_template(hostgal_id)
        fig = plt.figure(figsize=(4,4))
        ax = fig.gca()
        ax.plot(snana_wave,snana_flux,linewidth=1,color='k',label='SNANA SIM')
        ax.plot(eazy_wave,eazy_flux*EAZY_FLAM_SCALE,linewidth=1,color='r',label='EAZY Template')
        ax.set_ylim((0,5*np.median(snana_flux)))
        ax.legend()
        ax.set_xlabel('Wavelength ($\AA$)')
        ax.set_ylabel(r'$\rm{Flux (erg/s/}cm^2/\AA$)')
        if saveplot_root is not None:
            plt.savefig(saveplot_root+'_%i.pdf'%hostgal_id,format='pdf')
        plt.show()

In [None]:
PATH_TO_SNANA_SIM_FOLDER = os.path.join('/Volumes','Justin_Pierel_SSD',
                                        'Data','roman_spec_sims')
SIM_FOLDER_NAME = 'Prism_shallow_hostIa'
shallow_spec_files = glob.glob(os.path.join(PATH_TO_SNANA_SIM_FOLDER,SIM_FOLDER_NAME,'*.DAT'))
plot_snana_eazy_comp(shallow_spec_files, num_plots=2, saveplot_root=SIM_FOLDER_NAME)
    
