In [2]:
from astropy.io import fits
import pickle
import numpy as np
from astropy.table import Table

In [4]:
with open("../DESI_LAE_dataset/original_dataset/iron_spectra.pkl", "rb") as fh:
    spectra = pickle.load(fh)

In [15]:
def iron_label_generate(spectra, interval = 25, width_mode = "sum"):
    LAE_label = []
    NLAE_label = []
    min_wave = 3600
    max_wave = 5550

    for i in spectra:
        if "VI" not in i:
            continue
        else:
            spectrum = []
            line_loc = []
            line_pre = []
            line_std = []
        if i["VI"] == 1:
            wavelength = i["SPECTRUM"].wave["brz"][i["SPECTRUM"].wave["brz"] < 5550]
            flux = i["SPECTRUM"].flux["brz"][0][i["SPECTRUM"].wave["brz"] < 5550]
            for j in range(min_wave, max_wave, interval):
                if i["PARAMS"][1] > j and i["PARAMS"][1] < j + interval:
                    line_loc.append(1)
                    line_pre.append((i["PARAMS"][1] - j + interval) / 75)
                    if width_mode == "gaussian":
                        line_std.append((2 * (2 * np.log(2)) ** (1 / 2) * i["PARAMS"][2]) / 0.8)
                    elif width_mode == "sum":
                        min_wavelength = i["PARAMS"][1] - 4 * i["PARAMS"][2]
                        max_wavelength = i["PARAMS"][1] + 4 * i["PARAMS"][2]
                        mask1 = wavelength > min_wavelength
                        mask2 = wavelength[mask1] < max_wavelength
                        signal = flux[mask1]
                        signal = signal[mask2]
                        signal[signal < 0] = 0
                        cumulative_flux = np.cumsum(signal)
                        lower = np.argmax(cumulative_flux > 0.025 * cumulative_flux[-1])
                        higher = np.argmax(cumulative_flux > 0.975 * cumulative_flux[-1])
                        width_pixel = higher - lower + 1
                        line_std.append(width_pixel)
                    if ((i["PARAMS"][1] - j + interval) / 75) < 1 / 3 or ((i["PARAMS"][1] - j + interval) / 75) > 2 / 3:
                        print("warning")
                    else:
                        pass
                else:
                    line_loc.append(0)
                    line_pre.append(0)
                    line_std.append(0)

            if len(list(filter(lambda x: x == 1, line_loc))) != 1:
                print("Wrong")

            spectrum.append(line_loc)
            spectrum.append(line_pre)
            spectrum.append(line_std)
            LAE_label.append(spectrum)
        
        elif i["VI"] == 0:
            NLAE_label.append(np.zeros((3, len([i for i in range(3600, 5550, interval)]))))

    LAE_label = np.array(LAE_label)
    NLAE_label = np.array(NLAE_label)
    with open("../DESI_LAE_dataset/LAE_iron_25.pkl", "wb") as fh:
        pickle.dump(LAE_label, fh)
    with open("../DESI_LAE_dataset/NLAE_iron_25.pkl", "wb") as fh:
        pickle.dump(LAE_label, fh)
        
    return LAE_label, NLAE_label

In [13]:
def iron_spectra_generate(spectra):
    LAE_b_band = []
    NLAE_b_band = []
    for i in spectra:
        if "VI" not in i:
            continue
        else:
            pass
        
        dic = {}
        dic['_dr'] = i['_dr']
        dic["Z"] = i["Z"]
        dic["TARGETID"] = i["TARGETID"]
        dic["WAVE"] = i["SPECTRUM"].wave["brz"][i["SPECTRUM"].wave["brz"] < 5550]
        dic["IVAR"] = i["SPECTRUM"].ivar["brz"][0][i["SPECTRUM"].wave["brz"] < 5550]
        dic["FLUX"] = i["SPECTRUM"].flux["brz"][0][i["SPECTRUM"].wave["brz"] < 5550]
        
        if i["VI"] == 1:
            LAE_b_band.append(dic)
        elif i["VI"] == 0:
            NLAE_b_band.append(dic)
        
    with open("../DESI_LAE_dataset/train_spectra/LAE_iron_b.pkl", "wb") as fh:
        pickle.dump(LAE_b_band, fh)
    with open("../DESI_LAE_dataset/train_spectra/NLAE_iron_b.pkl", "wb") as fh:
        pickle.dump(NLAE_b_band, fh)
            
    return LAE_b_band, NLAE_b_band

In [16]:
LAE_label, NLAE_label = LAE_label_generate(spectra, interval = 25, width_mode = "sum")

In [14]:
LAE_b_band, NLAE_b_band = iron_spectra_generate(spectra = spectra)

Process: 100.0 %9161139289 %%%%