In [None]:
import numpy as np
import pandas as pd

def subset_df_columns(df, subset_cols):
    df_cols = list(df.columns)
    return df[[c for c in subset_cols if c in df_cols]]

def set_index(df, index_name):
    if not df.index.name is None and df.index.name==index_name:
        return df
    df_cols = list(df.columns)
    assert index_name in df_cols
    return df.set_index([index_name])

def df_to_float32(df):
    for c in df.columns:
        if df[c].dtype=='float64':
            df[c] = df[c].astype(np.float32)

In [None]:
import pandas as pd
import numpy as np

load_root_dir = '../../surveys_data'
survey_name = 'PLAsTiCCv1'
df_index_names = {
    'oid':'object_id', # object id
    'oid_det':'object_id', # object id
    'label':'target', # object class name
    'ra':'ra',
    'dec':'decl',
    'band':'passband', # band
    'obs_day':'mjd', # days
    'obs':'flux', # observations
    'obs_error':'flux_err', # observation errors
}
subset_columns_names = {
    'labels':['object_id', 'target', 'ra', 'decl'],
    'detections':['object_id', 'passband', 'mjd', 'flux', 'flux_err'],
}
uses_ddf = False # True: similar to ZTF, False: more obs
uses_detected = False

### load files and processing
labels_df = pd.read_csv(f'{load_root_dir}/{survey_name}/training_set_metadata.csv')
if uses_ddf:
    labels_df = labels_df.drop(labels_df[labels_df.ddf==0].index)
print(f'labels - columns: {list(labels_df.columns)} - id: {labels_df.index.name}')
labels_df = subset_df_columns(labels_df, subset_columns_names['labels']) # sub sample columns
labels_df = set_index(labels_df, df_index_names['oid']) # set index
labels_df = labels_df.astype({df_index_names['label']:str})

detections_df = pd.read_csv(f'{load_root_dir}/{survey_name}/training_set.csv')
#detections_df = pd.read_csv(f'{load_root_dir}/{survey_name}/test_set_sample.csv')
if uses_detected:
    detections_df = detections_df.drop(detections_df[detections_df.detected==0].index)
print(f'detections_df - columns: {list(detections_df.columns)} - id: {detections_df.index.name}')
detections_df = subset_df_columns(detections_df, subset_columns_names['detections']) # sub sample columns
detections_df = set_index(detections_df, df_index_names['oid_det']) # set index
detections_df.index.rename(df_index_names['oid'], inplace=True)
df_to_float32(detections_df)

### print info
classes = np.unique(labels_df[df_index_names['label']].values)
print('classes:', classes)

In [None]:
labels_df.info()
labels_df[:10]

In [None]:
detections_df.info()
detections_df[:10]

In [None]:
import sys
sys.path.append('../')
sys.path.append('../../')

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
from src.curve_dictionary_utils import LightCurveDictionaryCreator

band_dictionary = {
    'u':0,
    'g':1,
    'r':2,
    'i':3,
    'z':4,
    'y':5,
}
label_to_class_dict = {
    '90':'SNIa',
    '67':'SNIa-91bg',
    '52':'SNIax',
    '42':'SNII',
    '62':'SNIbc',
    '95':'SLSN-I',
    '15':'TDE',
    '64':'KN',
    '88':'AGN',
    '92':'RRL',
    '65':'M-dwarf',
    '16':'EB',
    '53':'Mira',
    '6':'Lens-Single',
    '991':'Lens-Binary',
    '992':'ILOT',
    '993':'CaRT',
    '994':'PISN',
    '995':'Lens-String',
}
args = [survey_name, detections_df, labels_df, band_dictionary, df_index_names]
kwargs = {
    'obs_is_flux':True,
    'remove_negative_fluxes':True,
    #'maximum_samples_per_class':5000,
    'label_to_class_dict':label_to_class_dict,
}
lcDictionaryCreator = LightCurveDictionaryCreator(*args, **kwargs)
lcDictionaryCreator.plot_class_distribution(figsize=(12,4), uses_log_scale=True)

In [None]:
DF_INVALID_CLASSES = ['Lens-Single', 'KN']
DF_SN_LIST = ['SNII', 'SNIax', 'SNIbc', 'SNIa-91bg', 'SNIa', 'SLSN-I']

#mode = 'raw'
#mode = 'simple'
#mode = 'transients'
#mode = 'RRCeph'
mode = 'onlySN'
#mode = 'onlySNIa'

if mode=='simple':
    assert 0
    invalid_classes = DF_INVALID_CLASSES
    query_classes = []
    to_merge_classes_dic = {'SN':DF_SN_LIST,}
    
elif mode=='onlySN':
    invalid_classes = DF_INVALID_CLASSES
    query_classes = DF_SN_LIST
    to_merge_classes_dic = {}
    to_merge_classes_dic = {'merSNIa':['SNIax', 'SNIa-91bg', 'SNIa'],}
    
elif mode=='onlySNIa':
    assert 0
    invalid_classes = DF_INVALID_CLASSES
    query_classes = DF_SN_LIST
    to_merge_classes_dic = {'nonSNIa':['SLSN' ,'SNII', 'SNIIb', 'SNIIn', 'SNIbc'],}

elif mode=='RRCeph':
    assert 0
    invalid_classes = DF_INVALID_CLASSES
    query_classes = ['RRL','Ceph']
    to_merge_classes_dic = {}

lcDictionaryCreator.update_labels_df(invalid_classes, query_classes, to_merge_classes_dic)
lcDictionaryCreator.plot_class_distribution(uses_log_scale=True)

In [None]:
%load_ext autoreload
%autoreload 2
from src import C_

description = 'PLAsTiCC'
save_folder = f'../save/{survey_name}'
filename_extra_parameters = {
    'mode':mode,
}
kwargs = {
    'to_export_bands':list(band_dictionary.keys()),
    'to_export_bands':['g','r'],
    #'SCPD_probs':C_.DEFAULT_SCPD_PS,
    'filename_extra_parameters':filename_extra_parameters,
    'saves_every':1e5,
}
raw_lcdataset = lcDictionaryCreator.export_dictionary(description, save_folder, **kwargs)

In [None]:
import fuzzytools.myUtils.lists as lists
import matplotlib.pyplot as plt
from src.plots import plot_lightcurve

lcobj, key = raw_lcdataset.raw.get_random_lcobj()
#print(lcobj)
fig, ax = plt.subplots(1,1)
for kb,b in enumerate(raw_lcdataset.raw.band_names):
    plot_lightcurve(ax, lcobj, b)