Skip to content

Commit

Permalink
Merge b209ce2 into c4aecd1
Browse files Browse the repository at this point in the history
  • Loading branch information
KulikDM committed Mar 25, 2023
2 parents c4aecd1 + b209ce2 commit f04a886
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 108 deletions.
8 changes: 4 additions & 4 deletions examples/cd_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,22 @@
X_train, X_test, y_train, y_test = \
generate_data(n_train=n_train,
n_test=n_test,
n_features=2,
n_features=5,
contamination=contamination,
random_state=42)

# train HBOS detector
clf_name = 'CD'
clf = CD()
clf.fit(X_train, y_train)
clf.fit(X_train)

# get the prediction labels and outlier scores of the training data
y_train_pred = clf.labels_ # binary labels (0: inliers, 1: outliers)
y_train_scores = clf.decision_scores_ # raw outlier scores

# get the prediction on the test data
y_test_pred = clf.predict(np.append(X_test, y_test.reshape(-1,1), axis=1)) # outlier labels (0 or 1)
y_test_scores = clf.decision_function(np.append(X_test, y_test.reshape(-1,1), axis=1)) # outlier scores
y_test_pred = clf.predict(X_test) # outlier labels (0 or 1)
y_test_scores = clf.decision_function(X_test) # outlier scores

# evaluate and print the results
print("\nOn Training Data:")
Expand Down
175 changes: 97 additions & 78 deletions pyod/models/cd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,32 @@
from __future__ import print_function

import numpy as np
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression
from sklearn.utils import check_array
from sklearn.utils.validation import check_is_fitted

from .base import BaseDetector

def _Cooks_dist(X, y, model):
"""Calculated the Cook's distance
def whiten_data(X, pca):
X = pca.transform(X)
Parameters
----------
X : numpy array of shape (n_samples, n_features)
The training dataset.
y : numpy array of shape (n_samples)
The training datset
return X
model : object
Regression model used to calculate the Cook's distance
Returns
-------
distances_ : numpy array of shape (n_samples)
Cook's distance
"""

def Cooks_dist(X, y, model):
# Leverage is computed as the diagonal of the projection matrix of X
leverage = (X * np.linalg.pinv(X).T).sum(1)

Expand All @@ -34,21 +45,70 @@ def Cooks_dist(X, y, model):
# Compute the MSE from the residuals
residuals = y - model.predict(X)
mse = np.dot(residuals, residuals) / df

# Compute Cook's distance
residuals_studentized = residuals / np.sqrt(mse) / np.sqrt(1 - leverage)
distance_ = residuals_studentized ** 2 / X.shape[1]
distance_ *= leverage / (1 - leverage)
if (mse!=0) or (mse!=np.nan):
residuals_studentized = residuals / np.sqrt(mse) / np.sqrt(1 - leverage)
distance_ = residuals_studentized ** 2 / X.shape[1]
distance_ *= leverage / (1 - leverage)
distance_ = ((distance_ - distance_.min())
/ (distance_.max() - distance_.min()))

else:
distance_ = np.ones(len(y))*np.nan

return distance_

def _process_distances(X, model):
"""Calculated the mean Cook's distances for
each feature
Parameters
----------
X : numpy array of shape (n_samples, n_features)
The training dataset.
model : object
Regression model used to calculate the Cook's distance
Returns
-------
distances_ : numpy array of shape (n_samples)
mean Cook's distance
"""

distances_ = []
for i in range(X.shape[1]):

mod = model

# Extract new X and y inputs
exp = np.delete(X.copy(), i, axis=1)
resp = X[:,i]

exp = exp.reshape(-1,1) if exp.ndim == 1 else exp

# Fit the model
mod.fit(exp, resp)

# Get Cook's Distance
distance_ = _Cooks_dist(exp, resp, mod)

distances_.append(distance_)

distances_ = np.nanmean(distances_, axis=0)

return distances_


class CD(BaseDetector):
"""Cook's distance can be used to identify points that negatively
affect a regression model. A combination of each observation’s
leverage and residual values are used in the measurement. Higher
leverage and residuals relate to higher Cook’s distances.
Read more in the :cite:`cook1977detection`.
leverage and residuals relate to higher Cook’s distances. Note
that this method is unsupervised and requires at least two
features for X with which to calculate the mean Cook's distance
for each datapoint. Read more in the :cite:`cook1977detection`.
Parameters
----------
Expand All @@ -57,16 +117,8 @@ class CD(BaseDetector):
the proportion of outliers in the data set. Used when fitting to
define the threshold on the decision function.
whiten : bool, optional (default=True)
transform X to have a covariance matrix that is the identity matrix
of 1 in the diagonal and 0 for the other cells using PCA
rule_of_thumb : bool, optional (default=False)
to apply the rule of thumb prediction (4 / n) as the influence
threshold; where n is the number of samples. This has been know to
be a good estimate for values over this point as being outliers.
** Note the contamination level is reset when rule_of_thumb is
set to True
model : object, optional (default=LinearRegression())
Regression model used to calculate the Cook's distance
Attributes
Expand All @@ -88,74 +140,50 @@ class CD(BaseDetector):
``threshold_`` on ``decision_scores_``.
"""

def __init__(self, whitening=True, contamination=0.1, rule_of_thumb=False):

def __init__(self, contamination=0.1, model=LinearRegression()):
super(CD, self).__init__(contamination=contamination)
self.whitening = whitening
self.rule_of_thumb = rule_of_thumb
self.model = model

def fit(self, X, y):
"""Fit detector. y is necessary for supervised method.
def fit(self, X, y=None):
""""Fit detector. y is ignored in unsupervised methods.
Parameters
----------
X : numpy array of shape (n_samples, n_features)
The input samples.
y : numpy array of shape (n_samples,), optional (default=None)
The ground truth of the input samples (labels).
"""
y : Ignored
Not used, present for API consistency by convention.
# Define OLS model
self.model = LinearRegression()
Returns
-------
self : object
Fitted estimator.
"""

