diff --git a/combo/models/base.py b/combo/models/base.py index fd0b546..825ab48 100644 --- a/combo/models/base.py +++ b/combo/models/base.py @@ -29,8 +29,9 @@ class BaseAggregator(ABC): Parameters ---------- - base_estimators: list or numpy array (n_estimators,) - A list of base estimators. + base_estimators: list, length must be greater than 1 + A list of base estimators. Certain methods must be present, e.g., + `fit` and `predict`. pre_fitted: bool, optional (default=False) Whether the base estimators are trained. If True, `fit` diff --git a/combo/models/detector_comb.py b/combo/models/detector_comb.py index 72adc04..5184011 100644 --- a/combo/models/detector_comb.py +++ b/combo/models/detector_comb.py @@ -23,16 +23,14 @@ class SimpleDetectorAggregator(BaseAggregator): Parameters ---------- - base_estimators: list or numpy array (n_estimators,) - A list of base detectors. + base_estimators : list, length must be greater than 1 + Base unsupervised outlier detectors from PyOD. (Note: requires fit and + decision_function methods) method : str, optional (default='average') Combination method: {'average', 'maximization', 'median'}. Pass in weights of detector for weighted version. - threshold : float in (0, 1), optional (default=0.5) - Cut-off value to convert scores into binary labels. - contamination : float in (0., 0.5), optional (default=0.1) The amount of contamination of the data set, i.e. the proportion of outliers in the data set. Used when fitting to @@ -69,9 +67,8 @@ class SimpleDetectorAggregator(BaseAggregator): ``threshold_`` on ``decision_scores_``. """ - def __init__(self, base_estimators, method='average', threshold=0.5, - contamination=0.1, standardization=True, - weights=None, pre_fitted=False): + def __init__(self, base_estimators, method='average', contamination=0.1, + standardization=True, weights=None, pre_fitted=False): super(SimpleDetectorAggregator, self).__init__( base_estimators=base_estimators, pre_fitted=pre_fitted) @@ -89,10 +86,6 @@ def __init__(self, base_estimators, method='average', threshold=0.5, self.standardization = standardization - check_parameter(threshold, 0, 1, include_left=False, - include_right=False, param_name='threshold') - self.threshold = threshold - if weights is None: self.weights = np.ones([1, self.n_base_estimators_]) else: