Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ Below is a list of the methods currently implemented in this module.
8. Edited Nearest Neighbours [6]_
9. Instance Hardness Threshold [7]_
10. Repeated Edited Nearest Neighbours [14]_
11. AllKNN [14]_

* Over-sampling
1. Random minority over-sampling with replacement
Expand Down
4 changes: 2 additions & 2 deletions doc/todo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ Version 0.2
New methods
~~~~~~~~~~~

* AIIKNN_: Garcia, Salvador, et al. "Prototype selection for nearest neighbor classification: Taxonomy and empirical study." IEEE Transactions on Pattern Analysis and Machine Intelligence 34.3 (2012): 417-435.
* SMOTEBagging_: Wang, Shuo, and Xin Yao. "Diversity analysis on imbalanced data sets by using ensemble models." Computational Intelligence and Data Mining, 2009. CIDM'09. IEEE Symposium on. IEEE, 2009.

.. _AIIKNN: https://www.semanticscholar.org/paper/Prototype-Selection-for-Nearest-Neighbor-Garc%C3%ADa-Derrac/fbca1824c49e02da37e5e780eaf0ab6ddfaf5614/pdf
.. _SMOTEBagging: http://pages.bangor.ac.uk/~mas00a/papers/jpjrcolkis15.pdf

API improvements
~~~~~~~~~~~~~~~~
Expand Down
1 change: 1 addition & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Changelog

- Added support for bumpversion.
- Added doctest in the documentation.
- Added AllKNN under sampling technique.


.. _changes_0_1:
Expand Down
92 changes: 92 additions & 0 deletions examples/under-sampling/plot_allknn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""
==================================
AllKNN
==================================

An illustration of the AllKNN method.

"""

print(__doc__)

import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

# Define some color for the plotting
almost_black = '#262626'
palette = sns.color_palette()

from sklearn.datasets import make_classification
from sklearn.decomposition import PCA

from imblearn.under_sampling import EditedNearestNeighbours
from imblearn.under_sampling import RepeatedEditedNearestNeighbours
from imblearn.under_sampling import AllKNN

# Generate the dataset
X, y = make_classification(n_classes=2, class_sep=1.25, weights=[0.3, 0.7],
n_informative=3, n_redundant=1, flip_y=0,
n_features=5, n_clusters_per_class=1,
n_samples=5000, random_state=10)

# Instanciate a PCA object for the sake of easy visualisation
pca = PCA(n_components=2)
# Fit and transform x to visualise inside a 2D feature space
X_vis = pca.fit_transform(X)

# Three subplots, unpack the axes array immediately
f, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4)

ax1.scatter(X_vis[y == 0, 0], X_vis[y == 0, 1], label="Class #0", alpha=.5,
edgecolor=almost_black, facecolor=palette[0], linewidth=0.15)
ax1.scatter(X_vis[y == 1, 0], X_vis[y == 1, 1], label="Class #1", alpha=.5,
edgecolor=almost_black, facecolor=palette[2], linewidth=0.15)
ax1.set_title('Original set')

# Apply the ENN
print('ENN')
enn = EditedNearestNeighbours()
X_resampled, y_resampled = enn.fit_sample(X, y)
X_res_vis = pca.transform(X_resampled)
print('Reduced {:.2f}\%'.format(100 * (1 - float(len(X_resampled))/ len(X))))

ax2.scatter(X_res_vis[y_resampled == 0, 0], X_res_vis[y_resampled == 0, 1],
label="Class #0", alpha=.5, edgecolor=almost_black,
facecolor=palette[0], linewidth=0.15)
ax2.scatter(X_res_vis[y_resampled == 1, 0], X_res_vis[y_resampled == 1, 1],
label="Class #1", alpha=.5, edgecolor=almost_black,
facecolor=palette[2], linewidth=0.15)
ax2.set_title('Edited nearest neighbours')

# Apply the RENN
print('RENN')
renn = RepeatedEditedNearestNeighbours()
X_resampled, y_resampled = renn.fit_sample(X, y)
X_res_vis = pca.transform(X_resampled)
print('Reduced {:.2f}\%'.format(100 * (1 - float(len(X_resampled))/ len(X))))

ax3.scatter(X_res_vis[y_resampled == 0, 0], X_res_vis[y_resampled == 0, 1],
label="Class #0", alpha=.5, edgecolor=almost_black,
facecolor=palette[0], linewidth=0.15)
ax3.scatter(X_res_vis[y_resampled == 1, 0], X_res_vis[y_resampled == 1, 1],
label="Class #1", alpha=.5, edgecolor=almost_black,
facecolor=palette[2], linewidth=0.15)
ax3.set_title('Repeated Edited nearest neighbours')

# Apply the AllKNN
print('AllKNN')
allknn = AllKNN()
X_resampled, y_resampled = allknn.fit_sample(X, y)
X_res_vis = pca.transform(X_resampled)
print('Reduced {:.2f}\%'.format(100 * (1 - float(len(X_resampled))/ len(X))))

ax4.scatter(X_res_vis[y_resampled == 0, 0], X_res_vis[y_resampled == 0, 1],
label="Class #0", alpha=.5, edgecolor=almost_black,
facecolor=palette[0], linewidth=0.15)
ax4.scatter(X_res_vis[y_resampled == 1, 0], X_res_vis[y_resampled == 1, 1],
label="Class #1", alpha=.5, edgecolor=almost_black,
facecolor=palette[2], linewidth=0.15)
ax4.set_title('AllKNN')

plt.show()
2 changes: 2 additions & 0 deletions imblearn/under_sampling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .neighbourhood_cleaning_rule import NeighbourhoodCleaningRule
from .edited_nearest_neighbours import EditedNearestNeighbours
from .edited_nearest_neighbours import RepeatedEditedNearestNeighbours
from .edited_nearest_neighbours import AllKNN
from .instance_hardness_threshold import InstanceHardnessThreshold

__all__ = ['RandomUnderSampler',
Expand All @@ -23,4 +24,5 @@
'NeighbourhoodCleaningRule',
'EditedNearestNeighbours',
'RepeatedEditedNearestNeighbours',
'AllKNN',
'InstanceHardnessThreshold']
170 changes: 170 additions & 0 deletions imblearn/under_sampling/edited_nearest_neighbours.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,3 +382,173 @@ def _sample(self, X, y):
return X_resampled, y_resampled, idx_under
else:
return X_resampled, y_resampled


class AllKNN(SamplerMixin):
"""Class to perform under-sampling based on the AllKNN method.

