From 5d387692b1edf56ed1783f15b81e7ca6fe9dd4d7 Mon Sep 17 00:00:00 2001 From: lucasplagwitz Date: Wed, 23 Sep 2020 11:03:03 +0200 Subject: [PATCH] allow estimator_type in [None, 'transformer'] for PipelineElement without predict method --- photonai/base/photon_elements.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/photonai/base/photon_elements.py b/photonai/base/photon_elements.py index 8baf6e68..60c91d66 100644 --- a/photonai/base/photon_elements.py +++ b/photonai/base/photon_elements.py @@ -182,19 +182,21 @@ def random_state(self, random_state): @property def _estimator_type(self): - if hasattr(self.base_element, '_estimator_type'): - est_type = getattr(self.base_element, '_estimator_type') - if est_type is not 'classifier' and est_type is not 'regressor': - raise NotImplementedError("Currently, we only support type classifier or regressor. Is {}.".format(est_type)) + # estimator_type obligation for estimators, is ignored if a transformer is given + # prevention of misuse through predict test (predict method available <=> Estimator). + est_type = getattr(self.base_element, '_estimator_type', None) + if est_type in [None, 'transformer']: + if hasattr(self.base_element, 'predict'): + raise NotImplementedError("Element has predict() method but does not specify whether it is a regressor" + " or classifier. Remember to inherit from ClassifierMixin or RegressorMixin.") + return None + else: + if est_type not in ['classifier', 'regressor']: + raise NotImplementedError("Currently, we only support type classifier or regressor." + " Is {}.".format(est_type)) if not hasattr(self.base_element, 'predict'): raise NotImplementedError("Estimator does not implement predict() method.") return est_type - else: - if hasattr(self.base_element, 'predict'): - raise NotImplementedError("Element has predict() method but does not specify whether it is a regressor " - "or classifier. Remember to inherit from ClassifierMixin or RegressorMixin.") - else: - return None # this is only here because everything inherits from PipelineElement. def __iadd__(self, pipe_element):