Skip to content

Commit

Permalink
plot_lcs exception 1 point + peak
Browse files Browse the repository at this point in the history
  • Loading branch information
anaismoller committed Oct 19, 2022
1 parent c3543ff commit 2248655
Showing 1 changed file with 59 additions and 52 deletions.
111 changes: 59 additions & 52 deletions supernnova/visualization/early_prediction.py
Expand Up @@ -159,7 +159,7 @@ def plot_predictions(
OOD is None
and not settings.data_testing
and arr_time.min() < peak_MJD
and peak_MJD > arr_time.max()
and peak_MJD < arr_time.max()
):
ax.plot([peak_MJD, peak_MJD], [0, 1], "k--", label="Peak MJD")
ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0)
Expand Down Expand Up @@ -283,60 +283,67 @@ def make_early_prediction(settings, nb_lcs=1, do_gifs=False):
settings, dict_rnn, X, target, OOD=OOD
)
# X here has been normalized. We unnormalize X
X_unnormed = tu.unnormalize_arr(X_normed, settings)
# Check we do recover X_ori when OOD is None
if OOD is None and "cosmo" not in settings.norm:
# check if normalization converges
# using clipping in case of min<model_min
X_clip = X_ori.copy()
X_clip = np.clip(
X_clip[:, settings.idx_features_to_normalize],
settings.arr_norm[:, 0],
np.inf,
)
X_ori[:, settings.idx_features_to_normalize] = X_clip
assert np.all(
np.all(np.isclose(np.ravel(X_ori), np.ravel(X_unnormed), atol=1e-1))
try:
X_unnormed = tu.unnormalize_arr(X_normed, settings)

# Check we do recover X_ori when OOD is None
if OOD is None and "cosmo" not in settings.norm:
# check if normalization converges
# using clipping in case of min<model_min
X_clip = X_ori.copy()
X_clip = np.clip(
X_clip[:, settings.idx_features_to_normalize],
settings.arr_norm[:, 0],
np.inf,
)
X_ori[:, settings.idx_features_to_normalize] = X_clip
assert np.all(
np.all(
np.isclose(np.ravel(X_ori), np.ravel(X_unnormed), atol=1e-1)
)
)

# TODO: IMPROVE
df_temp = pd.DataFrame(data=X_unnormed, columns=features)
arr_time = np.cumsum(df_temp.delta_time.values)
df_temp["time"] = arr_time
for flt in settings.list_filters:
non_zero = np.where(
~np.isclose(df_temp[f"FLUXCAL_{flt}"].values, 0, atol=1e-2)
)[0]
d_plot[flt]["FLUXCAL"] = df_temp[f"FLUXCAL_{flt}"].values[non_zero]
d_plot[flt]["FLUXCALERR"] = df_temp[f"FLUXCALERR_{flt}"].values[
non_zero
]
d_plot[flt]["MJD"] = arr_time[non_zero]
plot_predictions(
settings,
d_plot,
SNID,
redshift,
peak_MJD,
target,
arr_time,
d_pred,
OOD,
)

# TODO: IMPROVE
df_temp = pd.DataFrame(data=X_unnormed, columns=features)
arr_time = np.cumsum(df_temp.delta_time.values)
df_temp["time"] = arr_time
for flt in settings.list_filters:
non_zero = np.where(
~np.isclose(df_temp[f"FLUXCAL_{flt}"].values, 0, atol=1e-2)
)[0]
d_plot[flt]["FLUXCAL"] = df_temp[f"FLUXCAL_{flt}"].values[non_zero]
d_plot[flt]["FLUXCALERR"] = df_temp[f"FLUXCALERR_{flt}"].values[
non_zero
]
d_plot[flt]["MJD"] = arr_time[non_zero]
plot_predictions(
settings,
d_plot,
SNID,
redshift,
peak_MJD,
target,
arr_time,
d_pred,
OOD,
)
# use to create GIFs
if not OOD:
if do_gifs:
plot_gif(
settings,
df_temp,
SNID,
redshift,
peak_MJD,
target,
arr_time,
d_pred,
)
except Exception:
lu.print_red(f"SNID {SNID} only has {len(X)} measurement, not plotting")

# use to create GIFs
if not OOD:
if do_gifs:
plot_gif(
settings,
df_temp,
SNID,
redshift,
peak_MJD,
target,
arr_time,
d_pred,
)
lu.print_green("Finished plotting lightcurves and predictions ")


Expand Down

0 comments on commit 2248655

Please sign in to comment.