In [389]:
import os
import glob
import json
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import dask.dataframe as dd
from tqdm.notebook import tqdm

In [390]:
os.chdir('/home/ricky/RNNAE/plasticc_csv')
print(os.getcwd())

/home/ricky/RNNAE/plasticc_csv


In [391]:
req_cols_meta = ['object_id', 'ddf_bool','target',
       'true_target', 'true_submodel', 'true_z', 'true_distmod',
       'true_lensdmu', 'true_peakmjd',
        ]

req_cols = ['object_id', 'mjd', 'passband', 'flux', 'flux_err']

data_meta = pd.read_csv('plasticc_test_metadata.csv', usecols=req_cols_meta, low_memory=False)
data = pd.read_csv('plasticc_test_lightcurves_02.csv', usecols=req_cols, low_memory=False)

In [5]:
pd.isnull(data).sum()

object_id    0
mjd          0
passband     0
flux         0
flux_err     0
dtype: int64

In [6]:
print(data_meta.columns)
print(data.columns)

Index(['object_id', 'ddf_bool', 'target', 'true_target', 'true_submodel',
       'true_z', 'true_distmod', 'true_lensdmu', 'true_peakmjd'],
      dtype='object')
Index(['object_id', 'mjd', 'passband', 'flux', 'flux_err'], dtype='object')


In [393]:
min = np.min(data.object_id)
#min_id = np.argmin(data.object_id)
#print(min, min_id)

max = np.max(data.object_id)
#max_id = np.argmax(data.object_id)
#print(max, max_id)

min_id = list(data_meta[data_meta.object_id == min].object_id.index)[0]
max_id = list(data_meta[data_meta.object_id == max].object_id.index)[0]

print(min_id, max_id)

32926 378922


In [41]:
time = [[[] for j in range(6)] for i in data_meta.object_id[min_id:max_id]]

value, counts = np.unique(data.object_id, return_counts=True)
print(counts)

for i, oid in tqdm(enumerate(data_meta.object_id[min_id:max_id])):
    time[i][0] = list(data.mjd[i:i+counts[i]])

print(time[0][0])

[140 130 126 ... 134 108 142]


0it [00:00, ?it/s]

