Skip to content

Commit

Permalink
Refactor data generation and examples
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhao062 authored and yuezhao@cs.toronto.edu committed Jun 2, 2018
1 parent 41e44f8 commit f9d5050
Show file tree
Hide file tree
Showing 24 changed files with 370 additions and 496 deletions.
4 changes: 2 additions & 2 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ The toolkit consists of three major groups of functionalities:
* Threshold Sum (Thresh) [6]

3. **Outlier detection utility functions**, see :mod:`pyod.utils`.
* :func:`pyod.utils.utility.score_to_lable`: converting raw outlier scores to binary labels
* :func:`pyod.utils.utility.score_to_label`: converting raw outlier scores to binary labels
* :func:`pyod.utils.utility.precision_n_scores`: one of the popular evaluation metrics for outlier mining (precision @ rank n)
* :func:`pyod.utils.load_data.generate_data`: generate pseudo data for outlier detection experiment
* :func:`pyod.utils.data.generate_data`: generate pseudo data for outlier detection experiment
* :func:`pyod.utils.stat_models.wpearsonr`:: weighted pearson is useful in pseudo ground truth generation

Contents
Expand Down
4 changes: 2 additions & 2 deletions docs/pyod.utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ pyod.utils package
Submodules
----------

pyod.utils.load\_data module
pyod.utils.data module
----------------------------

.. automodule:: pyod.utils.load_data
.. automodule:: pyod.utils.data
:members:
:undoc-members:
:show-inheritance:
Expand Down
Binary file added examples/KNN.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
98 changes: 31 additions & 67 deletions examples/abod_example.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
"""
Example of using ABOD for outlier detection
Example of using Angle-base outlier detection (ABOD) for outlier detection
"""
from __future__ import division
from __future__ import print_function
Expand All @@ -12,83 +12,47 @@
# if pyod is installed, no need to use the following line
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

try:
from pathlib import Path
except ImportError:
from pathlib2 import Path # python 2 backport

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import numpy as np
from sklearn.metrics import roc_auc_score

from pyod.models.abod import ABOD
from pyod.utils.load_data import generate_data

from pyod.utils.data import generate_data
from pyod.utils.data import visualize
from pyod.utils.utility import precision_n_scores

if __name__ == "__main__":
contamination = 0.1 # percentage of outliers
n_train = 1000
n_test = 500
n_train = 200 # number of training points
n_test = 100 # number of testing points

