Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Refactor ratio to pick up any class #290

Merged
merged 53 commits into from Jun 10, 2017
Merged
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
4f87df5
EHN enable multiclass ratio handling
glemaitre May 8, 2017
4e528ec
FIX simplify call to dictionary
glemaitre May 8, 2017
4996bbb
FIX RUS done
glemaitre May 8, 2017
3f5dffa
FIX Refactor ADASYN
glemaitre May 9, 2017
87a630d
FIX partial
glemaitre May 9, 2017
9825317
FIX refactor SMOTE
glemaitre May 11, 2017
b7021fc
FIX refactor SMOTE
glemaitre May 11, 2017
8a010fb
DOC add proper docstring
glemaitre May 11, 2017
f573af3
PEP8
glemaitre May 11, 2017
a85dcff
FIX ClusterCentroids
glemaitre May 11, 2017
cabf202
FIX refactor IHT
glemaitre May 11, 2017
d2539b1
FIX Nearmiss refactoring
glemaitre May 12, 2017
0ee50c1
FIX tomek links refactor
glemaitre May 12, 2017
118af0e
FIX refactor OSS
glemaitre May 12, 2017
96b102e
FIX NCR refactoring
glemaitre May 13, 2017
8ecfd88
FIX refactor combined methods with Pipeline
glemaitre May 13, 2017
d4b9c3e
FIX combine method targetting all classes when cleaning
glemaitre May 13, 2017
f5303ca
FIX balance cascade refactoring
glemaitre May 13, 2017
0e93429
EHN add the possibility to add a dict for ratio
glemaitre May 14, 2017
38fe8ca
TST add test for check_ratio
glemaitre May 14, 2017
7f076cf
TST add test for float
glemaitre May 14, 2017
d89c12d
FIX/TST adapt common test
glemaitre May 14, 2017
039420b
TST fix IHT tests
glemaitre May 14, 2017
a31c0e1
TST fix NCR
glemaitre May 14, 2017
02be5f5
FIX combine test
glemaitre May 14, 2017
6fba010
TST fix balance
glemaitre May 14, 2017
f2d541a
FIX doctest
glemaitre May 14, 2017
c4d74e2
FIX doctest
glemaitre May 14, 2017
a1ba5f7
FIX solve the pickle issue
glemaitre May 14, 2017
012d3db
FIX remove comments
glemaitre May 14, 2017
d7fb9ae
TST add test for NCR
glemaitre May 14, 2017
b691064
TST add knn balance cascade
glemaitre May 14, 2017
ecf241f
EHN add callable option for the ratio
glemaitre May 14, 2017
a048c54
DOC make doc cleaner
glemaitre May 14, 2017
acc98e8
FIX/DOC remove useless comments and clean doc
glemaitre May 14, 2017
18bc464
DEP deprecation of ratio as float
glemaitre May 14, 2017
b9d0e5a
EHN add base class for cleaning methods
glemaitre May 15, 2017
3f3fb16
TST add common test for multi class
glemaitre May 16, 2017
9bcec08
MAINT downgrade sphinx for the moment
glemaitre May 16, 2017
38e1806
TST/EHN add test for the ratio and specific ratio for cleaning sampling
glemaitre May 16, 2017
5aa541b
EHN remove redundant code
glemaitre May 18, 2017
7a45207
FIX warning
glemaitre May 18, 2017
15c158c
Remove useless base class
glemaitre May 19, 2017
834de2f
MAINT add christos back to some file
glemaitre May 19, 2017
db22871
EHN rename test and add a comment
glemaitre May 19, 2017
38708e6
DOC add hash_X_y in the API
glemaitre May 19, 2017
8e59009
[MRG] Incorporate chkoar remarks (#6)
massich May 22, 2017
18f726c
[MRG] Remove the init in base class (#7)
massich May 22, 2017
19f423a
EHN doc
glemaitre May 26, 2017
9f3cfbd
Merge branch 'is/121' of github.com:glemaitre/imbalanced-learn into i…
glemaitre May 26, 2017
e4892ac
FIX add extension for sphinx
glemaitre May 26, 2017
7aef770
EHN make deprecatin great again
glemaitre May 30, 2017
c5ab8b9
EHN Improve SMOTE and ADASYN
glemaitre May 30, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion circle.yml
Expand Up @@ -23,7 +23,7 @@ dependencies:
- sudo apt-get install build-essential python-dev python-setuptools
# install numpy first as it is a compile time dependency for other packages
- pip install --upgrade numpy
- pip install --upgrade scipy matplotlib setuptools nose coverage sphinx pillow sphinx-gallery sphinx_rtd_theme
- pip install --upgrade scipy matplotlib setuptools nose coverage pillow sphinx-gallery sphinx_rtd_theme sphinx==1.5.6
# Installing required packages for `make -C doc check command` to work.
- sudo -E apt-get -yq update
- sudo -E apt-get -yq --no-install-suggests --no-install-recommends --force-yes install dvipng texlive-latex-base texlive-latex-extra
Expand Down
3 changes: 3 additions & 0 deletions doc/api.rst
Expand Up @@ -174,3 +174,6 @@ Utilities
:toctree: generated/

utils.estimator_checks.check_estimator
utils.check_neighbors_object
utils.check_ratio
utils.hash_X_y
10 changes: 10 additions & 0 deletions doc/whats_new.rst
Expand Up @@ -14,6 +14,8 @@ Bug fixes

- Fixed a bug in :class:`under_sampling.NearMiss` version 3. The
indices returned were wrong. By `Guillaume Lemaitre`_.
- fixed bug for :class:`ensemble.BalanceCascade` and :class:`combine.SMOTEENN`
and :class:`SMOTETomek. By `Guillaume Lemaitre`_.`

New features
~~~~~~~~~~~~
Expand All @@ -32,6 +34,7 @@ Enhancement
`Guillaume Lemaitre`_
- Remove seaborn dependence and improve the examples. By `Guillaume
Lemaitre`_.
- adapt all classes to multi-class resampling. By `Guillaume Lemaitre`_

API changes summary
~~~~~~~~~~~~~~~~~~~
Expand All @@ -45,7 +48,14 @@ API changes summary
- move the under-sampling methods in `prototype_selection` and
`prototype_generation` submodule to make a clearer dinstinction. By
`Guillaume Lemaitre`_.
- change `ratio` such that it can adapt to multiple class problems. By
`Guillaume Lemaitre`_.

Deprecation
~~~~~~~~~~~

- deprecate the use of float as ratio in favor of dictionary, string, or
callable. By `Guillaume Lemaitre`_.

.. _changes_0_2:

Expand Down
298 changes: 36 additions & 262 deletions imblearn/base.py
Expand Up @@ -8,97 +8,67 @@

import logging
import warnings
from numbers import Real
from abc import ABCMeta, abstractmethod
from collections import Counter

import numpy as np
from sklearn.base import BaseEstimator
from sklearn.externals import six
from sklearn.utils import check_X_y
from sklearn.utils.multiclass import type_of_target
from sklearn.utils.validation import check_is_fitted

from .utils import hash_X_y


class SamplerMixin(six.with_metaclass(ABCMeta, BaseEstimator)):
"""Mixin class for samplers with abstact method.
"""Mixin class for samplers with abstract method.

Warning: This class should not be used directly. Use the derive classes
instead.
"""

_estimator_type = 'sampler'

def fit(self, X, y):
"""Find the classes statistics before to perform sampling.

Parameters
----------
X : ndarray, shape (n_samples, n_features)
Matrix containing the data which have to be sampled.

y : ndarray, shape (n_samples, )
Corresponding label for each sample in X.

Returns
-------
self : object,
Return self.
def _validate_size_ngh_deprecation(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we could make all these _vaildate_methods(estimator) as module level functions to keep the class clean. So, it will be clean for any newcomer.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep I agree with that

"Private function to warn about the deprecation about size_ngh."

"""
# Announce deprecation if necessary
if self.size_ngh is not None:
warnings.warn('`size_ngh` will be replaced in version 0.4. Use'
' `n_neighbors` instead.', DeprecationWarning)
self.n_neighbors = self.size_ngh

# Check the consistency of X and y
X, y = check_X_y(X, y)
def _validate_k_deprecation(self):
"""Private function to warn about deprecation of k in ADASYN"""
if self.k is not None:
warnings.warn('`k` will be replaced in version 0.4. Use'
' `n_neighbors` instead.', DeprecationWarning)
self.n_neighbors = self.k

self.min_c_ = None
self.maj_c_ = None
self.stats_c_ = {}
self.X_shape_ = None
def _validate_k_m_deprecation(self):
"""Private function to warn about deprecation of k in ADASYN"""
if self.k is not None:
warnings.warn('`k` will be replaced in version 0.4. Use'
' `k_neighbors` instead.', DeprecationWarning)
self.k_neighbors = self.k

if hasattr(self, 'ratio'):
self._validate_ratio()
if self.m is not None:
warnings.warn('`m` will be replaced in version 0.4. Use'
' `m_neighbors` instead.', DeprecationWarning)
self.m_neighbors = self.m

def _validate_deprecation(self):
if hasattr(self, 'size_ngh'):
self._validate_size_ngh_deprecation()
elif hasattr(self, 'k') and not hasattr(self, 'm'):
self._validate_k_deprecation()
elif hasattr(self, 'k') and hasattr(self, 'm'):
self._validate_k_m_deprecation()

self.logger.info('Compute classes statistics ...')

# Raise an error if there is only one class
if np.unique(y).size <= 1:
raise ValueError("Sampler can't balance when only one class is"
" present.")

# Store the size of X to check at sampling time if we have the
# same data
self.X_shape_ = X.shape

# Create a dictionary containing the class statistics
self.stats_c_ = Counter(y)

# Find the minority and majority classes
self.min_c_ = min(self.stats_c_, key=self.stats_c_.get)
self.maj_c_ = max(self.stats_c_, key=self.stats_c_.get)

self.logger.info('%s classes detected: %s',
np.unique(y).size, self.stats_c_)

# Check if the ratio provided at initialisation make sense
if isinstance(self.ratio, Real):
if self.ratio < (self.stats_c_[self.min_c_] /
self.stats_c_[self.maj_c_]):
raise RuntimeError('The ratio requested at initialisation'
' should be greater or equal than the'
' balancing ratio of the current data.'
' Got {} < {}.'.format(
self.ratio,
self.stats_c_[self.min_c_] /
self.stats_c_[self.maj_c_]))

return self
def _check_hash_X_y(self, X, y):
"""Private function to check that the X and y in fitting are the same
than in sampling."""
X_hash, y_hash = hash_X_y(X, y)
if self.X_hash_ != X_hash or self.y_hash_ != y_hash:
raise RuntimeError("X and y need to be same array earlier fitted.")

def sample(self, X, y):
"""Resample the dataset.
Expand All @@ -124,25 +94,9 @@ def sample(self, X, y):
# Check the consistency of X and y
X, y = check_X_y(X, y)

# Check that the data have been fitted
check_is_fitted(self, 'stats_c_')

# Check if the size of the data is identical than at fitting
if X.shape != self.X_shape_:
raise RuntimeError('The data that you attempt to resample do not'
' seem to be the one earlier fitted. Use the'
' fitted data. Shape of data is {}, got {}'
' instead.'.format(X.shape, self.X_shape_))

if hasattr(self, 'ratio'):
self._validate_ratio()

if hasattr(self, 'size_ngh'):
self._validate_size_ngh_deprecation()
elif hasattr(self, 'k') and not hasattr(self, 'm'):
self._validate_k_deprecation()
elif hasattr(self, 'k') and hasattr(self, 'm'):
self._validate_k_m_deprecation()
self._validate_deprecation()
check_is_fitted(self, 'ratio_')
self._check_hash_X_y(X, y)

return self._sample(X, y)

Expand All @@ -169,56 +123,6 @@ def fit_sample(self, X, y):

return self.fit(X, y).sample(X, y)

def _validate_ratio(self):
# The ratio correspond to the number of samples in the minority class
# over the number of samples in the majority class. Thus, the ratio
# cannot be greater than 1.0
if isinstance(self.ratio, Real):
if self.ratio > 1:
raise ValueError('Ratio cannot be greater than one.'
' Got {}.'.format(self.ratio))
elif self.ratio <= 0:
raise ValueError('Ratio cannot be negative.'
' Got {}.'.format(self.ratio))

elif isinstance(self.ratio, six.string_types):
if self.ratio != 'auto':
raise ValueError("Unknown string for the parameter ratio."
" Got {} instead of 'auto'".format(
self.ratio))
else:
raise ValueError('Unknown parameter type for ratio.'
' Got {} instead of float or str'.format(
type(self.ratio)))

def _validate_size_ngh_deprecation(self):
"Private function to warn about the deprecation about size_ngh."

# Announce deprecation if necessary
if self.size_ngh is not None:
warnings.warn('`size_ngh` will be replaced in version 0.4. Use'
' `n_neighbors` instead.', DeprecationWarning)
self.n_neighbors = self.size_ngh

def _validate_k_deprecation(self):
"""Private function to warn about deprecation of k in ADASYN"""
if self.k is not None:
warnings.warn('`k` will be replaced in version 0.4. Use'
' `n_neighbors` instead.', DeprecationWarning)
self.n_neighbors = self.k

def _validate_k_m_deprecation(self):
"""Private function to warn about deprecation of k in ADASYN"""
if self.k is not None:
warnings.warn('`k` will be replaced in version 0.4. Use'
' `k_neighbors` instead.', DeprecationWarning)
self.k_neighbors = self.k

if self.m is not None:
warnings.warn('`m` will be replaced in version 0.4. Use'
' `m_neighbors` instead.', DeprecationWarning)
self.m_neighbors = self.m

@abstractmethod
def _sample(self, X, y):
"""Resample the dataset.
Expand Down Expand Up @@ -252,133 +156,3 @@ def __setstate__(self, dict):
logger = logging.getLogger(__name__)
self.__dict__.update(dict)
self.logger = logger


class BaseBinarySampler(six.with_metaclass(ABCMeta, SamplerMixin)):
"""Base class for all binary class sampler.

Warning: This class should not be used directly. Use derived classes
instead.

"""

def __init__(self, ratio='auto', random_state=None):
"""Initialize this object and its instance variables.

Parameters
----------
ratio : str or float, optional (default='auto')
If 'auto', the ratio will be defined automatically to balanced
the dataset. Otherwise, the ratio will corresponds to the number
of samples in the minority class over the the number of samples
in the majority class.

random_state : int, RandomState or None, optional (default=None)

- If int, random_state is the seed used by the random number
generator;
- If RandomState instance, random_state is the random number
generator;
- If None, the random number generator is the RandomState instance
used by np.random.

Returns
-------
None

"""
self.ratio = ratio
self.random_state = random_state
self.logger = logging.getLogger(__name__)

def fit(self, X, y):
"""Find the classes statistics before to perform sampling.

Parameters
----------
X : ndarray, shape (n_samples, n_features)
Matrix containing the data which have to be sampled.

y : ndarray, shape (n_samples, )
Corresponding label for each sample in X.

Returns
-------
self : object,
Return self.

"""

super(BaseBinarySampler, self).fit(X, y)

# Check that the target type is binary
if not type_of_target(y) == 'binary':
warnings.simplefilter('always', UserWarning)
warnings.warn('The target type should be binary.')

return self


class BaseMulticlassSampler(six.with_metaclass(ABCMeta, SamplerMixin)):
"""Base class for all multiclass sampler.

Warning: This class should not be used directly. Use derived classes
instead.

"""
def __init__(self, ratio='auto', random_state=None):
"""Initialize this object and its instance variables.

Parameters
----------
ratio : str or float, optional (default='auto')
If 'auto', the ratio will be defined automatically to balanced
the dataset. Otherwise, the ratio will corresponds to the number
of samples in the minority class over the the number of samples
in the majority class.

random_state : int, RandomState or None, optional (default=None)

- If int, random_state is the seed used by the random number
generator;
- If RandomState instance, random_state is the random number
generator;
- If None, the random number generator is the RandomState instance
used by np.random.

Returns
-------
None

"""
self.ratio = ratio
self.random_state = random_state
self.logger = logging.getLogger(__name__)

def fit(self, X, y):
"""Find the classes statistics before to perform sampling.

Parameters
----------
X : ndarray, shape (n_samples, n_features)
Matrix containing the data which have to be sampled.

y : ndarray, shape (n_samples, )
Corresponding label for each sample in X.

Returns
-------
self : object,
Return self.

"""

super(BaseMulticlassSampler, self).fit(X, y)

# Check that the target type is either binary or multiclass
if not (type_of_target(y) == 'binary' or
type_of_target(y) == 'multiclass'):
warnings.simplefilter('always', UserWarning)
warnings.warn('The target type should be binary or multiclass.')

return self