In [1]:
from astropy.io import fits
import pickle
import numpy as np
from astropy.table import Table
from astropy.convolution import convolve, Gaussian1DKernel

In [2]:
with open("/pscratch/sd/j/juikuan/DESI_LAE_dataset/original_dataset/fuji_pre_lite.pkl", "rb") as fh:
    pre = pickle.load(fh)
with open("/pscratch/sd/j/juikuan/DESI_LAE_dataset/original_dataset/fuji_similar_lite.pkl", "rb") as fh:
    similar = pickle.load(fh)
with open("/pscratch/sd/j/juikuan/DESI_LAE_dataset/original_dataset/fuji_NLAE_lite.pkl", "rb") as fh:
    NLAE = pickle.load(fh)
with open("/pscratch/sd/j/juikuan/DESI_LAE_dataset/original_dataset/fuji_random_lite.pkl", "rb") as fh:
    random = pickle.load(fh)

In [3]:
def label_generate_old(spectra, save, interval = 25, width_mode = "sum"):
    label = []
    label = []
    min_wave = 3600
    max_wave = 5550

    for i in spectra:
        if len(i["VI"]) == 0:
            continue
        else:
            spectrum = []
            line_loc = []
            line_pre = []
            line_std = []
        if i["VI"][-1] == "LAE":
            wavelength = i["SPECTRUM"].wave["brz"][i["SPECTRUM"].wave["brz"] < 5550]
            flux = i["SPECTRUM"].flux["brz"][0][i["SPECTRUM"].wave["brz"] < 5550]
            for j in range(min_wave, max_wave, interval):
                if i["PARAMS"][1] > j and i["PARAMS"][1] < j + interval:
                    line_loc.append(1)
                    line_pre.append((i["PARAMS"][1] - j + interval) / 75)
                    if width_mode == "gaussian":
                        line_std.append((2 * (2 * np.log(2)) ** (1 / 2) * i["PARAMS"][2]) / 0.8)
                    elif width_mode == "sum":
                        min_wavelength = i["PARAMS"][1] - 4 * i["PARAMS"][2]
                        max_wavelength = i["PARAMS"][1] + 4 * i["PARAMS"][2]
                        mask1 = wavelength > min_wavelength
                        mask2 = wavelength[mask1] < max_wavelength
                        signal = flux[mask1]
                        signal = signal[mask2]
                        signal[signal < 0] = 0
                        cumulative_flux = np.cumsum(signal)
                        lower = np.argmax(cumulative_flux > 0.025 * cumulative_flux[-1])
                        higher = np.argmax(cumulative_flux > 0.975 * cumulative_flux[-1])
                        width_pixel = higher - lower + 1
                        line_std.append(width_pixel)
                    if ((i["PARAMS"][1] - j + interval) / 75) < 1 / 3 or ((i["PARAMS"][1] - j + interval) / 75) > 2 / 3:
                        print("warning")
                    else:
                        pass
                else:
                    line_loc.append(0)
                    line_pre.append(0)
                    line_std.append(0)

            if len(list(filter(lambda x: x == 1, line_loc))) != 1:
                print("Wrong")

            spectrum.append(line_loc)
            spectrum.append(line_pre)
            spectrum.append(line_std)
            label.append(spectrum)
        
        elif i["VI"][-1] == "QSO" or i["VI"] == "ELG":
            label.append(np.zeros((3, len([i for i in range(3600, 5550, interval)]))))

    label = np.array(label)
    with open(str(save), "wb") as fh:
        pickle.dump(label, fh)
        
    return label

