Skip to content

Commit

Permalink
fix pred class bug
Browse files Browse the repository at this point in the history
  • Loading branch information
anaismoller committed Apr 23, 2020
1 parent 274061b commit f9f647f
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 7 deletions.
9 changes: 3 additions & 6 deletions run_onthefly.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,7 @@ def reformat_to_df(pred_probs, ids=None):
preds_df = pd.DataFrame.from_dict(d_series)

# get predicted class
try:
preds_df["pred_class"] = np.argmax(pred_probs, axis=0)[0]
except Exception:
preds_df["pred_class"] = np.argmax(pred_probs[0], axis=0)[0]
preds_df["pred_class"] = np.argmax(pred_probs, axis=-1).reshape(-1)

return preds_df

Expand Down Expand Up @@ -124,8 +121,8 @@ def reformat_to_df(pred_probs, ids=None):

# Input data
# options: csv or manual data, choose one
df = load_lc_csv(args.filename)
# df = manual_lc()
# df = load_lc_csv(args.filename)
df = manual_lc()

# Obtain predictions for full light-curve
# Format: batch, nb_inference_samples, nb_classes
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setuptools.setup(
name="supernnova",
version="1.1",
version="1.2",
author="Anais Moller and Thibault de Boissiere",
author_email="anais.moller@clermont.in2p3.fr",
description="framework for Bayesian, Neural Network based supernova light-curve classification",
Expand Down

0 comments on commit f9f647f

Please sign in to comment.