Skip to content

Commit

Permalink
Merge pull request #8 from supernnova/onthefly
Browse files Browse the repository at this point in the history
Implementing on the fly classification
  • Loading branch information
anaismoller committed Apr 22, 2020
2 parents 00ff200 + 5d8f31d commit 6ba7947
Show file tree
Hide file tree
Showing 14 changed files with 958 additions and 106 deletions.
64 changes: 64 additions & 0 deletions planning_for_onthefly.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# INPUT
# Beware, there may be missing filters

# option 1: Input format = lists
mjd = [57433.4816, 57436.4815]
fluxes = [2.0, 3]
fluxerrs = [0.1, 0.2]
passbands = ["g", "r"]
SNID = "1"
redshift_zspe = 0.12
redshift_zpho = 0.1
redshift = 0.12
# redshift can be given either as zpho/zspe of global one and use zpho model

# option 2: small pandas dataframe
df = pd.DataFrame()
df["mjd"] = [57433.4816, 57436.4815, 33444, 33454]
df["fluxes"] = [2.0, 3, 200, 300]
df["fluxerrs"] = [0.1, 0.2, 0.1, 0.2]
df["passbands"] = ["g", "r", "g", "r"]
df["SNID"] = ["1", "1", "2", "2"]
df["redshift_zspe"] = [0.12, 0.12, 0.5, 0.5]
df["redshift_zpho"] = [0.1, 0.1, 0.5, 0.5]
df["redshift"] = [0.12, 0.12, 0.5, 0.5]

# VALIDATE
# Important: want classification output directly here, not in a file.

# USAGE
import supernnova.conf as conf
from supernnova.data import ontheflydata
from supernnova.validation import validate_rnn_onthefly

# read data
ontheflydata(df) or ontheflydata(mjd, fluxes, fluxerrs, redshift)

# get config args
args = conf.get_args()
args.validate_rnn = False # conf: validate rnn
args.model_files = "model_file" # conf: model file to load
settings = conf.get_settings(args) # conf: set settings
preds = validate_rnn_onthefly.get_predictions(settings) # classify test set

# output format, list with predictions
[0.5, 0.6]


# to check predictions you can use early_predictions and save lcs as df
arr_flux = []
arr_fluxerr = []
arr_flt = []
arr_MJD = []
for flt in ["g", "r", "i", "z"]:
arr_flux += df_temp[f"FLUXCAL_{flt}"].values.tolist()
arr_fluxerr += df_temp[f"FLUXCALERR_{flt}"].values.tolist()
arr_MJD += arr_time.tolist()
arr_flt += flt * len(df_temp[f"FLUXCAL_{flt}"].values.tolist())
aaa = pd.DataFrame()
aaa["FLUXCAL"] = arr_flux
aaa["FLUXCALERR"] = arr_fluxerr
aaa["FLT"] = arr_flt
aaa["MJD"] = arr_MJD
aaa["SNID"] = np.ones(len(aaa)).astype(int).astype(str)
aaa.to_csv("tmp_cl.csv")
142 changes: 142 additions & 0 deletions run_onthefly.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import argparse
import numpy as np
import pandas as pd
from pathlib import Path
import supernnova.utils.logging_utils as lu
from supernnova.validation.validate_onthefly import classify_lcs

"""
Example code on how to run on the fly classifications
- Need to laod a pre-trained model
- Either provide a list with data or a Pandas DataFrame
"""

# Data columns to provide
# HOSTGAL redshifts only required if classification with redshift is used
COLUMN_NAMES = [
"SNID",
"MJD",
"FLUXCAL",
"FLUXCALERR",
"FLT",
"HOSTGAL_PHOTOZ",
"HOSTGAL_SPECZ",
"HOSTGAL_PHOTOZ_ERR",
"HOSTGAL_SPECZ_ERR",
]