In [4]:
def label_generate_new(spectra, save, interval = 25, width_mode = "sum"):
    label = []
    label = []
    min_wave = 3600
    max_wave = 5550

    for i in spectra:
        if len(i["VI"]) == 0:
            continue
        else:
            spectrum = []
            line_loc = []
            line_pre = []
            line_std = []
        if i["VI"][-1] == "LAE":
            wavelength = i["WAVE"][i["WAVE"] < 5550]
            flux = i["FLUX"][i["WAVE"] < 5550]
            for j in range(min_wave, max_wave, interval):
                if i["PARAMS"][1] > j and i["PARAMS"][1] < j + interval:
                    line_loc.append(1)
                    line_pre.append((i["PARAMS"][1] - j + interval) / 75)
                    if width_mode == "gaussian":
                        line_std.append((2 * (2 * np.log(2)) ** (1 / 2) * i["PARAMS"][2]) / 0.8)
                    elif width_mode == "sum":
                        min_wavelength = i["PARAMS"][1] - 4 * i["PARAMS"][2]
                        max_wavelength = i["PARAMS"][1] + 4 * i["PARAMS"][2]
                        mask1 = wavelength > min_wavelength
                        mask2 = wavelength[mask1] < max_wavelength
                        signal = flux[mask1]
                        signal = signal[mask2]
                        signal[signal < 0] = 0
                        cumulative_flux = np.cumsum(signal)
                        lower = np.argmax(cumulative_flux > 0.025 * cumulative_flux[-1])
                        higher = np.argmax(cumulative_flux > 0.975 * cumulative_flux[-1])
                        width_pixel = higher - lower + 1
                        line_std.append(width_pixel)
                    if ((i["PARAMS"][1] - j + interval) / 75) < 1 / 3 or ((i["PARAMS"][1] - j + interval) / 75) > 2 / 3:
                        print("warning")
                    else:
                        pass
                else:
                    line_loc.append(0)
                    line_pre.append(0)
                    line_std.append(0)

            if len(list(filter(lambda x: x == 1, line_loc))) != 1:
                print("Wrong")

            spectrum.append(line_loc)
            spectrum.append(line_pre)
            spectrum.append(line_std)
            label.append(spectrum)
        
        elif i["VI"][-1] == "QSO" or i["VI"][-1] == "ELG":
            label.append(np.zeros((3, len([i for i in range(3600, 5550, interval)]))))

    label = np.array(label)
    with open(str(save), "wb") as fh:
        pickle.dump(label, fh)
        
    return label

In [5]:
def spectra_generate(save, spectra):
    b_band = []
    for i in spectra:
        if len(i["VI"]) == 0:
            continue
        else:
            pass
        
        dic = {}
        dic["Z"] = i["Z"]
        dic["TARGETID"] = i["TARGETID"]
        dic["WAVE"] = i["WAVE"][i["WAVE"] < 5550]
        dic["IVAR"] = i["IVAR"][i["WAVE"] < 5550]
        dic["FLUX"] = i["FLUX"][i["WAVE"] < 5550]
        dic["CONV_FLUX"] = convolve(i["FLUX"][i["WAVE"] < 5550], Gaussian1DKernel(1))
        
        if i["VI"][-1] == "LAE":
            b_band.append(dic)
        elif i["VI"][-1] == "ELG" or i["VI"][-1] == "QSO":
            b_band.append(dic)

    with open(str(save), "wb") as fh:
        pickle.dump(b_band, fh)
            
    return b_band

In [9]:
with open("/pscratch/sd/j/juikuan/DESI_LAE_dataset/original_dataset/iron_pre_lite.pkl", "rb") as fh:
    iron_pre = pickle.load(fh)

In [10]:
label = label_generate_new(spectra = iron_pre, save = "/pscratch/sd/j/juikuan/DESI_LAE_dataset/train_label/iron_pre_25.pkl", interval = 25, width_mode = "sum")

In [13]:
b_band = spectra_generate(save = "/pscratch/sd/j/juikuan/DESI_LAE_dataset/train_spectra/iron_pre_b.pkl", spectra = iron_pre)

In [4]:
with open("/pscratch/sd/j/juikuan/DESI_LAE_dataset/train_label/fuji_NLAE_25.pkl", "rb") as fh:
    label = pickle.load(fh)

In [14]:
len(b_band)

1753

In [None]:
iron_pre[0]