diff --git a/.github/workflows/testing-cron.yml b/.github/workflows/testing-cron.yml index 79ddf7938..760da6138 100644 --- a/.github/workflows/testing-cron.yml +++ b/.github/workflows/testing-cron.yml @@ -16,7 +16,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, windows-latest] - python-version: [3.6, 3.9] + python-version: [3.6, 3.7, 3.8, 3.9] steps: - uses: actions/checkout@v2 diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 855630098..04d2d01ba 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -9,7 +9,9 @@ on: - master - development pull_request: - branches: [ master ] + branches: + - master + - development jobs: build: @@ -19,7 +21,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, windows-latest] - python-version: [3.6, 3.9] + python-version: [3.6, 3.7, 3.8, 3.9] steps: - uses: actions/checkout@v2 diff --git a/CHANGES.txt b/CHANGES.txt index bfc7e5422..a85ce2116 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -150,6 +150,10 @@ v<0.9.6>, <12/24/2021> -- Model persistence doc improvement. v<0.9.7>, <01/03/2022> -- Add ECOD. v<0.9.8>, <02/23/2022> -- Add Feature Importance for iForest. v<0.9.8>, <03/05/2022> -- Update ECOD (TKDE 2022). +v<0.9.9>, <03/20/2022> -- Renovate documentation. +v<0.9.9>, <03/23/2022> -- Add example for COPOD interpretability. +v<0.9.9>, <03/23/2022> -- Add outlier detection by Cook’s distances. +v<0.9.9>, <04/04/2022> -- Various community fix. diff --git a/README.rst b/README.rst index de3e8fd4a..541a96be6 100644 --- a/README.rst +++ b/README.rst @@ -315,6 +315,7 @@ Probabilistic MAD Median Absolute Deviation (MAD) Probabilistic SOS Stochastic Outlier Selection 2012 [#Janssens2012Stochastic]_ Linear Model PCA Principal Component Analysis (the sum of weighted projected distances to the eigenvector hyperplanes) 2003 [#Shyu2003A]_ Linear Model MCD Minimum Covariance Determinant (use the mahalanobis distances as the outlier scores) 1999 [#Hardin2004Outlier]_ [#Rousseeuw1999A]_ +Linear Model CD Use Cook's distance for outlier detection 1977 [#Cook1977Detection]_ Linear Model OCSVM One-Class Support Vector Machines 2001 [#Scholkopf2001Estimating]_ Linear Model LMDD Deviation-based Outlier Detection (LMDD) 1996 [#Arning1996A]_ Proximity-Based LOF Local Outlier Factor 2000 [#Breunig2000LOF]_ @@ -548,6 +549,8 @@ Reference .. [#Burgess2018Understanding] Burgess, Christopher P., et al. "Understanding disentangling in beta-VAE." arXiv preprint arXiv:1804.03599 (2018). +.. [#Cook1977Detection] Cook, R.D., 1977. Detection of influential observation in linear regression. Technometrics, 19(1), pp.15-18. + .. [#Goldstein2012Histogram] Goldstein, M. and Dengel, A., 2012. Histogram-based outlier score (hbos): A fast unsupervised anomaly detection algorithm. In *KI-2012: Poster and Demo Track*\ , pp.59-63. .. [#Gopalan2019PIDForest] Gopalan, P., Sharan, V. and Wieder, U., 2019. PIDForest: Anomaly Detection via Partial Identification. In Advances in Neural Information Processing Systems, pp. 15783-15793. diff --git a/TODO.txt b/TODO.txt new file mode 100644 index 000000000..73d03a838 --- /dev/null +++ b/TODO.txt @@ -0,0 +1,3 @@ +1. ECOD parallelization and interpretability +2. Add latest deep learning algorithms. +3. finish the wrapping for cook distance detector \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index 39a1d776d..7513eb57d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -25,7 +25,7 @@ # -- Project information ----------------------------------------------------- project = 'pyod' -copyright = '2021, Yue Zhao' +copyright = '2022, Yue Zhao' author = 'Yue Zhao' # The short X.Y version @@ -50,8 +50,8 @@ 'sphinx.ext.imgmath', 'sphinx.ext.viewcode', 'sphinxcontrib.bibtex', - 'sphinx.ext.napoleon', - 'sphinx_rtd_theme', + # 'sphinx.ext.napoleon', + # 'sphinx_rtd_theme', ] bibtex_bibfiles = ['zreferences.bib'] @@ -90,7 +90,7 @@ # # html_theme = 'default' -html_theme = "sphinx_rtd_theme" +html_theme = "furo" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the @@ -112,8 +112,8 @@ # 'searchbox.html']``. # # html_sidebars = {} -html_sidebars = {'**': ['globaltoc.html', 'relations.html', 'sourcelink.html', - 'searchbox.html']} +# html_sidebars = {'**': ['globaltoc.html', 'relations.html', 'sourcelink.html', +# 'searchbox.html']} # -- Options for HTMLHelp output --------------------------------------------- diff --git a/docs/index.rst b/docs/index.rst index 5db5e110c..e001f6533 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -154,6 +154,7 @@ Probabilistic MAD Median Absolute Deviation (MAD) Probabilistic SOS Stochastic Outlier Selection 2012 :class:`pyod.models.sos.SOS` :cite:`a-janssens2012stochastic` Linear Model PCA Principal Component Analysis (the sum of weighted projected distances to the eigenvector hyperplanes) 2003 :class:`pyod.models.pca.PCA` :cite:`a-shyu2003novel` Linear Model MCD Minimum Covariance Determinant (use the mahalanobis distances as the outlier scores) 1999 :class:`pyod.models.mcd.MCD` :cite:`a-rousseeuw1999fast,a-hardin2004outlier` +Linear Model CD Use Cook's distance for outlier detection 1977 :class:`pyod.models.cd.CD` :cite:`a-cook1977detection` Linear Model OCSVM One-Class Support Vector Machines 2001 :class:`pyod.models.ocsvm.OCSVM` :cite:`a-scholkopf2001estimating` Linear Model LMDD Deviation-based Outlier Detection (LMDD) 1996 :class:`pyod.models.lmdd.LMDD` :cite:`a-arning1996linear` Proximity-Based LOF Local Outlier Factor 2000 :class:`pyod.models.lof.LOF` :cite:`a-breunig2000lof` diff --git a/docs/pyod.models.rst b/docs/pyod.models.rst index 123d36d4c..40af77470 100644 --- a/docs/pyod.models.rst +++ b/docs/pyod.models.rst @@ -57,6 +57,16 @@ pyod.models.combination module :show-inheritance: :inherited-members: +pyod.models.cd module +--------------------- + +.. automodule:: pyod.models.cd + :members: + :exclude-members: + :undoc-members: + :show-inheritance: + :inherited-members: + pyod.models.copod module ------------------------ diff --git a/docs/requirements.txt b/docs/requirements.txt index ef7bad27b..2d6cef684 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,4 +1,5 @@ combo +furo joblib keras matplotlib diff --git a/docs/zreferences.bib b/docs/zreferences.bib index 26d15916c..27b006ac2 100644 --- a/docs/zreferences.bib +++ b/docs/zreferences.bib @@ -316,7 +316,7 @@ @article{pevny2016loda @article{burgess2018understanding, - title={Understanding disentangling in beta-VAE}, + title={Understanding disentangling in betVAE}, author={Burgess, Christopher P and Higgins, Irina and Pal, Arka and Matthey, Loic and Watters, Nick and Desjardins, Guillaume and Lerchner, Alexander}, journal={arXiv preprint arXiv:1804.03599}, year={2018} @@ -379,10 +379,21 @@ @inproceedings{perini2020quantifying publisher={Springer} } -@article{Li2021ecod, +@article{li2021ecod, title={ECOD: Unsupervised Outlier Detection Using Empirical Cumulative Distribution Functions}, author={Li, Zheng and Zhao, Yue and Hu, Xiyang and Botta, Nicola and Ionescu, Cezar and Chen, H. George}, journal={IEEE Transactions on Knowledge and Data Engineering}, year={2022}, publisher={IEEE} +} + +@article{cook1977detection, + title={Detection of influential observation in linear regression}, + author={Cook, R Dennis}, + journal={Technometrics}, + volume={19}, + number={1}, + pages={15--18}, + year={1977}, + publisher={Taylor \& Francis} } \ No newline at end of file diff --git a/examples/cd_example.py b/examples/cd_example.py new file mode 100644 index 000000000..8112e3f31 --- /dev/null +++ b/examples/cd_example.py @@ -0,0 +1,58 @@ +"""Example of using Cook's distance (CD) for +outlier detection +""" +# Author: D Kulik +# License: BSD 2 clause + +from __future__ import division +from __future__ import print_function + +import os +import sys + +# temporary solution for relative imports in case pyod is not installed +# if pyod is installed, no need to use the following line +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname("__file__"), '..'))) + +import numpy as np +from pyod.models.cd import CD +from pyod.utils.data import generate_data +from pyod.utils.data import evaluate_print +from pyod.utils.example import visualize + +if __name__ == "__main__": + contamination = 0.1 # percentage of outliers + n_train = 200 # number of training points + n_test = 100 # number of testing points + + # Generate sample data + X_train, y_train, X_test, y_test = \ + generate_data(n_train=n_train, + n_test=n_test, + n_features=2, + contamination=contamination, + random_state=42) + + # train HBOS detector + clf_name = 'CD' + clf = CD() + clf.fit(X_train, y_train) + + # get the prediction labels and outlier scores of the training data + y_train_pred = clf.labels_ # binary labels (0: inliers, 1: outliers) + y_train_scores = clf.decision_scores_ # raw outlier scores + + # get the prediction on the test data + y_test_pred = clf.predict(np.append(X_test, y_test.reshape(-1,1), axis=1)) # outlier labels (0 or 1) + y_test_scores = clf.decision_function(np.append(X_test, y_test.reshape(-1,1), axis=1)) # outlier scores + + # evaluate and print the results + print("\nOn Training Data:") + evaluate_print(clf_name, y_train, y_train_scores) + print("\nOn Test Data:") + evaluate_print(clf_name, y_test, y_test_scores) + + # visualize the results + visualize(clf_name, X_train, y_train, X_test, y_test, y_train_pred, + y_test_pred, show_figure=True, save_figure=False) diff --git a/examples/copod_interpretability.py b/examples/copod_interpretability.py index 633f86615..513151919 100644 --- a/examples/copod_interpretability.py +++ b/examples/copod_interpretability.py @@ -1,2 +1,56 @@ # -*- coding: utf-8 -*- +"""Example of using Copula Based Outlier Detector (COPOD) for outlier detection +Sample wise interpretation is provided here. +""" +# Author: Winston Li +# License: BSD 2 clause +from __future__ import division +from __future__ import print_function + +import os +import sys + +# temporary solution for relative imports in case pyod is not installed +# if pyod is installed, no need to use the following line +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname("__file__"), '..'))) + +from scipy.io import loadmat +from sklearn.model_selection import train_test_split + +from pyod.models.copod import COPOD +from pyod.utils.utility import standardizer + +if __name__ == "__main__": + # Define data file and read X and y + # Generate some data if the source data is missing + mat_file = 'cardio.mat' + + mat = loadmat(os.path.join('data', mat_file)) + X = mat['X'] + y = mat['y'].ravel() + + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, + random_state=1) + + # standardizing data for processing + X_train_norm, X_test_norm = standardizer(X_train, X_test) + + # train COPOD detector + clf_name = 'COPOD' + clf = COPOD() + + # you could try parallel version as well. + # clf = COPOD(n_jobs=2) + clf.fit(X_train) + + # get the prediction labels and outlier scores of the training data + y_train_pred = clf.labels_ # binary labels (0: inliers, 1: outliers) + y_train_scores = clf.decision_scores_ # raw outlier scores + + print('The first sample is an outlier', y_train[0]) + clf.explain_outlier(0) + + # we could see feature 7, 16, and 20 is above the 0.99 cutoff + # and play a more important role in deciding it is an outlier. diff --git a/pyod/models/cd.py b/pyod/models/cd.py new file mode 100644 index 000000000..c6e2cfb97 --- /dev/null +++ b/pyod/models/cd.py @@ -0,0 +1,189 @@ +# -*- coding: utf-8 -*- +"""Cook's distance outlier detection (CD) +""" + +# Author: D Kulik +# License: BSD 2 clause + +from __future__ import division +from __future__ import print_function + +import numpy as np +from sklearn.linear_model import LinearRegression +from sklearn.decomposition import PCA +from sklearn.utils import check_array +from sklearn.utils.validation import check_is_fitted + +from .base import BaseDetector +from ..utils.utility import check_parameter + +def whiten_data(X, pca): + + X = pca.transform(X) + + return X + + +def Cooks_dist(X, y, model): + + # Leverage is computed as the diagonal of the projection matrix of X + leverage = (X * np.linalg.pinv(X).T).sum(1) + + # Compute the rank and the degrees of freedom of the model + rank = np.linalg.matrix_rank(X) + df = X.shape[0] - rank + + # Compute the MSE from the residuals + residuals = y - model.predict(X) + mse = np.dot(residuals, residuals) / df + + # Compute Cook's distance + residuals_studentized = residuals / np.sqrt(mse) / np.sqrt(1 - leverage) + distance_ = residuals_studentized ** 2 / X.shape[1] + distance_ *= leverage / (1 - leverage) + + return distance_ + + + +class CD(BaseDetector): + """Cook's distance can be used to identify points that negatively + affect a regression model. A combination of each observation’s + leverage and residual values are used in the measurement. Higher + leverage and residuals relate to higher Cook’s distances. + Read more in the :cite:`cook1977detection`. + + Parameters + ---------- + contamination : float in (0., 0.5), optional (default=0.1) + The amount of contamination of the data set, i.e. + the proportion of outliers in the data set. Used when fitting to + define the threshold on the decision function. + + whiten : bool, optional (default=True) + transform X to have a covariance matrix that is the identity matrix  + of 1 in the diagonal and 0 for the other cells using PCA + + rule_of_thumb : bool, optional (default=False) + to apply the rule of thumb prediction (4 / n) as the influence + threshold; where n is the number of samples. This has been know to + be a good estimate for values over this point as being outliers. + ** Note the contamination level is reset when rule_of_thumb is + set to True + + + Attributes + ---------- + decision_scores_ : numpy array of shape (n_samples,) + The outlier scores of the training data. + The higher, the more abnormal. Outliers tend to have higher + scores. This value is available once the detector is + fitted. + + threshold_ : float + The modified z-score to use as a threshold. Observations with + a modified z-score (based on the median absolute deviation) greater + than this value will be classified as outliers. + + labels_ : int, either 0 or 1 + The binary labels of the training data. 0 stands for inliers + and 1 for outliers/anomalies. It is generated by applying + ``threshold_`` on ``decision_scores_``. + """ + + + def __init__(self, whitening=True, contamination=0.1, rule_of_thumb=False): + + super(CD, self).__init__(contamination=contamination) + self.whitening = whitening + self.rule_of_thumb = rule_of_thumb + + + def fit(self, X, y): + """Fit detector. y is necessary for supervised method. + + Parameters + ---------- + X : numpy array of shape (n_samples, n_features) + The input samples. + + y : numpy array of shape (n_samples,), optional (default=None) + The ground truth of the input samples (labels). + """ + + # Define OLS model + self.model = LinearRegression() + + # Validate inputs X and y + try: + X = check_array(X) + except ValueError: + X = X.reshape(-1,1) + + y = np.squeeze(check_array(y, ensure_2d=False)) + self._set_n_classes(y) + + # Apply whitening + if self.whitening: + self.pca = PCA(whiten=True) + self.pca.fit(X) + X = whiten_data(X, self.pca) + + # Fit a linear model to X and y + self.model.fit(X, y) + + # Get Cook's Distance + distance_ = Cooks_dist(X, y, self.model) + + # Compute the influence threshold + if self.rule_of_thumb: + influence_threshold_ = 4 / X.shape[0] + self.contamination = sum(distance_ > influence_threshold_) / X.shape[0] + + self.decision_scores_ = distance_ + + self._process_decision_scores() + + return self + + + def decision_function(self, X): + """Predict raw anomaly score of X using the fitted detector. + + The anomaly score of an input sample is computed based on different + detector algorithms. For consistency, outliers are assigned with + larger anomaly scores. + + Parameters + ---------- + X : numpy array of shape (n_samples, n_features) + The independent and dependent/target samples with the target + samples being the last column of the numpy array such that + eg: X = np.append(x, y.reshape(-1,1), axis=1). Sparse matrices are + accepted only if they are supported by the base estimator. + + Returns + ------- + anomaly_scores : numpy array of shape (n_samples,) + The anomaly score of the input samples. + """ + + check_is_fitted(self, ['decision_scores_', 'threshold_', 'labels_']) + + try: + X = check_array(X) + except ValueError: + X = X.reshape(-1,1) + + y = X[:,-1] + X = X[:,:-1] + + + # Apply whitening + if self.whitening: + X = whiten_data(X, self.pca) + + # Get Cook's Distance + distance_ = Cooks_dist(X, y, self.model) + + return distance_ diff --git a/pyod/models/ecod.py b/pyod/models/ecod.py index 1593c07d3..649edc0b2 100644 --- a/pyod/models/ecod.py +++ b/pyod/models/ecod.py @@ -149,7 +149,10 @@ def decision_function(self, X): skewness = np.sign(skew(X, axis=0)) self.U_skew = self.U_l * -1 * np.sign( skewness - 1) + self.U_r * np.sign(skewness + 1) - self.O = np.maximum(self.U_skew, self.U_l, self.U_r) + + self.O = np.maximum(self.U_l, self.U_r) + self.O = np.maximum(self.U_skew, self.O) + if hasattr(self, 'X_train'): decision_scores_ = self.O.sum(axis=1)[-original_size:] else: @@ -208,7 +211,10 @@ def _decision_function_parallel(self, X): skewness = np.sign(skew(X, axis=0)) self.U_skew = self.U_l * -1 * np.sign( skewness - 1) + self.U_r * np.sign(skewness + 1) - self.O = np.maximum(self.U_skew, self.U_l, self.U_r) + + self.O = np.maximum(self.U_l, self.U_r) + self.O = np.maximum(self.U_skew, self.O) + if hasattr(self, 'X_train'): decision_scores_ = self.O.sum(axis=1)[-original_size:] else: diff --git a/pyod/models/mad.py b/pyod/models/mad.py index f334968ba..90c50d851 100644 --- a/pyod/models/mad.py +++ b/pyod/models/mad.py @@ -83,7 +83,7 @@ def fit(self, X, y=None): self : object Fitted estimator. """ - X = check_array(X, ensure_2d=False) + X = check_array(X, ensure_2d=False, force_all_finite=False) _check_dim(X) self._set_n_classes(y) self.median = None # reset median after each call @@ -111,7 +111,7 @@ def decision_function(self, X): anomaly_scores : numpy array of shape (n_samples,) The anomaly score of the input samples. """ - X = check_array(X, ensure_2d=False) + X = check_array(X, ensure_2d=False, force_all_finite=False) _check_dim(X) return self._mad(X) @@ -129,7 +129,7 @@ def _mad(self, X): # `self.median` will be None only before `fit()` is called self.median = np.nanmedian(obs) if self.median is None else self.median diff = np.abs(obs - self.median) - self.median_diff = np.median(diff) if self.median_diff is None else self.median_diff + self.median_diff = np.nanmedian(diff) if self.median_diff is None else self.median_diff return np.nan_to_num(np.ravel(0.6745 * diff / self.median_diff)) def _process_decision_scores(self): @@ -147,7 +147,7 @@ def _process_decision_scores(self): self.labels_ = (self.decision_scores_ > self.threshold_).astype('int').ravel() # calculate for predict_proba() - self._mu = np.mean(self.decision_scores_) - self._sigma = np.std(self.decision_scores_) + self._mu = np.nanmean(self.decision_scores_) + self._sigma = np.nanstd(self.decision_scores_) return self diff --git a/pyod/test/test_cd.py b/pyod/test/test_cd.py new file mode 100644 index 000000000..9129d5bcc --- /dev/null +++ b/pyod/test/test_cd.py @@ -0,0 +1,133 @@ +# -*- coding: utf-8 -*- +from __future__ import division +from __future__ import print_function + +import os +import sys + +import unittest +# noinspection PyProtectedMember +from numpy.testing import assert_allclose +from numpy.testing import assert_array_less +from numpy.testing import assert_equal +from numpy.testing import assert_raises +import numpy as np +from sklearn.base import clone + +# temporary solution for relative imports in case pyod is not installed +# if pyod is installed, no need to use the following line +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname("__file__"), '..'))) + +from pyod.models.cd import CD +from pyod.utils.data import generate_data + + +class TestCD(unittest.TestCase): + """ + Notes: GAN may yield unstable results, so the test is design for running + models only, without any performance check. + """ + + def setUp(self): + self.n_train = 1000 + self.n_test = 200 + self.n_features = 2 + self.contamination = 0.1 + # GAN may yield unstable results; turning performance check off + # self.roc_floor = 0.8 + self.X_train, self.y_train, self.X_test, self.y_test = generate_data( + n_train=self.n_train, n_test=self.n_test, + n_features=self.n_features, contamination=self.contamination, + random_state=42) + + self.clf = CD(contamination=self.contamination) + self.clf.fit(self.X_train, self.y_train) + + def test_parameters(self): + assert (hasattr(self.clf, 'decision_scores_') and + self.clf.decision_scores_ is not None) + assert (hasattr(self.clf, 'labels_') and + self.clf.labels_ is not None) + assert (hasattr(self.clf, 'threshold_') and + self.clf.threshold_ is not None) + + def test_train_scores(self): + assert_equal(len(self.clf.decision_scores_), self.X_train.shape[0]) + + def test_prediction_scores(self): + pred_scores = self.clf.decision_function(np.append(self.X_test, + self.y_test.reshape(-1,1), + axis=1)) + + # check score shapes + assert_equal(pred_scores.shape[0], self.X_test.shape[0]) + + # check performance + # assert (roc_auc_score(self.y_test, pred_scores) >= self.roc_floor) + + def test_prediction_labels(self): + pred_labels = self.clf.predict(np.append(self.X_test, + self.y_test.reshape(-1,1), + axis=1)) + assert_equal(pred_labels.shape, self.y_test.shape) + + def test_prediction_proba(self): + pred_proba = self.clf.predict_proba(np.append(self.X_test, + self.y_test.reshape(-1,1), + axis=1)) + assert (pred_proba.min() >= 0) + assert (pred_proba.max() <= 1) + + def test_prediction_proba_linear(self): + pred_proba = self.clf.predict_proba(np.append(self.X_test, + self.y_test.reshape(-1,1), + axis=1), method='linear') + assert (pred_proba.min() >= 0) + assert (pred_proba.max() <= 1) + + def test_prediction_proba_unify(self): + pred_proba = self.clf.predict_proba(np.append(self.X_test, + self.y_test.reshape(-1,1), + axis=1), method='unify') + assert (pred_proba.min() >= 0) + assert (pred_proba.max() <= 1) + + def test_prediction_proba_parameter(self): + with assert_raises(ValueError): + self.clf.predict_proba(np.append(self.X_test, + self.y_test.reshape(-1,1), + axis=1), method='something') + + def test_prediction_labels_confidence(self): + pred_labels, confidence = self.clf.predict(np.append(self.X_test, + self.y_test.reshape(-1,1), + axis=1), + return_confidence=True) + assert_equal(pred_labels.shape, self.y_test.shape) + assert_equal(confidence.shape, self.y_test.shape) + assert (confidence.min() >= 0) + assert (confidence.max() <= 1) + + def test_prediction_proba_linear_confidence(self): + pred_proba, confidence = self.clf.predict_proba(np.append(self.X_test, + self.y_test.reshape(-1,1), + axis=1), + method='linear', + return_confidence=True) + assert (pred_proba.min() >= 0) + assert (pred_proba.max() <= 1) + + assert_equal(confidence.shape, self.y_test.shape) + assert (confidence.min() >= 0) + assert (confidence.max() <= 1) + + def test_model_clone(self): + clone_clf = clone(self.clf) + + def tearDown(self): + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/pyod/test/test_mad.py b/pyod/test/test_mad.py index cefb03c73..d05b7e405 100644 --- a/pyod/test/test_mad.py +++ b/pyod/test/test_mad.py @@ -30,12 +30,26 @@ def setUp(self): self.n_test = 50 self.contamination = 0.1 self.roc_floor = 0.8 + # generate data and fit model without missing or infinite values: self.X_train, self.y_train, self.X_test, self.y_test = generate_data( n_train=self.n_train, n_test=self.n_test, n_features=1, contamination=self.contamination, random_state=42) - self.clf = MAD() self.clf.fit(self.X_train) + # generate data and fit model with missing value: + self.X_train_nan, self.y_train_nan, self.X_test_nan, self.y_test_nan = generate_data( + n_train=self.n_train, n_test=self.n_test, n_features=1, + contamination=self.contamination, random_state=42, + n_nan=1) + self.clf_nan = MAD() + self.clf_nan.fit(self.X_train_nan) + # generate data and fit model with infinite value: + self.X_train_inf, self.y_train_inf, self.X_test_inf, self.y_test_inf = generate_data( + n_train=self.n_train, n_test=self.n_test, n_features=1, + contamination=self.contamination, random_state=42, + n_inf=1) + self.clf_inf = MAD() + self.clf_inf.fit(self.X_train_inf) def test_parameters(self): assert (hasattr(self.clf, 'decision_scores_') and @@ -105,6 +119,14 @@ def test_fit_predict(self): pred_labels = self.clf.fit_predict(self.X_train) assert_equal(pred_labels.shape, self.y_train.shape) + def test_fit_predict_with_nan(self): + pred_labels = self.clf_nan.fit_predict(self.X_train_nan) + assert_equal(pred_labels.shape, self.y_train_nan.shape) + + def test_fit_predict_with_inf(self): + pred_labels = self.clf_inf.fit_predict(self.X_train_inf) + assert_equal(pred_labels.shape, self.y_train_inf.shape) + def test_fit_predict_score(self): self.clf.fit_predict_score(self.X_test, self.y_test) self.clf.fit_predict_score(self.X_test, self.y_test, @@ -125,6 +147,26 @@ def test_predict_rank(self): assert_array_less(pred_ranks, self.X_train.shape[0] + 1) assert_array_less(-0.1, pred_ranks) + def test_predict_rank_with_nan(self): + pred_scores = self.clf_nan.decision_function(self.X_test_nan) + pred_ranks = self.clf_nan._predict_rank(self.X_test_nan) + print(pred_ranks) + + # assert the order is reserved + assert_allclose(rankdata(pred_ranks), rankdata(pred_scores), atol=2) + assert_array_less(pred_ranks, self.X_train_nan.shape[0] + 1) + assert_array_less(-0.1, pred_ranks) + + def test_predict_rank_with_inf(self): + pred_scores = self.clf_inf.decision_function(self.X_test_inf) + pred_ranks = self.clf_inf._predict_rank(self.X_test_inf) + print(pred_ranks) + + # assert the order is reserved + assert_allclose(rankdata(pred_ranks), rankdata(pred_scores), atol=2) + assert_array_less(pred_ranks, self.X_train_inf.shape[0] + 1) + assert_array_less(-0.1, pred_ranks) + def test_predict_rank_normalized(self): pred_scores = self.clf.decision_function(self.X_test) pred_ranks = self.clf._predict_rank(self.X_test, normalized=True) @@ -134,6 +176,24 @@ def test_predict_rank_normalized(self): assert_array_less(pred_ranks, 1.01) assert_array_less(-0.1, pred_ranks) + def test_predict_rank_normalized_with_nan(self): + pred_scores = self.clf_nan.decision_function(self.X_test_nan) + pred_ranks = self.clf_nan._predict_rank(self.X_test_nan, normalized=True) + + # assert the order is reserved + assert_allclose(rankdata(pred_ranks), rankdata(pred_scores), atol=2) + assert_array_less(pred_ranks, 1.01) + assert_array_less(-0.1, pred_ranks) + + def test_predict_rank_normalized_with_inf(self): + pred_scores = self.clf_inf.decision_function(self.X_test_inf) + pred_ranks = self.clf_inf._predict_rank(self.X_test_inf, normalized=True) + + # assert the order is reserved + assert_allclose(rankdata(pred_ranks), rankdata(pred_scores), atol=2) + assert_array_less(pred_ranks, 1.01) + assert_array_less(-0.1, pred_ranks) + def test_check_univariate(self): with assert_raises(ValueError): MAD().fit(X=[[0.0, 0.0], @@ -149,6 +209,20 @@ def test_detect_anomaly(self): self.assertGreaterEqual(score[0], self.clf.threshold_) self.assertEqual(anomaly[0], 1) + def test_detect_anomaly_with_nan(self): + X_test = [[10000]] + score = self.clf_nan.decision_function(X_test) + anomaly = self.clf_nan.predict(X_test) + self.assertGreaterEqual(score[0], self.clf_nan.threshold_) + self.assertEqual(anomaly[0], 1) + + def test_detect_anomaly_with_inf(self): + X_test = [[10000]] + score = self.clf_inf.decision_function(X_test) + anomaly = self.clf_inf.predict(X_test) + self.assertGreaterEqual(score[0], self.clf_inf.threshold_) + self.assertEqual(anomaly[0], 1) + # todo: fix clone issue def test_model_clone(self): pass diff --git a/pyod/utils/data.py b/pyod/utils/data.py index 061d0bba9..6d495137c 100644 --- a/pyod/utils/data.py +++ b/pyod/utils/data.py @@ -25,7 +25,7 @@ def _generate_data(n_inliers, n_outliers, n_features, coef, offset, - random_state): + random_state, n_nan=0, n_inf=0): """Internal function to generate data samples. Parameters @@ -51,6 +51,12 @@ def _generate_data(n_inliers, n_outliers, n_features, coef, offset, If None, the random number generator is the RandomState instance used by `np.random`. + n_nan : int + The number of values that are missing (np.NaN). Defaults to zero. + + n_inf : int + The number of values that are infinite. (np.infty). Defaults to zero. + Returns ------- X : numpy array of shape (n_train, n_features) @@ -67,6 +73,14 @@ def _generate_data(n_inliers, n_outliers, n_features, coef, offset, y = np.r_[np.zeros((n_inliers,)), np.ones((n_outliers,))] + if n_nan > 0: + X = np.r_[X, np.full((n_nan, n_features), np.NaN)] + y = np.r_[y, np.full((n_nan), np.NaN)] + + if n_inf > 0: + X = np.r_[X, np.full((n_inf, n_features), np.infty)] + y = np.r_[y, np.full((n_inf), np.infty)] + return X, y @@ -97,7 +111,7 @@ def get_outliers_inliers(X, y): def generate_data(n_train=1000, n_test=500, n_features=2, contamination=0.1, train_only=False, offset=10, behaviour='old', - random_state=None): + random_state=None, n_nan=0, n_inf=0): """Utility function to generate synthesized data. Normal data is generated by a multivariate Gaussian distribution and outliers are generated by a uniform distribution. @@ -146,6 +160,12 @@ def generate_data(n_train=1000, n_test=500, n_features=2, contamination=0.1, If None, the random number generator is the RandomState instance used by `np.random`. + n_nan : int + The number of values that are missing (np.NaN). Defaults to zero. + + n_inf : int + The number of values that are infinite. (np.infty). Defaults to zero. + Returns ------- X_train : numpy array of shape (n_train, n_features) @@ -171,7 +191,8 @@ def generate_data(n_train=1000, n_test=500, n_features=2, contamination=0.1, n_inliers_train = int(n_train - n_outliers_train) X_train, y_train = _generate_data(n_inliers_train, n_outliers_train, - n_features, coef_, offset_, random_state) + n_features, coef_, offset_, random_state, + n_nan, n_inf) if train_only: return X_train, y_train @@ -180,7 +201,8 @@ def generate_data(n_train=1000, n_test=500, n_features=2, contamination=0.1, n_inliers_test = int(n_test - n_outliers_test) X_test, y_test = _generate_data(n_inliers_test, n_outliers_test, - n_features, coef_, offset_, random_state) + n_features, coef_, offset_, random_state, + n_nan, n_inf) if behaviour == 'old': warn('behaviour="old" is deprecated and will be removed ' diff --git a/pyod/version.py b/pyod/version.py index bcd44c318..8aefbca12 100644 --- a/pyod/version.py +++ b/pyod/version.py @@ -20,4 +20,4 @@ # Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer. # 'X.Y.dev0' is the canonical version of 'X.Y.dev' # -__version__ = '0.9.8' # pragma: no cover +__version__ = '0.9.9' # pragma: no cover diff --git a/setup.py b/setup.py index 3218995f5..c2a908b4a 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ def readme(): setup( name='pyod', version=__version__, - description='A Python Toolbox for Scalable Outlier Detection (Anomaly Detection)', + description='A Comprehensive and Scalable Python Library for Outlier Detection (Anomaly Detection)', long_description=readme(), long_description_content_type='text/x-rst', author='Yue Zhao',