-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #8 from supernnova/onthefly
Implementing on the fly classification
- Loading branch information
Showing
14 changed files
with
958 additions
and
106 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# INPUT | ||
# Beware, there may be missing filters | ||
|
||
# option 1: Input format = lists | ||
mjd = [57433.4816, 57436.4815] | ||
fluxes = [2.0, 3] | ||
fluxerrs = [0.1, 0.2] | ||
passbands = ["g", "r"] | ||
SNID = "1" | ||
redshift_zspe = 0.12 | ||
redshift_zpho = 0.1 | ||
redshift = 0.12 | ||
# redshift can be given either as zpho/zspe of global one and use zpho model | ||
|
||
# option 2: small pandas dataframe | ||
df = pd.DataFrame() | ||
df["mjd"] = [57433.4816, 57436.4815, 33444, 33454] | ||
df["fluxes"] = [2.0, 3, 200, 300] | ||
df["fluxerrs"] = [0.1, 0.2, 0.1, 0.2] | ||
df["passbands"] = ["g", "r", "g", "r"] | ||
df["SNID"] = ["1", "1", "2", "2"] | ||
df["redshift_zspe"] = [0.12, 0.12, 0.5, 0.5] | ||
df["redshift_zpho"] = [0.1, 0.1, 0.5, 0.5] | ||
df["redshift"] = [0.12, 0.12, 0.5, 0.5] | ||
|
||
# VALIDATE | ||
# Important: want classification output directly here, not in a file. | ||
|
||
# USAGE | ||
import supernnova.conf as conf | ||
from supernnova.data import ontheflydata | ||
from supernnova.validation import validate_rnn_onthefly | ||
|
||
# read data | ||
ontheflydata(df) or ontheflydata(mjd, fluxes, fluxerrs, redshift) | ||
|
||
# get config args | ||
args = conf.get_args() | ||
args.validate_rnn = False # conf: validate rnn | ||
args.model_files = "model_file" # conf: model file to load | ||
settings = conf.get_settings(args) # conf: set settings | ||
preds = validate_rnn_onthefly.get_predictions(settings) # classify test set | ||
|
||
# output format, list with predictions | ||
[0.5, 0.6] | ||
|
||
|
||
# to check predictions you can use early_predictions and save lcs as df | ||
arr_flux = [] | ||
arr_fluxerr = [] | ||
arr_flt = [] | ||
arr_MJD = [] | ||
for flt in ["g", "r", "i", "z"]: | ||
arr_flux += df_temp[f"FLUXCAL_{flt}"].values.tolist() | ||
arr_fluxerr += df_temp[f"FLUXCALERR_{flt}"].values.tolist() | ||
arr_MJD += arr_time.tolist() | ||
arr_flt += flt * len(df_temp[f"FLUXCAL_{flt}"].values.tolist()) | ||
aaa = pd.DataFrame() | ||
aaa["FLUXCAL"] = arr_flux | ||
aaa["FLUXCALERR"] = arr_fluxerr | ||
aaa["FLT"] = arr_flt | ||
aaa["MJD"] = arr_MJD | ||
aaa["SNID"] = np.ones(len(aaa)).astype(int).astype(str) | ||
aaa.to_csv("tmp_cl.csv") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
import argparse | ||
import numpy as np | ||
import pandas as pd | ||
from pathlib import Path | ||
import supernnova.utils.logging_utils as lu | ||
from supernnova.validation.validate_onthefly import classify_lcs | ||
|
||
""" | ||
Example code on how to run on the fly classifications | ||
- Need to laod a pre-trained model | ||
- Either provide a list with data or a Pandas DataFrame | ||
""" | ||
|
||
# Data columns to provide | ||
# HOSTGAL redshifts only required if classification with redshift is used | ||
COLUMN_NAMES = [ | ||
"SNID", | ||
"MJD", | ||
"FLUXCAL", | ||
"FLUXCALERR", | ||
"FLT", | ||
"HOSTGAL_PHOTOZ", | ||
"HOSTGAL_SPECZ", | ||
"HOSTGAL_PHOTOZ_ERR", | ||
"HOSTGAL_SPECZ_ERR", | ||
] | ||
|
||
|
||
def manual_lc(): | ||
"""Manually provide data | ||
""" | ||
# this is the format you can use to provide light-curves | ||
df = pd.DataFrame() | ||
# supernova IDs | ||
df["SNID"] = ["1", "1", "2", "2"] | ||
# time in MJD | ||
df["MJD"] = [57433.4816, 57436.4815, 33444, 33454] | ||
# FLux and errors | ||
df["FLUXCAL"] = [2.0, 3, 200, 300] | ||
df["FLUXCALERR"] = [0.1, 0.2, 0.1, 0.2] | ||
# bandpasses | ||
df["FLT"] = ["g", "r", "g", "r"] | ||
# redshift is not required if classifying without it | ||
df["HOSTGAL_SPECZ"] = [0.12, 0.12, 0.5, 0.5] | ||
df["HOSTGAL_PHOTOZ"] = [0.1, 0.1, 0.5, 0.5] | ||
df["HOSTGAL_SPECZ_ERR"] = [0.001, 0.001, 0.001, 0.001] | ||
df["HOSTGAL_PHOTOZ_ERR"] = [0.01, 0.01, 0.01, 0.01] | ||
|
||
return df | ||
|
||
|
||
def load_lc_csv(filename): | ||
"""Read light-curve(s) in csv format | ||
Args: | ||
filename (str): data file | ||
""" | ||
df = pd.read_csv(filename) | ||
|
||
missing_cols = [k for k in COLUMN_NAMES if k not in df.keys()] | ||
lu.print_red(f"Missing {len(missing_cols)} columns", missing_cols) | ||
lu.print_yellow(f"filling with zeros") | ||
lu.print_yellow(f"HOSTGAL required only for classification with redshift") | ||
for k in missing_cols: | ||
df[k] = np.zeros(len(df)) | ||
df = df.sort_values(by=["MJD"]) | ||
df["SNID"] = df["SNID"].astype(int).astype(str) | ||
|
||
return df | ||
|
||
|
||
def reformat_to_df(pred_probs, ids=None): | ||
""" | ||
""" | ||
# TO DO: suppport nb_inference != 1 | ||
num_inference_samples = 1 | ||
|
||
d_series = {} | ||
for i in range(pred_probs[0].shape[1]): | ||
d_series["SNID"] = [] | ||
d_series[f"prob_class{i}"] = [] | ||
for idx, value in enumerate(pred_probs): | ||
d_series["SNID"] += ids[idx] if len(ids) > 0 else idx | ||
value = value.reshape((num_inference_samples, -1)) | ||
value_dim = value.shape[1] | ||
for i in range(value_dim): | ||
d_series[f"prob_class{i}"].append(value[:, i][0]) | ||
preds_df = pd.DataFrame.from_dict(d_series) | ||
|
||
# get predicted class | ||
try: | ||
preds_df["pred_class"] = np.argmax(pred_probs, axis=0)[0] | ||
except Exception: | ||
preds_df["pred_class"] = np.argmax(pred_probs[0], axis=0)[0] | ||
|
||
return preds_df | ||
|
||
|
||
if __name__ == "__main__": | ||
""" Wrapper to get predictions on the fly with SNN | ||
""" | ||
|
||
parser = argparse.ArgumentParser( | ||
description="Classification using pre-trained model" | ||
) | ||
parser.add_argument( | ||
"--model_file", | ||
type=str, | ||
default="tests/onthefly_model/vanilla_S_0_CLF_2_R_none_photometry_DF_1.0_N_global_lstm_32x2_0.05_128_True_mean.pt", | ||
help="path to pre-trained SuperNNova model", | ||
) | ||
parser.add_argument( | ||
"--device", type=str, default="cpu", help="device to be used [cuda,cpu]" | ||
) | ||
parser.add_argument( | ||
"--filename", | ||
type=str, | ||
default="tests/onthefly_lc/example_lc.csv", | ||
help="device to be used [cuda,cpu]", | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
# Input data | ||
# options: csv or manual data, choose one | ||
df = load_lc_csv(args.filename) | ||
# df = manual_lc() | ||
|
||
# Obtain predictions for full light-curve | ||
# Format: batch, nb_inference_samples, nb_classes | ||
pred_probs = classify_lcs(df, args.model_file, args.device) | ||
|
||
# ________________________ | ||
# Optional | ||
# | ||
# reformat to df | ||
preds_df = reformat_to_df(pred_probs, ids=df.SNID.unique()) | ||
preds_df.to_csv(f"Predictions_{Path(args.filename).name}") | ||
|
||
# To implement | ||
# Early prediction visualization |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.