Skip to content

Commit

Permalink
survey filters option + docs
Browse files Browse the repository at this point in the history
  • Loading branch information
anaismoller committed Jun 17, 2019
1 parent 8d72951 commit dec0838
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 41 deletions.
9 changes: 9 additions & 0 deletions docs/data/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,15 @@ Photometric measurements may span over a larger time range than the one desired
python run.py --dump_dir <path/to/save/database/> --data --raw_dir <path/to/raw/data/> --photo_window_files <path/to/csv/with/peakMJD> --photo_window_var <name/of/variable/in/csv/to/cut/on> --photo_window_min <negative/int/indicating/days/before/var> --photo_window_max <positive/int/indicating/days/after/var>
Creating a database with different survey
------------------------------
The default filter set is the one from the Dark Energy Survey Supernova ``g,r,i,z``. If you want to use your own survey, you'll need to specify your filters and the possible combination of them in observations.

.. code::
python run.py --dump_dir <path/to/save/database/> --data --raw_dir <path/to/raw/data/> --list_filters <your/filters> --list_filters_combination <your/filter/combination>
e.g. ``--list_filters g r --list_filters_combination g r gr``. Also, beware that the ``--sntypes`` are coherent with your data!

Under the hood
-------------------------------
Expand Down
4 changes: 2 additions & 2 deletions docs/validation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ In this case it will run the model provided in ``model_files`` with the normaliz

Predictions format
~~~~~~~~~~~~~~~~~~~~~
For a binary classification task, predictions files contain the follwing columns:
For a binary classification task, predictions files contain the following columns:

.. code::
Expand All @@ -86,7 +86,7 @@ For a binary classification task, predictions files contain the follwing columns
target int64 - Type of the supernova, simulated class.
SNID int64 - ID number of the light-curve
these columns rely on maximum light information and target (original type) from simulations. Out-of-distribution classifications are done on the fly.
these columns rely on maximum light information and target (original type) from simulations. Out-of-distribution classifications are done on the fly. Bayesian Networks (variational and Bayes by Backprop) have an entry for each probability distribution sampling, to get the mean and std of the classification read the ``_aggregated.pickle`` file.


RNN speed
Expand Down
18 changes: 18 additions & 0 deletions supernnova/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pickle
import argparse
from pathlib import Path
from natsort import natsorted
from collections import OrderedDict
from distutils.util import strtobool
from .utils import experiment_settings
Expand Down Expand Up @@ -200,6 +201,23 @@ def get_args():
default=100,
help="Window size after peak"
)
# Survey configuration
parser.add_argument(
"--list_filters",
nargs='+',
default=natsorted(["g", "i", "r", "z"]),
help="Survey filters"
)
parser.add_argument(
"--list_filters_combination",
nargs='+',
default=natsorted(['g', 'r', 'i', 'z',
'gr', 'gi', 'gz',
'ir', 'iz',
'rz',
'gir', 'giz', 'grz', 'irz', 'girz']),
help="Possible combination of filters"
)

######################
# RNN parameters
Expand Down
41 changes: 26 additions & 15 deletions supernnova/data/make_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,33 +40,40 @@ def build_traintestval_splits(settings):

