diff --git a/src/train.py b/src/train.py index 437c494..653c7a6 100644 --- a/src/train.py +++ b/src/train.py @@ -152,7 +152,6 @@ 0.5, shuffle=False, k_iwae=num_samples, - model_name=args.net, ) if itr % 100 == 0 and args.save: torch.save({ diff --git a/src/utils.py b/src/utils.py index 740ba78..2533e95 100644 --- a/src/utils.py +++ b/src/utils.py @@ -64,7 +64,6 @@ def evaluate_hetvae( sample_tp=0.5, shuffle=False, k_iwae=1, - model_name=None, device='cuda', ): torch.manual_seed(seed=0) @@ -157,12 +156,6 @@ def get_mimiciii_data(batch_size, test_batch_size=5, filter_anomalies=True): data_mean.append(0) data_std.append(1) - # filtering - observed_vals[observed_vals > hth] = 0.0 - observed_vals[observed_vals < lth] = 0.0 - observed_mask[observed_vals > hth] = 0 - observed_mask[observed_vals < lth] = 0 - # normalizing observed_vals = (observed_vals - data_mean) / data_std observed_vals[observed_mask == 0] = 0 @@ -189,8 +182,8 @@ def get_mimiciii_data(batch_size, test_batch_size=5, filter_anomalies=True): total_dataset, train_size=0.8, random_state=42, shuffle=True ) # for interpolation, we dont need a non-overlapping validation set as - # we can condition on different set of time points from same dataset to - # create a distinct example + # we can condition on different set of time points from same set to + # create distinct examples _, val_data = model_selection.train_test_split( train_data, train_size=0.8, random_state=11, shuffle=True )