Skip to content

Commit

Permalink
Merge eef0ee9 into ad355b7
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre committed Oct 30, 2016
2 parents ad355b7 + eef0ee9 commit f352e45
Show file tree
Hide file tree
Showing 11 changed files with 579 additions and 82 deletions.
1 change: 1 addition & 0 deletions doc/whats_new.rst
Expand Up @@ -53,6 +53,7 @@ API changes summary
- Two base classes :class:`BaseBinaryclassSampler` and :class:`BaseMulticlassSampler` have been created to handle the target type and raise warning in case of abnormality. By `Guillaume Lemaitre`_ and `Christos Aridas`_.
- Move `random_state` to be assigned in the :class:`SamplerMixin` initialization. By `Guillaume Lemaitre`_.
- Provide estimators instead of parameters in :class:`combine.SMOTEENN` and :class:`combine.SMOTETomek`. Therefore, the list of parameters have been deprecated. By `Guillaume Lemaitre`_ and `Christos Aridas`_.
- `n_neighbors` accept `KNeighborsMixin` based object for :class:`under_sampling.EditedNearestNeighbors`, :class:`under_sampling.CondensedNeareastNeigbour`, :class:`under_sampling.NeighbourhoodCleaningRule`, :class:`under_sampling.RepeatedEditedNearestNeighbours`, and :class:`under_sampling.AllKNN`. By `Guillaume Lemaitre`_.

Documentation changes
~~~~~~~~~~~~~~~~~~~~~
Expand Down
120 changes: 89 additions & 31 deletions imblearn/under_sampling/edited_nearest_neighbours.py
Expand Up @@ -8,6 +8,7 @@
import numpy as np
from scipy.stats import mode
from sklearn.neighbors import NearestNeighbors
from sklearn.neighbors.base import KNeighborsMixin

from ..base import BaseMulticlassSampler