X_train, y_train, c_train, X_test, y_test, c_test = generate_data(
X_train, y_train, X_test, y_test = generate_data(
n_train=n_train, n_test=n_test, contamination=contamination)

# train a ABOD detector (default version)
clf = ABOD(contamination=contamination)
# train ABOD detector
clf_name = 'ABOD'
clf = ABOD()
clf.fit(X_train)

# get the prediction on the training data
y_train_pred = clf.labels_
y_train_score = clf.decision_scores_
# get the prediction label and decision_scores_ on the training data
y_train_pred = clf.labels_ # binary labels (0: inliers, 1: outliers)
y_train_scores = clf.decision_scores_ # raw outlier scores

# get the prediction on the test data
y_test_pred = clf.predict(X_test)
y_test_score = clf.decision_function(X_test)

print('Train ROC:{roc}, precision@n_train_:{prn}'.format(
roc=roc_auc_score(y_train, y_train_score),
prn=precision_n_scores(y_train, y_train_score)))

print('Test ROC:{roc}, precision@n_train_:{prn}'.format(
roc=roc_auc_score(y_test, y_test_score),
prn=precision_n_scores(y_test, y_test_score)))

#######################################################################
# Visualizations
# initialize the log directory if it does not exist
Path('example_figs').mkdir(parents=True, exist_ok=True)

# plot the results
fig = plt.figure(figsize=(12, 10))
ax = fig.add_subplot(221)
plt.scatter(X_train[:, 0], X_train[:, 1], c=c_train)
plt.title('Train ground truth')
legend_elements = [Line2D([0], [0], marker='o', color='w', label='normal',
markerfacecolor='b', markersize=8),
Line2D([0], [0], marker='o', color='w', label='outlier',
markerfacecolor='r', markersize=8)]

plt.legend(handles=legend_elements, loc=4)

ax = fig.add_subplot(222)
plt.scatter(X_test[:, 0], X_test[:, 1], c=c_test)
plt.title('Test ground truth')
plt.legend(handles=legend_elements, loc=4)

ax = fig.add_subplot(223)
plt.scatter(X_train[:, 0], X_train[:, 1], c=y_train_pred)
plt.title('Train prediction by ABOD')
legend_elements = [Line2D([0], [0], marker='o', color='w', label='normal',
markerfacecolor='0', markersize=8),
Line2D([0], [0], marker='o', color='w', label='outlier',
markerfacecolor='yellow', markersize=8)]
plt.legend(handles=legend_elements, loc=4)

ax = fig.add_subplot(224)
plt.scatter(X_test[:, 0], X_test[:, 1], c=y_test_pred)
plt.title('Test prediction by ABOD')
plt.legend(handles=legend_elements, loc=4)

plt.savefig(os.path.join('example_figs', 'abod.png'), dpi=300)

plt.show()
y_test_pred = clf.predict(X_test) # outlier labels (0 or 1)
y_test_scores = clf.decision_function(X_test) # outlier scores

# evaluate and print the results
print('{clf_name} Train ROC:{roc}, precision @ rank n:{prn}'.format(
clf_name=clf_name,
roc=np.round(roc_auc_score(y_train, y_train_scores), decimals=4),
prn=np.round(precision_n_scores(y_train, y_train_scores), decimals=4)))

print('{clf_name} Test ROC:{roc}, precision @ rank n:{prn}'.format(
clf_name=clf_name,
roc=np.round(roc_auc_score(y_test, y_test_scores), decimals=4),
prn=np.round(precision_n_scores(y_test, y_test_scores), decimals=4)))

# visualize the results
visualize(clf_name, X_train, y_train, X_test, y_test, y_train_pred,
y_test_pred, save_figure=False)
6 changes: 3 additions & 3 deletions examples/comb_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from pyod.models.combination import aom, moa, average, maximization
from pyod.utils.utility import precision_n_scores
from pyod.utils.utility import standardizer
from pyod.utils.load_data import generate_data
from pyod.utils.data import generate_data

if __name__ == "__main__":

Expand All @@ -43,11 +43,11 @@
except TypeError:
print('{data_file} does not exist. Use generated data'.format(
data_file=mat_file))
X, y, _ = generate_data(train_only=True) # load data
X, y = generate_data(train_only=True) # load data
except IOError:
print('{data_file} does not exist. Use generated data'.format(
data_file=mat_file))
X, y, _ = generate_data(train_only=True) # load data
X, y = generate_data(train_only=True) # load data
else:
X = mat['X']
y = mat['y'].ravel()
Expand Down
Binary file removed examples/example_figs/knn.png
Binary file not shown.
2 changes: 1 addition & 1 deletion examples/feat_bagging_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from pyod.models.lof import LOF
from pyod.models.iforest import IForest
from pyod.models.base import clone
from pyod.utils.load_data import generate_data
from pyod.utils.data import generate_data
from pyod.utils.utility import precision_n_scores
from sklearn.utils.estimator_checks import check_estimator
from sklearn.linear_model import LogisticRegression
Expand Down
93 changes: 29 additions & 64 deletions examples/hbos_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,83 +11,48 @@
# temporary solution for relative imports in case pyod is not installed
# if pyod is installed, no need to use the following line
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
try:
from pathlib import Path
except ImportError:
from pathlib2 import Path # python 2 backport

import numpy as np
from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

from pyod.models.hbos import HBOS
from pyod.utils.load_data import generate_data

from pyod.utils.data import generate_data
from pyod.utils.data import visualize
from pyod.utils.utility import precision_n_scores

if __name__ == "__main__":
contamination = 0.1 # percentage of outliers
n_train = 1000
n_test = 500
n_train = 200 # number of training points
n_test = 100 # number of testing points

X_train, y_train, c_train, X_test, y_test, c_test = generate_data(
X_train, y_train, X_test, y_test = generate_data(
n_train=n_train, n_test=n_test, contamination=contamination)

# train a HBOS detector (default version)
# train HBOS detector
clf_name = 'HBOS'
clf = HBOS()
clf.fit(X_train)

# get the prediction on the training data
y_train_pred = clf.labels_
y_train_score = clf.decision_scores_
# get the prediction label and decision_scores_ on the training data
y_train_pred = clf.labels_ # binary labels (0: inliers, 1: outliers)
y_train_scores = clf.decision_scores_ # raw outlier scores

# get the prediction on the test data
y_test_pred = clf.predict(X_test)
y_test_score = clf.decision_function(X_test)

print('Train ROC:{roc}, precision@n_train_:{prn}'.format(
roc=roc_auc_score(y_train, y_train_score),
prn=precision_n_scores(y_train, y_train_score)))

print('Test ROC:{roc}, precision@n_train_:{prn}'.format(
roc=roc_auc_score(y_test, y_test_score),
prn=precision_n_scores(y_test, y_test_score)))

#######################################################################
# Visualizations
# initialize the log directory if it does not exist
Path('example_figs').mkdir(parents=True, exist_ok=True)

# plot the results
fig = plt.figure(figsize=(12, 10))
ax = fig.add_subplot(221)
plt.scatter(X_train[:, 0], X_train[:, 1], c=c_train)
plt.title('Train ground truth')
legend_elements = [Line2D([0], [0], marker='o', color='w', label='normal',
markerfacecolor='b', markersize=8),
Line2D([0], [0], marker='o', color='w', label='outlier',
markerfacecolor='r', markersize=8)]

plt.legend(handles=legend_elements, loc=4)

ax = fig.add_subplot(222)
plt.scatter(X_test[:, 0], X_test[:, 1], c=c_test)
plt.title('Test ground truth')
plt.legend(handles=legend_elements, loc=4)

ax = fig.add_subplot(223)
plt.scatter(X_train[:, 0], X_train[:, 1], c=y_train_pred)
plt.title('Train prediction by HBOS')
legend_elements = [Line2D([0], [0], marker='o', color='w', label='normal',
markerfacecolor='0', markersize=8),
Line2D([0], [0], marker='o', color='w', label='outlier',
markerfacecolor='yellow', markersize=8)]
plt.legend(handles=legend_elements, loc=4)

ax = fig.add_subplot(224)
plt.scatter(X_test[:, 0], X_test[:, 1], c=y_test_pred)
plt.title('Test prediction by HBOS')
plt.legend(handles=legend_elements, loc=4)

plt.savefig(os.path.join('example_figs', 'hbos.png'), dpi=300)

plt.show()
y_test_pred = clf.predict(X_test) # outlier labels (0 or 1)
y_test_scores = clf.decision_function(X_test) # outlier scores

# evaluate and print the results
print('{clf_name} Train ROC:{roc}, precision @ rank n:{prn}'.format(
clf_name=clf_name,
roc=np.round(roc_auc_score(y_train, y_train_scores), decimals=4),
prn=np.round(precision_n_scores(y_train, y_train_scores), decimals=4)))

print('{clf_name} Test ROC:{roc}, precision @ rank n:{prn}'.format(
clf_name=clf_name,
roc=np.round(roc_auc_score(y_test, y_test_scores), decimals=4),
prn=np.round(precision_n_scores(y_test, y_test_scores), decimals=4)))

# visualize the results
visualize(clf_name, X_train, y_train, X_test, y_test, y_train_pred,
y_test_pred, save_figure=False)
Loading

0 comments on commit f9d5050

Please sign in to comment.