diff --git a/photonai/processing/photon_folds.py b/photonai/processing/photon_folds.py index c751cfc3..4bad2d33 100644 --- a/photonai/processing/photon_folds.py +++ b/photonai/processing/photon_folds.py @@ -26,7 +26,8 @@ def data_overview(y): return {} else: unique, counts = np.unique(y, return_counts=True) - unique = [str(u) for u in unique] + # replacing is necessary for float inputs, because mongoDB does not allow '.'-char + unique = [str(u).replace(".", "_") for u in unique] counts = [int(c) for c in counts] return dict(zip(unique, counts)) @@ -64,7 +65,7 @@ def generate_folds(cv_strategy, X, y, kwargs, eval_final_performance=True, test_ data_test_cases = cv_strategy.split(X, groups) except: logger.error("Could not stratify data for outer cross validation according to " - "group variable") + "group variable") else: data_test_cases = cv_strategy.split(X, y) diff --git a/test/processing_tests/test_results_handler.py b/test/processing_tests/test_results_handler.py index bf44b97d..8c2793bb 100644 --- a/test/processing_tests/test_results_handler.py +++ b/test/processing_tests/test_results_handler.py @@ -305,3 +305,11 @@ def test_get_performance_table(self): def test_get_methods(self): self.hyperpipe.results_handler.get_methods() + + def test_float_labels_with_mongo(self): + """ + This test was added for a bug with float labels and saving to mongoDB. + """ + local_y = self.__y.astype(float) + self.hyperpipe.output_settings.mongodb_connect_url = self.mongodb_path + self.hyperpipe.fit(self.__X, local_y) \ No newline at end of file