Skip to content

Commit

Permalink
allowing extreme data values (asserts) & safe for unreal peak in metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
anaismoller committed Jul 1, 2019
1 parent 59cdeba commit f4100f3
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 26 deletions.
10 changes: 5 additions & 5 deletions supernnova/utils/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def normalize_arr(arr, settings):
arr_normed = (arr_normed - arr_mean) / arr_std

arr[:, settings.idx_features_to_normalize] = arr_normed

return arr


Expand All @@ -68,6 +67,7 @@ def unnormalize_arr(arr, settings):
arr_mean = settings.arr_norm[:, 1]
arr_std = settings.arr_norm[:, 2]
arr_to_unnorm = arr[:, settings.idx_features_to_normalize]

arr_to_unnorm = arr_to_unnorm * arr_std + arr_mean
arr_unnormed = np.exp(arr_to_unnorm) + arr_min - 1E-5

Expand Down Expand Up @@ -116,14 +116,15 @@ def fill_data_list(

# check if normalization converges
# using clipping in case of min<model_min
X_tmp = unnormalize_arr(normalize_arr(
X_all.copy(), settings), settings)
X_clip = X_all.copy()
X_clip = np.clip(
X_clip[:, settings.idx_features_to_normalize], settings.arr_norm[:, 0], np.inf)
X_all[:, settings.idx_features_to_normalize] = X_clip

X_tmp = unnormalize_arr(normalize_arr(
X_all.copy(), settings), settings)
assert np.all(
np.all(np.isclose(np.ravel(X_all), np.ravel(X_tmp), atol=1e-2)))
np.all(np.isclose(np.ravel(X_all), np.ravel(X_tmp), atol=1e-1)))
# Normalize features that need to be normalized
X_normed = X_all.copy()
X_normed_tmp = normalize_arr(X_normed, settings)
Expand Down Expand Up @@ -369,7 +370,6 @@ def get_data_batch(list_data, idxs, settings, max_lengths=None, OOD=None):
idx_sort = np.argsort(list_len)[::-1]
idxs_rev_sort = np.argsort(idx_sort) # these indices revert the sort
max_len = list_len[idx_sort[0]]

X_tensor = torch.zeros((max_len, len(idxs), input_dim))
list_target = []
lengths = []
Expand Down
39 changes: 20 additions & 19 deletions supernnova/validation/validate_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,30 +232,31 @@ def get_predictions(settings, model_file=None):
oob_idxs = np.where(np.array(slice_idxs) < 1)[0]
inb_idxs = np.where(np.array(slice_idxs) >= 1)[0]

# We only carry out prediction for samples in ``inb_idxs``
offset_batch_idxs = [batch_idxs[b] for b in inb_idxs]
max_lengths = [slice_idxs[b] for b in inb_idxs]
if len(inb_idxs)>0:
# We only carry out prediction for samples in ``inb_idxs``
offset_batch_idxs = [batch_idxs[b] for b in inb_idxs]
max_lengths = [slice_idxs[b] for b in inb_idxs]
lu.print_red('val',len(offset_batch_idxs),max_lengths)
packed, _, target_tensor, idxs_rev_sort = tu.get_data_batch(
list_data_test, offset_batch_idxs, settings, max_lengths=max_lengths
)

packed, _, target_tensor, idxs_rev_sort = tu.get_data_batch(
list_data_test, offset_batch_idxs, settings, max_lengths=max_lengths
)
for iter_ in tqdm(range(settings.num_inference_samples), ncols=100):

for iter_ in tqdm(range(settings.num_inference_samples), ncols=100):
arr_preds, arr_target = get_batch_predictions(
rnn, packed, target_tensor
)

arr_preds, arr_target = get_batch_predictions(
rnn, packed, target_tensor
)

# Rever sorting that occurs in get_batch_predictions
arr_preds = arr_preds[idxs_rev_sort]
# Rever sorting that occurs in get_batch_predictions
arr_preds = arr_preds[idxs_rev_sort]

suffix = str(offset) if offset != 0 else ""
suffix = f"+{suffix}" if offset > 0 else suffix
col = f"PEAKMJD{suffix}"
suffix = str(offset) if offset != 0 else ""
suffix = f"+{suffix}" if offset > 0 else suffix
col = f"PEAKMJD{suffix}"

d_pred[col][start_idx + inb_idxs, iter_] = arr_preds
# For oob_idxs, no prediction can be made, fill with nan
d_pred[col][start_idx + oob_idxs, iter_] = np.nan
d_pred[col][start_idx + inb_idxs, iter_] = arr_preds
# For oob_idxs, no prediction can be made, fill with nan
d_pred[col][start_idx + oob_idxs, iter_] = np.nan

#############################
# OOD predictions
Expand Down
4 changes: 2 additions & 2 deletions supernnova/visualization/early_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def make_early_prediction(settings, nb_lcs=1, do_gifs=False):
list_entries = np.random.randint(0,high=len(list_data_test),size=nb_lcs)
subset_to_plot = [list_data_test[i] for i in list_entries]
for X, target, SNID, _, X_ori in tqdm(subset_to_plot, ncols=100):

try:
redshift = SNinfo_df[SNinfo_df["SNID"] == SNID]["SIM_REDSHIFT_CMB"].values[0]
peak_MJD = SNinfo_df[SNinfo_df["SNID"] == SNID]["PEAKMJDNORM"].values[0]
Expand All @@ -283,7 +283,7 @@ def make_early_prediction(settings, nb_lcs=1, do_gifs=False):
X_clip = X_ori.copy()
X_clip = np.clip(X_clip[:,settings.idx_features_to_normalize], settings.arr_norm[:, 0], np.inf)
X_ori[:,settings.idx_features_to_normalize] = X_clip
assert np.all(np.all(np.isclose(np.ravel(X_ori), np.ravel(X_unnormed), atol=1e-2)))
assert np.all(np.all(np.isclose(np.ravel(X_ori), np.ravel(X_unnormed), atol=1e-1)))

# TODO: IMPROVE
df_temp = pd.DataFrame(data=X_unnormed, columns=features)
Expand Down

0 comments on commit f4100f3

Please sign in to comment.