# Starter Code to use ORACLE

## Imports

In [1]:
import numpy as np
import networkx as nx
import pandas as pd

from tensorflow import keras

from taxonomy import get_taxonomy_tree, source_node_label
from interpret_results import get_conditional_probabilites

## Load taxonomy and get ordering

In [2]:
tree = get_taxonomy_tree()
level_order_nodes = list(nx.bfs_tree(tree, source=source_node_label).nodes())
print(level_order_nodes)

['Alert', 'Transient', 'Variable', 'SN', 'Fast', 'Long', 'Periodic', 'AGN', 'SNIa', 'SNIb/c', 'SNIax', 'SNI91bg', 'SNII', 'KN', 'Dwarf Novae', 'uLens', 'M-dwarf Flare', 'SLSN', 'TDE', 'ILOT', 'CART', 'PISN', 'Cepheid', 'RR Lyrae', 'Delta Scuti', 'EB']


## Creating some fake data to run inference

In [3]:
batch_size = 10
x_static = np.ones((batch_size, 23))
x_ts = np.ones((batch_size, 500, 5))

## Run Inference

In [4]:
model = keras.models.load_model(f"models/lsst_alpha_0.5/best_model.h5", compile=False)
logits = model.predict([x_ts, x_static])

# Processing step to interpret NN output as a tree
_, pseudo_probabilities = get_conditional_probabilites(logits, tree)



In [5]:
# Optional - Display results as a DF
results = pd.DataFrame(pseudo_probabilities, columns=level_order_nodes)
results

Unnamed: 0,Alert,Transient,Variable,SN,Fast,Long,Periodic,AGN,SNIa,SNIb/c,...,M-dwarf Flare,SLSN,TDE,ILOT,CART,PISN,Cepheid,RR Lyrae,Delta Scuti,EB
0,1.0,0.991744,0.008256,0.048972,0.942675,9.8e-05,1.5e-05,0.008241,0.004791,0.005385,...,9.811335e-09,9.7e-05,7.441076e-07,1.34957e-10,6.837318e-12,3.807362e-09,6.04695e-12,1.316973e-09,1.743679e-08,1.5e-05
1,1.0,0.991744,0.008256,0.048972,0.942675,9.8e-05,1.5e-05,0.008241,0.004791,0.005385,...,9.811335e-09,9.7e-05,7.441076e-07,1.34957e-10,6.837318e-12,3.807362e-09,6.04695e-12,1.316973e-09,1.743679e-08,1.5e-05
2,1.0,0.991744,0.008256,0.048972,0.942675,9.8e-05,1.5e-05,0.008241,0.004791,0.005385,...,9.811335e-09,9.7e-05,7.441076e-07,1.34957e-10,6.837318e-12,3.807362e-09,6.04695e-12,1.316973e-09,1.743679e-08,1.5e-05
3,1.0,0.991744,0.008256,0.048972,0.942675,9.8e-05,1.5e-05,0.008241,0.004791,0.005385,...,9.811335e-09,9.7e-05,7.441076e-07,1.34957e-10,6.837318e-12,3.807362e-09,6.04695e-12,1.316973e-09,1.743679e-08,1.5e-05
4,1.0,0.991744,0.008256,0.048972,0.942675,9.8e-05,1.5e-05,0.008241,0.004791,0.005385,...,9.811335e-09,9.7e-05,7.441076e-07,1.34957e-10,6.837318e-12,3.807362e-09,6.04695e-12,1.316973e-09,1.743679e-08,1.5e-05
5,1.0,0.991744,0.008256,0.048972,0.942675,9.8e-05,1.5e-05,0.008241,0.004791,0.005385,...,9.811335e-09,9.7e-05,7.441076e-07,1.34957e-10,6.837318e-12,3.807362e-09,6.04695e-12,1.316973e-09,1.743679e-08,1.5e-05
6,1.0,0.991744,0.008256,0.048972,0.942675,9.8e-05,1.5e-05,0.008241,0.004791,0.005385,...,9.811335e-09,9.7e-05,7.441076e-07,1.34957e-10,6.837318e-12,3.807362e-09,6.04695e-12,1.316973e-09,1.743679e-08,1.5e-05
7,1.0,0.991744,0.008256,0.048972,0.942675,9.8e-05,1.5e-05,0.008241,0.004791,0.005385,...,9.811335e-09,9.7e-05,7.441076e-07,1.34957e-10,6.837318e-12,3.807362e-09,6.04695e-12,1.316973e-09,1.743679e-08,1.5e-05
8,1.0,0.991744,0.008256,0.048972,0.942675,9.8e-05,1.5e-05,0.008241,0.004791,0.005385,...,9.811335e-09,9.7e-05,7.441076e-07,1.34957e-10,6.837318e-12,3.807362e-09,6.04695e-12,1.316973e-09,1.743679e-08,1.5e-05
9,1.0,0.991744,0.008256,0.048972,0.942675,9.8e-05,1.5e-05,0.008241,0.004791,0.005385,...,9.811335e-09,9.7e-05,7.441076e-07,1.34957e-10,6.837318e-12,3.807362e-09,6.04695e-12,1.316973e-09,1.743679e-08,1.5e-05
