Skip to content

Commit

Permalink
fixed bug sorting on the fly
Browse files Browse the repository at this point in the history
  • Loading branch information
anaismoller committed Jun 2, 2020
1 parent ef37257 commit 22db909
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
10 changes: 6 additions & 4 deletions run_onthefly.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,18 @@ def load_lc_csv(filename):


def reformat_to_df(pred_probs, ids=None):
"""
"""
""" Reformat SNN predictions to a DataFrame
# TO DO: suppport nb_inference != 1
"""
num_inference_samples = 1

d_series = {}
for i in range(pred_probs[0].shape[1]):
d_series["SNID"] = []
d_series[f"prob_class{i}"] = []
for idx, value in enumerate(pred_probs):
d_series["SNID"] += ids[idx] if len(ids) > 0 else idx
d_series["SNID"] += [ids[idx]] if len(ids) > 0 else idx
value = value.reshape((num_inference_samples, -1))
value_dim = value.shape[1]
for i in range(value_dim):
Expand Down Expand Up @@ -126,7 +127,8 @@ def reformat_to_df(pred_probs, ids=None):

# Obtain predictions for full light-curve
# Format: batch, nb_inference_samples, nb_classes
pred_probs = classify_lcs(df, args.model_file, args.device)
# Beware, ids are resorted while obtaining predictions!
ids_preds, pred_probs = classify_lcs(df, args.model_file, args.device)

# ________________________
# Optional
Expand Down
13 changes: 12 additions & 1 deletion supernnova/validation/validate_onthefly.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def format_data(df, settings):
# fill dummies
if "PEAKMJD" not in df.keys():
df["PEAKMJD"] = np.zeros(len(df))

# pivot
df = pivot_dataframe_single_from_df(df, settings)

Expand All @@ -71,6 +72,16 @@ def format_data(df, settings):


def classify_lcs(df, model_file, device):
""" Obtain predictions for light-curves
Args:
df (DataFrame): light-curves to classify
model_file (str): Path+name of model to use for predictions
device (str): wehter to use cuda or cpu
Returns:
idx (list): light-curve indices after classification (they are resorted)
preds (np.array): predictions for this model (shape= len(idx),model_nb_class)
"""
# init
settings = get_settings(model_file)
settings.use_cuda = True if "cuda" in str(device) else False
Expand Down Expand Up @@ -134,4 +145,4 @@ def classify_lcs(df, model_file, device):
# B, inf_samples, nb_classes
preds = np.stack(list_preds, axis=1)

return preds
return df.index.unique(), preds

0 comments on commit 22db909

Please sign in to comment.