Skip to content

Commit

Permalink
issue1: type error,fix bug on sntypes
Browse files Browse the repository at this point in the history
  • Loading branch information
anaismoller committed Jun 11, 2019
1 parent a12f17a commit 8d72951
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions supernnova/data/make_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def build_traintestval_splits(settings):
# Save a dataframe to record train/test/val split for
# binary, ternary and all-classes classification
for dataset in ["saltfit", "photometry"]:
for nb_classes in list(set([2, len(settings.sntypes.keys())])):
for nb_classes in list([2, len(settings.sntypes.keys())]):
logging_utils.print_green(
f"Computing {dataset} splits for {nb_classes}-way classification"
)
Expand Down Expand Up @@ -470,7 +470,7 @@ def process_single_csv(file_path, settings):
# Merge left on df: len(df) will not change and will now include
# relevant columns from df_SNID
merge_columns = ["SNID"]
for c_ in [2, list(set(len(settings.sntypes.keys())))]:
for c_ in [2, len(settings.sntypes.keys())]:
merge_columns += [f"target_{c_}classes"]
for dataset in ["photometry", "saltfit"]:
merge_columns += [f"dataset_{dataset}_{c_}classes"]
Expand Down Expand Up @@ -528,8 +528,9 @@ def preprocess_data(settings):
# Need to cast to list because executor returns an iterator
host_spe_tmp += list(executor.map(parallel_fn,
list_files[start:end]))
# for debugging
# for debugging only (if lines 520-531 are commented)
# host_spe_tmp.append(process_single_FITS(list_files[0], settings))
# host_spe_tmp.append(process_single_csv(list_files[0], settings))
# Save host spe for plotting and performance tests
host_spe = [item for sublist in host_spe_tmp for item in sublist]
pd.DataFrame(host_spe, columns=["SNID"]).to_pickle(
Expand Down Expand Up @@ -597,7 +598,7 @@ def pivot_dataframe_single(filename, settings):
# drop columns that won"t be used onwards
df = df.drop(["MJD", "delta_time"], 1)
class_columns = []
for c_ in list(set([2, len(settings.sntypes.keys())])):
for c_ in list([2, len(settings.sntypes.keys())]):
class_columns += [f"target_{c_}classes"]
for dataset in ["photometry", "saltfit"]:
class_columns += [f"dataset_{dataset}_{c_}classes"]
Expand Down

0 comments on commit 8d72951

Please sign in to comment.