# Load photometry
# either in HEAD.FITS or csv format
list_files = natsorted(
list_files_tmp = natsorted(
glob.glob(os.path.join(settings.raw_dir, "*HEAD.FITS*"))
)
if len(list_files_tmp) > 0:
list_files = list_files_tmp
fmat = 'FITS'
else:
list_files = natsorted(
glob.glob(os.path.join(settings.raw_dir, "*HEAD.csv*"))
)
fmat = 'csv'
list_files = list_files[:]
if len(list_files) > 0:
print("List files", list_files)
print("List files", list_files)
# use parallelization to speed up processing
if fmat == 'FITS':
process_fn = partial(
data_utils.process_header_FITS,
settings=settings,
columns=photo_columns + ["SNTYPE"],
)
with ProcessPoolExecutor(max_workers=max_workers) as executor:
list_df = executor.map(process_fn, list_files)
else:
list_files = natsorted(
glob.glob(os.path.join(settings.raw_dir, "*HEAD.csv*"))
)
print("List files", list_files)
process_fn = partial(
data_utils.process_header_csv,
settings=settings,
columns=photo_columns + ["SNTYPE"],
)
with ProcessPoolExecutor(max_workers=max_workers) as executor:
list_df = executor.map(process_fn, list_files)
# only used for debugging (if above commented)
# list_df = data_utils.process_header_FITS(list_files[0],settings,columns=photo_columns + ["SNTYPE"])
with ProcessPoolExecutor(max_workers=max_workers) as executor:
list_df = executor.map(process_fn, list_files)
# only used for debugging (if parallelization above commented)
# if fmat == 'FITS':
# list_df = data_utils.process_header_FITS(list_files[0],settings,columns=photo_columns + ["SNTYPE"])
# else:
# list_df = data_utils.process_header_csv(list_files[0],settings,columns=photo_columns + ["SNTYPE"])

# Load df_photo
df_photo = pd.concat(list_df)
df_photo["SNID"] = df_photo["SNID"].astype(int)
Expand Down Expand Up @@ -515,6 +522,7 @@ def preprocess_data(settings):
max_workers = multiprocessing.cpu_count()

host_spe_tmp = []
# use parallelization to speed up processing
# Split list files in chunks of size 10 or less
# to get a progress bar and alleviate memory constraints
num_elem = len(list_files)
Expand All @@ -528,7 +536,7 @@ def preprocess_data(settings):
# Need to cast to list because executor returns an iterator
host_spe_tmp += list(executor.map(parallel_fn,
list_files[start:end]))
# for debugging only (if lines 520-531 are commented)
# for debugging only (parallelization needs to be commented)
# host_spe_tmp.append(process_single_FITS(list_files[0], settings))
# host_spe_tmp.append(process_single_csv(list_files[0], settings))
# Save host spe for plotting and performance tests
Expand All @@ -553,7 +561,7 @@ def pivot_dataframe_single(filename, settings):
settings (ExperimentSettings): controls experiment hyperparameters
"""
list_filters = data_utils.FILTERS
list_filters = settings.list_filters

assert len(list_filters) > 0

Expand Down Expand Up @@ -692,13 +700,15 @@ def pivot_dataframe_batch(list_files, settings):

# Parameters of multiprocessing below
max_workers = multiprocessing.cpu_count()
# use parallelization to speed up processing
# Loop over chunks of files
for chunk_idx in tqdm(list_chunks, desc="Pivoting dataframes", ncols=100):
parallel_fn = partial(pivot_dataframe_single, settings=settings)
# Process each file in the chunk in parallel
with ProcessPoolExecutor(max_workers=max_workers) as executor:
start, end = chunk_idx[0], chunk_idx[-1] + 1
executor.map(parallel_fn, list_files[start:end])
# for debugging only (if above is commented)
# pivot_dataframe_single(list_files[0], settings)

logging_utils.print_green("Finished pivot")
Expand Down Expand Up @@ -744,6 +754,7 @@ def make_dataset(settings):
)
)
logging_utils.print_green("Concatenating pivot")

df = pd.concat([pd.read_pickle(f) for f in list_files], axis=0)

# Save to HDF5
Expand Down
28 changes: 11 additions & 17 deletions supernnova/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,7 @@
OFFSETS = [-2, -1, 0, 1, 2]
OOD_TYPES = ["random", "reverse", "shuffle", "sin"]
OFFSETS_STR = ["-2", "-1", "", "+1", "+2"]
FILTERS = natsorted(["g", "i", "r", "z"])
# non data dependent onehot encoding
FILTERS_COMBINATION = natsorted(['g', 'r', 'i', 'z',
'gr', 'gi', 'gz',
'ir', 'iz',
'rz',
'gir', 'giz', 'grz', 'irz', 'girz'])

PLASTICC_FILTERS = natsorted(["u", "g", "r", "i", "z", "y"])
DICT_PLASTICC_FILTERS = {0: "u", 1: "g", 2: "r", 3: "i", 4: "z", 5: "y"}
DICT_PLASTICC_CLASS = OrderedDict(
Expand Down Expand Up @@ -56,7 +50,6 @@ def load_pandas_from_fit(fit_file_path):
Returns:
(pandas.DataFrame) load dataframe from FIT file
"""

dat = Table.read(fit_file_path, format="fits")
df = dat.to_pandas()

Expand Down Expand Up @@ -181,8 +174,10 @@ def process_header_FITS(file_path, settings, columns=None):

# Data
df = load_pandas_from_fit(file_path)

df = tag_type(df, settings, type_column="SNTYPE")


if columns is not None:
df = df[columns]

Expand All @@ -208,7 +203,6 @@ def process_header_csv(file_path, settings, columns=None):

if columns is not None:
df = df[columns]

return df


Expand Down Expand Up @@ -378,8 +372,8 @@ def save_to_HDF5(settings, df):
"""
# One hot encode filter information and Normalize features
list_training_features = [f"FLUXCAL_{f}" for f in FILTERS]
list_training_features += [f"FLUXCALERR_{f}" for f in FILTERS]
list_training_features = [f"FLUXCAL_{f}" for f in settings.list_filters]
list_training_features += [f"FLUXCALERR_{f}" for f in settings.list_filters]
list_training_features += [
"delta_time",
"HOSTGAL_PHOTOZ",
Expand Down Expand Up @@ -482,7 +476,7 @@ def save_to_HDF5(settings, df):

logging_utils.print_green("Saving filter occurences")
# Compute how many occurences of a specific filter around PEAKMJD
for flt in FILTERS:
for flt in settings.list_filters:
# Check presence / absence of the filter at all time steps
df[f"has_{flt}"] = df.FLT.str.contains(flt).astype(np.uint8)
for offset, suffix in zip(OFFSETS, OFFSETS_STR):
Expand Down Expand Up @@ -538,7 +532,7 @@ def save_to_HDF5(settings, df):
################
# FLUX features
#################
flux_features = [f"FLUXCAL_{f}" for f in FILTERS]
flux_features = [f"FLUXCAL_{f}" for f in settings.list_filters]
flux_log_standardized = log_standardization(
df[flux_features].values)
# Store normalization parameters
Expand All @@ -549,7 +543,7 @@ def save_to_HDF5(settings, df):
###################
# FLUXERR features
###################
fluxerr_features = [f"FLUXCALERR_{f}" for f in FILTERS]
fluxerr_features = [f"FLUXCALERR_{f}" for f in settings.list_filters]
fluxerr_log_standardized = log_standardization(
df[fluxerr_features].values)
# Store normalization parameters
Expand All @@ -575,11 +569,11 @@ def save_to_HDF5(settings, df):
assert sorted(df.columns.values.tolist()) == sorted(
list_training_features + ["FLT"]
)
# cheating to have the same onehot for all datasets
tmp = pd.Series(FILTERS_COMBINATION).append(df["FLT"])
# to have the same onehot for all datasets
tmp = pd.Series(settings.list_filters_combination).append(df["FLT"])
tmp_onehot = pd.get_dummies(tmp)
# this is ok since it goes by length not by index (which I never reset)
FLT_onehot = tmp_onehot[len(FILTERS_COMBINATION):]
FLT_onehot = tmp_onehot[len(settings.list_filters_combination):]
df = pd.concat([df[list_training_features],
FLT_onehot], axis=1)
# store feature names
Expand Down
5 changes: 1 addition & 4 deletions supernnova/utils/experiment_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from pathlib import Path
from collections import OrderedDict
from .data_utils import FILTERS, PLASTICC_FILTERS
from .data_utils import PLASTICC_FILTERS


class ExperimentSettings:
Expand Down Expand Up @@ -47,9 +47,6 @@ def __init__(self, cli_args):

self.randomforest_features = self.get_randomforest_features()

# Set the filters used in the study
self.list_filters = FILTERS

# Set the database file names
self.set_database_file_names()

Expand Down
5 changes: 2 additions & 3 deletions supernnova/visualization/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

plt.switch_backend("agg")
import matplotlib.gridspec as gridspec
from ..utils.data_utils import FILTERS


def plot_lightcurves(df, SNIDs, settings):
Expand All @@ -31,7 +30,7 @@ def plot_lightcurves(df, SNIDs, settings):
df_temp = df.loc[SNID]

# Prepare plotting data in a dict
d = {flt: {"FLUXCAL": [], "FLUXCALERR": [], "MJD": []} for flt in FILTERS}
d = {flt: {"FLUXCAL": [], "FLUXCALERR": [], "MJD": []} for flt in settings.list_filters}

current_time = 0
for idx in range(len(df_temp)):
Expand Down Expand Up @@ -129,7 +128,7 @@ def plot_lightcurves_from_hdf5(settings, SNID_idxs):
max_y = -float("Inf")
min_y = float("Inf")

for FLT in FILTERS:
for FLT in settings.list_filters:
idxs = np.array(
[i for i in range(len(df)) if FLT in list_present_filters[i]]
)
Expand Down

0 comments on commit dec0838

Please sign in to comment.