# ORACLE Tutorial

**References:**

arXiv: https://arxiv.org/abs/2501.01496 \
Github: https://github.com/uiucsn/ELAsTiCC-Classification

## Imports

In [1]:
import tqdm

from astropy.table import Table

from astroOracle.pretrained_models import ORACLE, ORACLE_lite

## Read in the data

First we need to read in the data. Here, we read in an astropy table that has both the light curve and meta data.

In [2]:
table=Table.read('AGN_17032813.ecsv')
table

MJD,BAND,FLUXCAL,FLUXCALERR,PHOTFLAG
float64,str1,float32,float32,int32
0.0,z,29.873867,14.653912,0
0.024899999996705443,i,-129.364,11.491417,6144
10.001400000001013,u,32.23296,12.735434,0
10.025099999998929,r,770.2578,9.1456175,4096
11.997699999999895,u,8.7568035,5.6472173,0
12.020400000001246,r,758.1768,8.05634,4096
13.921799999996438,i,-131.09286,7.7958207,4096
13.933599999996659,i,-127.39846,4.635382,4096
13.935399999994843,r,572.88477,5.6019444,4096
...,...,...,...,...


The time series data **MUST** have all these columns for the classifier to work correctly. Please read the SNANA documentation and ORACLE paper for more details.

In [3]:
table.meta

OrderedDict([('RA', 8.254658693564833),
             ('DEC', -19.237007751492673),
             ('MWEBV', 0.019017742946743965),
             ('MWEBV_ERR', 0.0009508871589787304),
             ('REDSHIFT_HELIO', 0.2852136790752411),
             ('REDSHIFT_HELIO_ERR', 0.29243001341819763),
             ('VPEC', 0.0),
             ('VPEC_ERR', 300.0),
             ('HOSTGAL_PHOTOZ', 0.2852136790752411),
             ('HOSTGAL_PHOTOZ_ERR', 0.29243001341819763),
             ('HOSTGAL_SPECZ', -9.0),
             ('HOSTGAL_SPECZ_ERR', -9.0),
             ('HOSTGAL_RA', 8.25465074942761),
             ('HOSTGAL_DEC', -19.2370160461322),
             ('HOSTGAL_SNSEP', 0.040309689939022064),
             ('HOSTGAL_DDLR', 0.06043460965156555),
             ('HOSTGAL_LOGMASS', 7.819200038909912),
             ('HOSTGAL_LOGMASS_ERR', -9999.0),
             ('HOSTGAL_LOGSFR', -9999.0),
             ('HOSTGAL_LOGSFR_ERR', -9999.0),
             ('HOSTGAL_LOGsSFR', -9999.0),
             ('HOSTGAL_

The meta data is useful, but not necessary for classification. You can also read in just the light curve data and use Oracle-lite instead (details below).

If you are interested in classifying a large batch of sources (in SNANA fits format), astroOracle.dataloader has tools for this purpose. Please look through astroOracle.trainRNN for example usage. 

## Load the model

Next we load the model

In [4]:
model = ORACLE('../models/lsst_alpha_0.5/best_model.h5')

Model loaded from ../models/lsst_alpha_0.5/best_model.h5


## Run Inference

In [5]:
model.predict([table.to_pandas()], [table.meta])

TS Processing: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 243.32it/s]
Static Processing: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 15420.24it/s]


Unnamed: 0,Alert,Transient,Variable,SN,Fast,Long,Periodic,AGN,SNIa,SNIb/c,...,M-dwarf Flare,SLSN,TDE,ILOT,CART,PISN,Cepheid,RR Lyrae,Delta Scuti,EB
0,1.0,1.649743e-11,1.0,1.6994e-14,1.448357e-25,1.648044e-11,5.459642e-18,1.0,1.227276e-19,2.582632e-21,...,1.093013e-43,1.647708e-11,1.948153e-24,1.017798e-25,2.693979e-30,3.359413e-15,1.046392e-35,5.476987e-20,1.4737530000000001e-27,5.404873e-18


This gives us the class probabilites for each class in the taxonomy. We can also directly get the class label

In [6]:
model.predict_classes([table.to_pandas()], [table.meta])

TS Processing: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 599.87it/s]
Static Processing: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 30393.51it/s]


array(['AGN'], dtype='<U13')

Similarly, we can also classify a list of sources together. 

In [7]:
model.predict([table.to_pandas(), table.to_pandas()], [table.meta, table.meta])

TS Processing: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 729.38it/s]
Static Processing: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 62137.84it/s]


Unnamed: 0,Alert,Transient,Variable,SN,Fast,Long,Periodic,AGN,SNIa,SNIb/c,...,M-dwarf Flare,SLSN,TDE,ILOT,CART,PISN,Cepheid,RR Lyrae,Delta Scuti,EB
0,1.0,1.649743e-11,1.0,1.6994e-14,1.448357e-25,1.648044e-11,5.459642e-18,1.0,1.227276e-19,2.582632e-21,...,1.093013e-43,1.647708e-11,1.948153e-24,1.017798e-25,2.693979e-30,3.359413e-15,1.046392e-35,5.476987e-20,1.4737530000000001e-27,5.404873e-18
1,1.0,1.649743e-11,1.0,1.6994e-14,1.448357e-25,1.648044e-11,5.459642e-18,1.0,1.227276e-19,2.582632e-21,...,1.093013e-43,1.647708e-11,1.948153e-24,1.017798e-25,2.693979e-30,3.359413e-15,1.046392e-35,5.476987e-20,1.4737530000000001e-27,5.404873e-18


## Try Oracle Lite

Sometimes, we may not have access to meta data for classification. For this purpose, we trained another version of the model called Oracle-lite. Oracle-lite is capable of doing the same real time, hierarchichal classification as the full Oracle model, although the full model maintains superior performance.

In [8]:
model = ORACLE_lite('../models/lsst_alpha_0.5_no_md/best_model.h5')

Model loaded from ../models/lsst_alpha_0.5_no_md/best_model.h5


The API for the two models is the same with the obvious exception of the meta data

In [9]:
model.predict([table.to_pandas()])

TS Processing: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 652.71it/s]


Unnamed: 0,Alert,Transient,Variable,SN,Fast,Long,Periodic,AGN,SNIa,SNIb/c,...,M-dwarf Flare,SLSN,TDE,ILOT,CART,PISN,Cepheid,RR Lyrae,Delta Scuti,EB
0,1.0,1.94686e-08,1.0,6.914028e-16,1.932542e-08,1.431834e-10,0.004392,0.995608,2.925498e-21,1.472778e-17,...,6.633189000000001e-23,3.174507e-16,1.838745e-19,1.43177e-10,4.029926e-17,6.007277e-15,0.002556,3.2e-05,0.000664,0.00114


In [10]:
model.predict_classes([table.to_pandas()])

TS Processing: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 625.46it/s]


array(['AGN'], dtype='<U13')

In [11]:
model.predict_classes([table.to_pandas(), table.to_pandas()])

TS Processing: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 611.59it/s]


array(['AGN', 'AGN'], dtype='<U13')