[59583.2493, 59585.3547, 59586.3595, 59590.3573, 59597.1711, 59614.2014, 59615.3783, 59629.3124, 59650.0262, 59658.0486, 59660.1741, 59663.1842, 59682.0658, 59684.1371, 59686.0893, 59698.9682, 59712.0934, 59713.0721, 59769.9638, 59899.3524, 59902.3542, 59903.3445, 59914.3118, 59914.3486, 59917.3393, 59918.3164, 59922.3132, 59924.3146, 59927.3007, 59930.3271, 59931.3482, 59932.3233, 59934.3444, 59935.2752, 59938.3063, 59941.269, 59953.3005, 59957.3304, 59964.1966, 59965.2271, 59973.2882, 59976.1232, 59976.2992, 59979.2119, 59980.2042, 59986.3184, 59994.1876, 59995.0852, 59999.1247, 60002.1365, 60008.197, 60011.1804, 60012.3005, 60016.0589, 60017.1047, 60017.2486, 60018.0503, 60019.2211, 60026.0514, 60027.0839, 60032.0596, 60039.184, 60040.1296, 60041.1914, 60045.2092, 60053.039, 60066.1289, 60068.0589, 60070.1226, 60072.0973, 60073.0361, 60095.0505, 60096.0488, 60101.0262, 60102.0136, 60264.3572, 60270.3333, 60280.3506, 60285.3417, 60290.2995, 60296.3406, 60297.3181, 60309.3449, 60312.3

In [394]:
def avoid_non_SNIa(ii, csv_data_meta):
    
    SN_type = csv_data_meta.true_target[ii]

    if SN_type == 90:
        return True
    else:
        return False

def avoid_empty_SN(ii, oid, oid_count_elem, filters, 
    csv_data, csv_data_meta,
    num=10, lc_length_prepeak=-250, lc_length_postpeak=250):

    t = [ [] for i in filters]
    m_app = [ [] for i in filters]

    band = csv_data.passband[ii:ii+oid_count_elem]
    
    N_specified_bands = 0

    #print('index', ii)
    
    for i, f in enumerate(filters):

        band_id = np.where(band == f)

        t[i] = np.array(csv_data.mjd.iloc[band_id[0]])
        m_app[i] = np.array(csv_data.flux.iloc[band_id[0]])

    t_max = csv_data_meta.true_peakmjd[ii]
    #print(t_max)

    for i in range(len(filters)):
        t_duration = np.where((t[i] > (t_max + lc_length_prepeak)) & (t[i] < (t_max + lc_length_postpeak)))
        if t_duration[0].shape[0] == 0:
                return False
        else:
            N_specified_bands += t_duration[0].shape[0]

    #print(N_specified_bands)
    if N_specified_bands > num:
        try:
            csv_data_meta.true_distmod[ii]
            csv_data_meta.true_z[ii]
            return True
        except Exception:
            return False
    else:
        #print('tiny', ii)
        return False

In [420]:
class LC_Preprocess:

    def __init__(self, ii, oid, oid_count_elem, filters, 
        csv_data, csv_data_meta):
        
        self.ii = ii
        self.oid = oid
        self.oid_count_elem = oid_count_elem

        self.filters = filters

        self.data = csv_data
        self.data_meta = csv_data_meta

        self.band = self.data.passband[ii:ii+oid_count_elem]
        self.claimedtype = 0

        self.t = [ [] for filter in self.filters]
        self.flux = [ [] for filter in self.filters]
        self.flux_err = [ [] for filter in self.filters]
        self.m = [ [] for filter in self.filters]
        self.m_err = [ [] for filter in self.filters]

    def peak_alignment(self, lc_length_prepeak=-200, lc_length_postpeak=200):

        self.t_peak = self.data_meta.true_peakmjd[self.ii]

        for i, f in enumerate(self.filters):

            self.t[i] = np.array(self.t[i]) - self.t_peak

            self.t[i]     = np.delete(self.t[i], np.where(self.t[i] > lc_length_postpeak))
            self.m[i]     = self.m[i][0:len(self.t[i])]
            self.m_err[i] = self.m_err[i][0:len(self.t[i])]

            self.t[i]     = np.delete(self.t[i], np.where(self.t[i] < lc_length_prepeak))
            self.m[i]     = self.m[i][len(self.m[i]) - len(self.t[i]):]
            self.m_err[i] = self.m_err[i][len(self.m_err[i]) - len(self.t[i]):]

            if (len(self.t[i]) - len(self.m[i])) != 0:
                print('bruh')

        return self.t, self.m, self.m_err, self.claimedtype

    def lc_graph(self, colors = ['darkcyan', 'limegreen', 'crimson'], lc_length_prepeak=-200, lc_length_postpeak=200):
        
        plt.plot(figsize=(16,12))

        for i, filter in enumerate(self.filters):
            plt.errorbar(self.t[i], self.m[i], self.m_err[i], label=filter, color=colors[i], fmt='.')
        
        plt.title(f'{self.oid}, {self.claimedtype}')
        plt.xlim(lc_length_prepeak, lc_length_postpeak)
        plt.ylim(-23, -14)
        plt.xlabel('time (day)')
        plt.ylabel('absolute magnitude')
        plt.legend()
        plt.grid()
        plt.gca().invert_yaxis()
        plt.savefig(f'/home/ricky/RNNAE/import_graph/{self.oid}.pdf')
        #plt.savefig(fr'C:\\Users\\ricky\\FYP\\RNNAE_public\\import_graph\\{self.SN_name}.pdf')
        plt.clf()

    def lc_extractor(self, **kwargs):
        
        dist_mod = float(self.data_meta.true_distmod[self.ii])

        z = float(self.data_meta.true_z[self.ii])

        self.claimedtype = self.data_meta.true_target[self.ii]

        f_min = 0

        for i, f in enumerate(self.filters):

            band_id = np.where(self.band == f)

            self.t[i] = np.array(self.data.mjd.iloc[band_id[0]])
            self.flux[i] = np.array(self.data.flux.iloc[band_id[0]])
            self.flux_err[i] = np.array(self.data.flux_err.iloc[band_id[0]])
            
            if np.min(self.flux[i]) < f_min:
                f_min = np.min(self.flux[i])

        for i, f in enumerate(self.filters):

            self.flux[i] -= f_min
            self.flux[i] = np.where(self.flux[i] == 0, 1e-6, self.flux[i])
            self.m[i] = self.flux[i]
            self.m_err[i] = self.flux_err[i]
            '''self.m[i] = -2.5*np.log10(self.flux[i]) - dist_mod + 2.5*np.log10(1+z) + 27.5
            self.flux[i] += f_min
            self.m_err[i] = 2.5*0.434*(np.absolute(self.flux_err[i]/(self.flux[i])))'''
        
        print(np.array(self.flux_err)/np.array(self.flux))

        if kwargs['peak_alignment']:
            LC_Preprocess.peak_alignment(self)

        if kwargs['LC_graph']:
            LC_Preprocess.lc_graph(self)

        return self.t, self.m, self.m_err, self.claimedtype, self.oid

In [421]:
def main():

    min = np.min(data.object_id)
    max = np.max(data.object_id)

    min_id = list(data_meta[data_meta.object_id == min].object_id.index)[0]
    max_id = list(data_meta[data_meta.object_id == max].object_id.index)[0]

    value, counts = np.unique(data.object_id, return_counts=True)

    t_all = []
    m_all = []
    m_err_all = []
    claimedtype_all = []
    SN_name_all = []

    filters_all = [1, 2, 3]
    num_extracted_SN = 0

    print('Screening and extracting SNe ...')
    
    for ii, oid in tqdm(enumerate(data_meta.object_id[min_id:min_id+4])):

        csv_QC1 = avoid_non_SNIa(ii, data_meta)
        #print(csv_QC1)
        csv_QC2 = avoid_empty_SN(ii, oid, counts[ii], filters_all, 
                    data, data_meta,
                    num=10, lc_length_prepeak=-200, lc_length_postpeak=200)
        
        '''if csv_QC2 is True:
            print(csv_QC2)'''

        '''if csv_QC1:
            num_extracted_SN += 1'''
            
        if (csv_QC1 and csv_QC2):

            LC_result = LC_Preprocess(ii, oid, counts[ii], filters_all, data, data_meta).lc_extractor(peak_alignment=True, LC_graph=True)
            
            t_all.append(LC_result[0])
            m_all.append(LC_result[1])
            m_err_all.append(LC_result[2])
            claimedtype_all.append(LC_result[3])
            SN_name_all.append(LC_result[4])
            
            num_extracted_SN += 1

    os.chdir('/home/ricky/RNNAE/import_npy')
    #os.chdir(r'C:\\Users\\ricky\\FYP\\RNNAE_public\\import_npy')
    print('The current working directory is', os.getcwd())

    np.save('Time_all.npy', np.array(t_all, dtype=object))
    np.save('Magnitude_Abs_all.npy', np.array(m_all, dtype=object))
    np.save('Magnitude_Abs_err_all.npy', np.array(m_err_all, dtype=object))
    np.save('Type_all.npy', np.array(claimedtype_all))
    np.save('SN_name.npy', np.array(SN_name_all))

    print(f'There are {num_extracted_SN} extracted SNe')
    print('End of import.py')

In [422]:
if __name__ == '__main__':
    main()

Screening and extracting SNe ...


0it [00:00, ?it/s]

[array([0.21246681, 0.60527556, 2.19857904, 0.37115121, 0.1021621 ,
       0.08307151, 0.19051656, 0.03123038, 0.54651137, 0.09574328,
       0.18155912, 0.06959933, 0.10904542])
 array([0.29591049, 0.20219412, 0.05937335, 0.26996102, 0.05885232,
       0.47409071, 0.60241718, 0.08017243, 0.64183649, 0.0472214 ,
       0.56763024, 0.32103765, 0.16346734, 0.07327796, 1.10648186,
       0.28794751, 0.14623407, 0.49281993, 0.19628323, 0.40619443,
       0.07347621, 0.49464915])
 array([6.35933887e-01, 5.54827840e-01, 1.68292760e-01, 2.02157989e-01,
       5.28017376e-02, 6.77579755e-02, 4.17795109e-01, 2.94591880e+07,
       3.42530897e-01, 1.12129196e-01, 1.70566569e-01, 3.80394742e-02,
       1.48391236e-01, 5.75731281e-01, 1.52735775e-01, 8.45346029e-02,
       2.47085463e-01, 1.67990172e+00, 5.42424786e-01, 8.53819057e-02,
       1.04674754e-01, 1.47087452e-01, 6.99953506e-02])]
The current working directory is /home/ricky/RNNAE/import_npy
There are 1 extracted SNe
End of import.py


  print(np.array(self.flux_err)/np.array(self.flux))


<Figure size 432x288 with 0 Axes>