Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Add the option to pass a KNeighborsMixin instead of simple n_neigbors #182

Merged
merged 7 commits into from Nov 2, 2016
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/whats_new.rst
Expand Up @@ -56,7 +56,7 @@ API changes summary
- Provide estimators instead of parameters in :class:`combine.SMOTEENN` and :class:`combine.SMOTETomek`. Therefore, the list of parameters have been deprecated. By `Guillaume Lemaitre`_ and `Christos Aridas`_.
- `k` has been deprecated in :class:`over_sampling.ADASYN`. Use `n_neighbors` instead. By `Guillaume Lemaitre`_.
- `k` and `m` have been deprecated in :class:`over_sampling.SMOTE`. Use `k_neighbors` and `m_neighbors` instead. By `Guillaume Lemaitre`_.

- `n_neighbors` accept `KNeighborsMixin` based object for :class:`under_sampling.EditedNearestNeighbors`, :class:`under_sampling.CondensedNeareastNeigbour`, :class:`under_sampling.NeighbourhoodCleaningRule`, :class:`under_sampling.RepeatedEditedNearestNeighbours`, and :class:`under_sampling.AllKNN`. By `Guillaume Lemaitre`_.

Documentation changes
~~~~~~~~~~~~~~~~~~~~~
Expand Down
54 changes: 46 additions & 8 deletions imblearn/over_sampling/adasyn.py
Expand Up @@ -5,6 +5,7 @@

import numpy as np
from sklearn.neighbors import NearestNeighbors
from sklearn.neighbors.base import KNeighborsMixin
from sklearn.utils import check_random_state

from ..base import BaseBinarySampler
Expand Down Expand Up @@ -39,6 +40,10 @@ class ADASYN(BaseBinarySampler):

n_neighbours : int, optional (default=5)
Number of nearest neighbours to used to construct synthetic samples.
A KNeighborsMixin object implementing a `kneighbors` method can be
provided. Currently, `NearestNeighbors` (i.e., exact NN algorithm)
and `LSHForest` (i.e., approximate NN algorithm) are the 2 types
provided in scikit-learn.

n_jobs : int, optional (default=1)
Number of threads to run the algorithm when it is possible.
Expand Down Expand Up @@ -96,9 +101,42 @@ def __init__(self, ratio='auto', random_state=None, k=None, n_neighbors=5,
self.k = k
self.n_neighbors = n_neighbors
self.n_jobs = n_jobs
self.nearest_neighbour = NearestNeighbors(
n_neighbors=self.n_neighbors + 1,
n_jobs=self.n_jobs)

def _validate_estimator(self):
"""Private function to create the NN estimator"""

if isinstance(self.n_neighbors, int):
self.nn_ = NearestNeighbors(n_neighbors=self.n_neighbors + 1,
n_jobs=self.n_jobs)
elif isinstance(self.n_neighbors, KNeighborsMixin):
self.nn_ = self.n_neighbors
else:
raise ValueError('`n_neighbors` has to be be either int or a'
' subclass of KNeighborsMixin.')

def fit(self, X, y):
"""Find the classes statistics before to perform sampling.

Parameters
----------
X : ndarray, shape (n_samples, n_features)
Matrix containing the data which have to be sampled.

y : ndarray, shape (n_samples, )
Corresponding label for each sample in X.

Returns
-------
self : object,
Return self.

"""

super(ADASYN, self).fit(X, y)

self._validate_estimator()

return self

def _sample(self, X, y):
"""Resample the dataset.
Expand Down Expand Up @@ -140,18 +178,18 @@ def _sample(self, X, y):

# Print if verbose is true
self.logger.debug('Finding the %s nearest neighbours ...',
self.n_neighbors)
self.nn_.n_neighbors - 1)

# Look for k-th nearest neighbours, excluding, of course, the
# point itself.
self.nearest_neighbour.fit(X)
self.nn_.fit(X)

# Get the distance to the NN
_, ind_nn = self.nearest_neighbour.kneighbors(X_min)
_, ind_nn = self.nn_.kneighbors(X_min)

# Compute the ratio of majority samples next to minority samples
ratio_nn = (np.sum(y[ind_nn[:, 1:]] == self.maj_c_, axis=1) /
self.n_neighbors)
(self.nn_.n_neighbors - 1))
# Check that we found at least some neighbours belonging to the
# majority class
if not np.sum(ratio_nn):
Expand All @@ -169,7 +207,7 @@ def _sample(self, X, y):
for x_i, x_i_nn, num_sample_i in zip(X_min, ind_nn, num_samples_nn):

# Pick-up the neighbors wanted
nn_zs = random_state.randint(1, high=self.n_neighbors + 1,
nn_zs = random_state.randint(1, high=self.nn_.n_neighbors,
size=num_sample_i)

# Create a new sample
Expand Down