Expand Down Expand Up @@ -37,9 +38,13 @@ class EditedNearestNeighbours(BaseMulticlassSampler):
NOTE: size_ngh is deprecated from 0.2 and will be replaced in 0.4
Use ``n_neighbors`` instead.
n_neighbors : int, optional (default=3)
n_neighbors : int or KNeighborsMixin object, optional (default=3)
Size of the neighbourhood to consider to compute the average
distance to the minority point samples.
distance to the minority point samples. A KNeighborsMixin object
implementing a `kneighbors` method can be provided. Currently,
`NearestNeighbors` (i.e., exact NN algorithm) and `LSHForest`
(i.e., approximate NN algorithm) are the 2 types provided in
scikit-learn.
kind_sel : str, optional (default='all')
Strategy to use in order to exclude samples.
Expand Down Expand Up @@ -108,6 +113,42 @@ def __init__(self, return_indices=False, random_state=None,
self.kind_sel = kind_sel
self.n_jobs = n_jobs

def _validate_estimator(self):
"""Private function to create the NN estimator"""

if isinstance(self.n_neighbors, int):
self.nn_ = NearestNeighbors(n_neighbors=self.n_neighbors + 1,
n_jobs=self.n_jobs)
elif isinstance(self.n_neighbors, KNeighborsMixin):
self.nn_ = self.n_neighbors
else:
raise ValueError('`n_neighbors` has to be be either int or a'
' subclass of KNeighborsMixin.')

def fit(self, X, y):
"""Find the classes statistics before to perform sampling.
Parameters
----------
X : ndarray, shape (n_samples, n_features)
Matrix containing the data which have to be sampled.
y : ndarray, shape (n_samples, )
Corresponding label for each sample in X.
Returns
-------
self : object,
Return self.
"""

super(EditedNearestNeighbours, self).fit(X, y)

self._validate_estimator()

return self

def _sample(self, X, y):
"""Resample the dataset.
Expand Down Expand Up @@ -148,11 +189,8 @@ def _sample(self, X, y):
if self.return_indices:
idx_under = np.flatnonzero(y == self.min_c_)

# Create a k-NN to fit the whole data
nn_obj = NearestNeighbors(n_neighbors=self.n_neighbors + 1,
n_jobs=self.n_jobs)
# Fit the data
nn_obj.fit(X)
self.nn_.fit(X)

# Loop over the other classes under picking at random
for key in self.stats_c_.keys():
Expand All @@ -166,8 +204,8 @@ def _sample(self, X, y):
sub_samples_y = y[y == key]

# Find the NN for the current class
nnhood_idx = nn_obj.kneighbors(sub_samples_x,
return_distance=False)[:, 1:]
nnhood_idx = self.nn_.kneighbors(sub_samples_x,
return_distance=False)[:, 1:]

# Get the label of the corresponding to the index
nnhood_label = y[nnhood_idx]
Expand Down Expand Up @@ -233,9 +271,13 @@ class RepeatedEditedNearestNeighbours(BaseMulticlassSampler):
NOTE: size_ngh is deprecated from 0.2 and will be replaced in 0.4
Use ``n_neighbors`` instead.
n_neighbors : int, optional (default=3)
n_neighbors : int or KNeighborsMixin object, optional (default=3)
Size of the neighbourhood to consider to compute the average
distance to the minority point samples.
distance to the minority point samples. A KNeighborsMixin object
implementing a `kneighbors` method can be provided. Currently,
`NearestNeighbors` (i.e., exact NN algorithm) and `LSHForest`
(i.e., approximate NN algorithm) are the 2 types provided in
scikit-learn.
kind_sel : str, optional (default='all')
Strategy to use in order to exclude samples.
Expand Down Expand Up @@ -301,20 +343,23 @@ class RepeatedEditedNearestNeighbours(BaseMulticlassSampler):
def __init__(self, return_indices=False, random_state=None,
size_ngh=None, n_neighbors=3, max_iter=100, kind_sel='all',
n_jobs=-1):
super(RepeatedEditedNearestNeighbours, self).__init__()
super(RepeatedEditedNearestNeighbours, self).__init__(
random_state=random_state)
self.return_indices = return_indices
self.random_state = random_state
self.size_ngh = size_ngh
self.n_neighbors = n_neighbors
self.kind_sel = kind_sel
self.n_jobs = n_jobs
self.max_iter = max_iter
self.enn_ = EditedNearestNeighbours(
return_indices=self.return_indices,
random_state=self.random_state,
n_neighbors=self.n_neighbors,
kind_sel=self.kind_sel,
n_jobs=self.n_jobs)

def _validate_estimator(self):
"""Private function to create the NN estimator"""

self.enn_ = EditedNearestNeighbours(return_indices=self.return_indices,
random_state=self.random_state,
n_neighbors=self.n_neighbors,
kind_sel=self.kind_sel,
n_jobs=self.n_jobs)

def fit(self, X, y):
"""Find the classes statistics before to perform sampling.
Expand All @@ -333,7 +378,11 @@ def fit(self, X, y):
Return self.
"""

super(RepeatedEditedNearestNeighbours, self).fit(X, y)

self._validate_estimator()

self.enn_.fit(X, y)

return self
Expand Down Expand Up @@ -466,9 +515,13 @@ class AllKNN(BaseMulticlassSampler):
NOTE: size_ngh is deprecated from 0.2 and will be replaced in 0.4
Use ``n_neighbors`` instead.
n_neighbors : int, optional (default=3)
n_neighbors : int or KNeighborsMixin object, optional (default=3)
Size of the neighbourhood to consider to compute the average
distance to the minority point samples.
distance to the minority point samples. A KNeighborsMixin object
implementing a `kneighbors` method can be provided. Currently,
`NearestNeighbors` (i.e., exact NN algorithm) and `LSHForest`
(i.e., approximate NN algorithm) are the 2 types provided in
scikit-learn.
kind_sel : str, optional (default='all')
Strategy to use in order to exclude samples.
Expand Down Expand Up @@ -529,19 +582,21 @@ class AllKNN(BaseMulticlassSampler):

def __init__(self, return_indices=False, random_state=None,
size_ngh=None, n_neighbors=3, kind_sel='all', n_jobs=-1):
super(AllKNN, self).__init__()
super(AllKNN, self).__init__(random_state=random_state)
self.return_indices = return_indices
self.random_state = random_state
self.size_ngh = size_ngh
self.n_neighbors = n_neighbors
self.kind_sel = kind_sel
self.n_jobs = n_jobs
self.enn_ = EditedNearestNeighbours(
return_indices=self.return_indices,
random_state=self.random_state,
n_neighbors=self.n_neighbors,
kind_sel=self.kind_sel,
n_jobs=self.n_jobs)

def _validate_estimator(self):
"""Private function to create the NN estimator"""

self.enn_ = EditedNearestNeighbours(return_indices=self.return_indices,
random_state=self.random_state,
n_neighbors=self.n_neighbors,
kind_sel=self.kind_sel,
n_jobs=self.n_jobs)

def fit(self, X, y):
"""Find the classes statistics before to perform sampling.
Expand All @@ -561,6 +616,9 @@ def fit(self, X, y):
"""
super(AllKNN, self).fit(X, y)

self._validate_estimator()

self.enn_.fit(X, y)

return self
Expand Down Expand Up @@ -598,10 +656,10 @@ def _sample(self, X, y):
if self.return_indices:
idx_under = np.arange(X.shape[0], dtype=int)

for curr_size_ngh in range(1, self.n_neighbors + 1):
self.logger.debug('Apply ENN size_ngh #%s', curr_size_ngh)
for curr_size_ngh in range(1, self.enn_.nn_.n_neighbors):
self.logger.debug('Apply ENN n_neighbors #%s', curr_size_ngh)
# updating ENN size_ngh
self.enn_.size_ngh = curr_size_ngh
self.enn_.n_neighbors = curr_size_ngh

if self.return_indices:
X_enn, y_enn, idx_enn = self.enn_.fit_sample(X_, y_)
Expand Down

0 comments on commit f352e45

Please sign in to comment.