# Validate inputs X and y
try:
X = check_array(X)
except ValueError:
X = X.reshape(-1, 1)
X = check_array(X)

y = np.squeeze(check_array(y, ensure_2d=False))
self._set_n_classes(y)

# Apply whitening
if self.whitening:
self.pca = PCA(whiten=True)
self.pca.fit(X)
X = whiten_data(X, self.pca)

# Fit a linear model to X and y
self.model.fit(X, y)

# Get Cook's Distance
distance_ = Cooks_dist(X, y, self.model)

# Compute the influence threshold
if self.rule_of_thumb:
influence_threshold_ = 4 / X.shape[0]
self.contamination = sum(distance_ > influence_threshold_) / \
X.shape[0]

self.decision_scores_ = distance_
# Get Cook's distance
distances_ = _process_distances(X, self.model)

self.decision_scores_ = distances_

self._process_decision_scores()

return self

def decision_function(self, X):
"""Predict raw anomaly score of X using the fitted detector.
The anomaly score of an input sample is computed based on different
detector algorithms. For consistency, outliers are assigned with
larger anomaly scores.
For consistency, outliers are assigned with larger anomaly scores.
Parameters
----------
X : numpy array of shape (n_samples, n_features)
The independent and dependent/target samples with the target
samples being the last column of the numpy array such that
eg: X = np.append(x, y.reshape(-1,1), axis=1). Sparse matrices are
accepted only if they are supported by the base estimator.
The training input samples. Sparse matrices are accepted only
if they are supported by the base estimator.
Returns
-------
Expand All @@ -165,19 +193,10 @@ def decision_function(self, X):

check_is_fitted(self, ['decision_scores_', 'threshold_', 'labels_'])

try:
X = check_array(X)
except ValueError:
X = X.reshape(-1, 1)

y = X[:, -1]
X = X[:, :-1]

# Apply whitening
if self.whitening:
X = whiten_data(X, self.pca)
# Validate input X
X = check_array(X)

# Get Cook's Distance
distance_ = Cooks_dist(X, y, self.model)
# Get Cook's distance
distances_ = _process_distances(X, self.model)

return distance_
return distances_
35 changes: 9 additions & 26 deletions pyod/test/test_cd.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ def setUp(self):
self.n_test = 200
self.n_features = 2
self.contamination = 0.1
# GAN may yield unstable results; turning performance check off
# self.roc_floor = 0.8

self.X_train, self.X_test, self.y_train, self.y_test = generate_data(
n_train=self.n_train, n_test=self.n_test,
n_features=self.n_features, contamination=self.contamination,
Expand All @@ -54,9 +53,7 @@ def test_train_scores(self):
assert_equal(len(self.clf.decision_scores_), self.X_train.shape[0])

def test_prediction_scores(self):
pred_scores = self.clf.decision_function(np.append(self.X_test,
self.y_test.reshape(-1,1),
axis=1))
pred_scores = self.clf.decision_function(self.X_test)

# check score shapes
assert_equal(pred_scores.shape[0], self.X_test.shape[0])
Expand All @@ -65,52 +62,38 @@ def test_prediction_scores(self):
# assert (roc_auc_score(self.y_test, pred_scores) >= self.roc_floor)

def test_prediction_labels(self):
pred_labels = self.clf.predict(np.append(self.X_test,
self.y_test.reshape(-1,1),
axis=1))
pred_labels = self.clf.predict(self.X_test)
assert_equal(pred_labels.shape, self.y_test.shape)

def test_prediction_proba(self):
pred_proba = self.clf.predict_proba(np.append(self.X_test,
self.y_test.reshape(-1,1),
axis=1))
pred_proba = self.clf.predict_proba(self.X_test)
assert (pred_proba.min() >= 0)
assert (pred_proba.max() <= 1)

def test_prediction_proba_linear(self):
pred_proba = self.clf.predict_proba(np.append(self.X_test,
self.y_test.reshape(-1,1),
axis=1), method='linear')
pred_proba = self.clf.predict_proba(self.X_test, method='linear')
assert (pred_proba.min() >= 0)
assert (pred_proba.max() <= 1)

def test_prediction_proba_unify(self):
pred_proba = self.clf.predict_proba(np.append(self.X_test,
self.y_test.reshape(-1,1),
axis=1), method='unify')
pred_proba = self.clf.predict_proba(self.X_test, method='unify')
assert (pred_proba.min() >= 0)
assert (pred_proba.max() <= 1)

def test_prediction_proba_parameter(self):
with assert_raises(ValueError):
self.clf.predict_proba(np.append(self.X_test,
self.y_test.reshape(-1,1),
axis=1), method='something')
self.clf.predict_proba(self.X_test, method='something')

def test_prediction_labels_confidence(self):
pred_labels, confidence = self.clf.predict(np.append(self.X_test,
self.y_test.reshape(-1,1),
axis=1),
pred_labels, confidence = self.clf.predict(self.X_test,
return_confidence=True)
assert_equal(pred_labels.shape, self.y_test.shape)
assert_equal(confidence.shape, self.y_test.shape)
assert (confidence.min() >= 0)
assert (confidence.max() <= 1)

def test_prediction_proba_linear_confidence(self):
pred_proba, confidence = self.clf.predict_proba(np.append(self.X_test,
self.y_test.reshape(-1,1),
axis=1),
pred_proba, confidence = self.clf.predict_proba(self.X_test,
method='linear',
return_confidence=True)
assert (pred_proba.min() >= 0)
Expand Down

0 comments on commit f04a886

Please sign in to comment.