def manual_lc():
"""Manually provide data
"""
# this is the format you can use to provide light-curves
df = pd.DataFrame()
# supernova IDs
df["SNID"] = ["1", "1", "2", "2"]
# time in MJD
df["MJD"] = [57433.4816, 57436.4815, 33444, 33454]
# FLux and errors
df["FLUXCAL"] = [2.0, 3, 200, 300]
df["FLUXCALERR"] = [0.1, 0.2, 0.1, 0.2]
# bandpasses
df["FLT"] = ["g", "r", "g", "r"]
# redshift is not required if classifying without it
df["HOSTGAL_SPECZ"] = [0.12, 0.12, 0.5, 0.5]
df["HOSTGAL_PHOTOZ"] = [0.1, 0.1, 0.5, 0.5]
df["HOSTGAL_SPECZ_ERR"] = [0.001, 0.001, 0.001, 0.001]
df["HOSTGAL_PHOTOZ_ERR"] = [0.01, 0.01, 0.01, 0.01]

return df


def load_lc_csv(filename):
"""Read light-curve(s) in csv format
Args:
filename (str): data file
"""
df = pd.read_csv(filename)

missing_cols = [k for k in COLUMN_NAMES if k not in df.keys()]
lu.print_red(f"Missing {len(missing_cols)} columns", missing_cols)
lu.print_yellow(f"filling with zeros")
lu.print_yellow(f"HOSTGAL required only for classification with redshift")
for k in missing_cols:
df[k] = np.zeros(len(df))
df = df.sort_values(by=["MJD"])
df["SNID"] = df["SNID"].astype(int).astype(str)

return df


def reformat_to_df(pred_probs, ids=None):
"""
"""
# TO DO: suppport nb_inference != 1
num_inference_samples = 1

d_series = {}
for i in range(pred_probs[0].shape[1]):
d_series["SNID"] = []
d_series[f"prob_class{i}"] = []
for idx, value in enumerate(pred_probs):
d_series["SNID"] += ids[idx] if len(ids) > 0 else idx
value = value.reshape((num_inference_samples, -1))
value_dim = value.shape[1]
for i in range(value_dim):
d_series[f"prob_class{i}"].append(value[:, i][0])
preds_df = pd.DataFrame.from_dict(d_series)

# get predicted class
try:
preds_df["pred_class"] = np.argmax(pred_probs, axis=0)[0]
except Exception:
preds_df["pred_class"] = np.argmax(pred_probs[0], axis=0)[0]

return preds_df


if __name__ == "__main__":
""" Wrapper to get predictions on the fly with SNN
"""

parser = argparse.ArgumentParser(
description="Classification using pre-trained model"
)
parser.add_argument(
"--model_file",
type=str,
default="tests/onthefly_model/vanilla_S_0_CLF_2_R_none_photometry_DF_1.0_N_global_lstm_32x2_0.05_128_True_mean.pt",
help="path to pre-trained SuperNNova model",
)
parser.add_argument(
"--device", type=str, default="cpu", help="device to be used [cuda,cpu]"
)
parser.add_argument(
"--filename",
type=str,
default="tests/onthefly_lc/example_lc.csv",
help="device to be used [cuda,cpu]",
)

args = parser.parse_args()

# Input data
# options: csv or manual data, choose one
df = load_lc_csv(args.filename)
# df = manual_lc()

# Obtain predictions for full light-curve
# Format: batch, nb_inference_samples, nb_classes
pred_probs = classify_lcs(df, args.model_file, args.device)

# ________________________
# Optional
#
# reformat to df
preds_df = reformat_to_df(pred_probs, ids=df.SNID.unique())
preds_df.to_csv(f"Predictions_{Path(args.filename).name}")

# To implement
# Early prediction visualization
9 changes: 9 additions & 0 deletions supernnova/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,15 @@ def get_args():
"--done_file", default=None, type=str, help="Done or failure file name"
)

parser.add_argument(
"--no_dump", action="store_true", help="No dump database nor preds"
)
parser.add_argument(
"--debug",
action="store_true",
help="Debug database creation: one file processed only",
)

#######################
# PLASTICC parameters
#######################
Expand Down

0 comments on commit 6ba7947

Please sign in to comment.