Parameters
----------
return_indices : bool, optional (default=False)
Whether or not to return the indices of the samples randomly
selected from the majority class.

random_state : int, RandomState instance or None, optional (default=None)
If int, random_state is the seed used by the random number generator;
If RandomState instance, random_state is the random number generator;
If None, the random number generator is the RandomState instance used
by np.random.

size_ngh : int, optional (default=3)
Size of the neighbourhood to consider to compute the average
distance to the minority point samples.

kind_sel : str, optional (default='all')
Strategy to use in order to exclude samples.

- If 'all', all neighbours will have to agree with the samples of
interest to not be excluded.
- If 'mode', the majority vote of the neighbours will be used in
order to exclude a sample.

n_jobs : int, optional (default=-1)
The number of thread to open when it is possible.

Attributes
----------
min_c_ : str or int
The identifier of the minority class.

max_c_ : str or int
The identifier of the majority class.

stats_c_ : dict of str/int : int
A dictionary in which the number of occurences of each class is
reported.

X_shape_ : tuple of int
Shape of the data `X` during fitting.

Notes
-----
The method is based on [1]_.

This class supports multi-class.

Examples
--------

>>> from collections import Counter
>>> from sklearn.datasets import make_classification
>>> from imblearn.under_sampling import AllKNN
>>> X, y = make_classification(n_classes=2, class_sep=2, weights=[0.1, 0.9],
... n_informative=3, n_redundant=1, flip_y=0,
... n_features=20, n_clusters_per_class=1,
... n_samples=1000, random_state=10)
>>> print('Original dataset shape {}'.format(Counter(y)))
Original dataset shape Counter({1: 900, 0: 100})
>>> allknn = AllKNN(random_state=42)
>>> X_res, y_res = allknn.fit_sample(X, y)
>>> print('Resampled dataset shape {}'.format(Counter(y_res)))
Resampled dataset shape Counter({1: 883, 0: 100})

References
----------
.. [1] I. Tomek, "An Experiment with the Edited Nearest-Neighbor
Rule," IEEE Transactions on Systems, Man, and Cybernetics, vol. 6(6),
pp. 448-452, June 1976.

"""

def __init__(self, return_indices=False, random_state=None,
size_ngh=3, kind_sel='all', n_jobs=-1):
super(AllKNN, self).__init__()
self.return_indices = return_indices
self.random_state = random_state
self.size_ngh = size_ngh
self.kind_sel = kind_sel
self.n_jobs = n_jobs
self.enn_ = EditedNearestNeighbours(
return_indices=self.return_indices,
random_state=self.random_state,
size_ngh=self.size_ngh,
kind_sel=self.kind_sel,
n_jobs=self.n_jobs)

def fit(self, X, y):
"""Find the classes statistics before to perform sampling.

Parameters
----------
X : ndarray, shape (n_samples, n_features)
Matrix containing the data which have to be sampled.

y : ndarray, shape (n_samples, )
Corresponding label for each sample in X.

Returns
-------
self : object,
Return self.

"""
super(AllKNN, self).fit(X, y)
self.enn_.fit(X, y)

return self

def _sample(self, X, y):
"""Resample the dataset.

Parameters
----------
X : ndarray, shape (n_samples, n_features)
Matrix containing the data which have to be sampled.

y : ndarray, shape (n_samples, )
Corresponding label for each sample in X.

Returns
-------
X_resampled : ndarray, shape (n_samples_new, n_features)
The array containing the resampled data.

y_resampled : ndarray, shape (n_samples_new)
The corresponding label of `X_resampled`

idx_under : ndarray, shape (n_samples, )
If `return_indices` is `True`, a boolean array will be returned
containing the which samples have been selected.

"""

if self.kind_sel not in SEL_KIND:
raise NotImplementedError

X_, y_ = X, y

if self.return_indices:
idx_under = np.arange(X.shape[0], dtype=int)

prev_len = y.shape[0]

for curr_size_ngh in range(1, self.size_ngh + 1):
self.logger.debug('Apply ENN size_ngh #%s', curr_size_ngh)
# updating ENN size_ngh
self.enn_.size_ngh = curr_size_ngh
if self.return_indices:
X_, y_, idx_ = self.enn_.fit_sample(X_, y_)
idx_under = idx_under[idx_]
else:
X_, y_ = self.enn_.fit_sample(X_, y_)

self.logger.info('Under-sampling performed: %s', Counter(y_))

X_resampled, y_resampled = X_, y_

# Check if the indices of the samples selected should be returned too
if self.return_indices:
# Return the indices of interest
return X_resampled, y_resampled, idx_under
else:
return X_resampled, y_resampled
Binary file added imblearn/under_sampling/tests/data/allknn_idx.npy
Binary file not shown.
Binary file added imblearn/under_sampling/tests/data/allknn_x.npy
Binary file not shown.
Binary file not shown.
Binary file added imblearn/under_sampling/tests/data/allknn_y.npy
Binary file not shown.
Binary file not shown.
Loading