Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,4 @@ Functions
:toctree: generated/

pipeline.make_pipeline

20 changes: 18 additions & 2 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,26 @@ Version 0.2
Changelog
---------

- Added support for bumpversion.
- Added doctest in the documentation.
New features
~~~~~~~~~~~~

- Added AllKNN under sampling technique.

API changes summary
~~~~~~~~~~~~~~~~~~~

- Two base classes :class:`BaseBinaryclassSampler` and :class:`BaseMulticlassSampler` have been created to handle the target type and raise warning in case of abnormality.

Enhancement
~~~~~~~~~~~

- Added support for bumpversion.
- Validate the type of target in binary samplers. A warning is raised for the moment.

Documentation changes
~~~~~~~~~~~~~~~~~~~~~

- Added doctest in the documentation.

.. _changes_0_1:

Expand Down
2 changes: 2 additions & 0 deletions imblearn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
Module which provides methods to under-sample a dataset.
under-sampling
Module which provides methods to over-sample a dataset.
pipeline
Module which allowing to create pipeline with scikit-learn estimators.
"""

from .version import _check_module_dependencies, __version__
Expand Down
74 changes: 73 additions & 1 deletion imblearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from sklearn.base import BaseEstimator
from sklearn.utils import check_X_y
from sklearn.utils.multiclass import type_of_target
from sklearn.externals import six

from six import string_types
Expand All @@ -27,7 +28,7 @@ class SamplerMixin(six.with_metaclass(ABCMeta, BaseEstimator)):
instead.
"""

_estimator_type = "sampler"
_estimator_type = 'sampler'

def __init__(self, ratio='auto'):
"""Initialize this object and its instance variables.
Expand Down Expand Up @@ -226,3 +227,74 @@ def __setstate__(self, dict):
logger = logging.getLogger(__name__)
self.__dict__.update(dict)
self.logger = logger


class BaseBinarySampler(six.with_metaclass(ABCMeta, SamplerMixin)):
"""Base class for all binary class sampler.

Warning: This class should not be used directly. Use derived classes
instead.

"""

def fit(self, X, y):
"""Find the classes statistics before to perform sampling.

Parameters
----------
X : ndarray, shape (n_samples, n_features)
Matrix containing the data which have to be sampled.

y : ndarray, shape (n_samples, )
Corresponding label for each sample in X.

Returns
-------
self : object,
Return self.

"""

super(BaseBinarySampler, self).fit(X, y)

# Check that the target type is binary
if not type_of_target(y) == 'binary':
warnings.warn('The target type should be binary.')

return self


class BaseMulticlassSampler(six.with_metaclass(ABCMeta, SamplerMixin)):
"""Base class for all multiclass sampler.

Warning: This class should not be used directly. Use derived classes
instead.

"""

def fit(self, X, y):
"""Find the classes statistics before to perform sampling.

Parameters
----------
X : ndarray, shape (n_samples, n_features)
Matrix containing the data which have to be sampled.

y : ndarray, shape (n_samples, )
Corresponding label for each sample in X.

Returns
-------
self : object,
Return self.

"""

super(BaseMulticlassSampler, self).fit(X, y)

# Check that the target type is either binary or multiclass
if not (type_of_target(y) == 'binary' or
type_of_target(y) == 'multiclass'):
warnings.warn('The target type should be binary or multiclass.')

return self
4 changes: 2 additions & 2 deletions imblearn/combine/smote_enn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

from ..over_sampling import SMOTE
from ..under_sampling import EditedNearestNeighbours
from ..base import SamplerMixin
from ..base import BaseBinarySampler


class SMOTEENN(SamplerMixin):
class SMOTEENN(BaseBinarySampler):
"""Class to perform over-sampling using SMOTE and cleaning using ENN.

Combine over- and under-sampling using SMOTE and Edited Nearest Neighbours.
Expand Down
4 changes: 2 additions & 2 deletions imblearn/combine/smote_tomek.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

from ..over_sampling import SMOTE
from ..under_sampling import TomekLinks
from ..base import SamplerMixin
from ..base import BaseBinarySampler


class SMOTETomek(SamplerMixin):
class SMOTETomek(BaseBinarySampler):
"""Class to perform over-sampling using SMOTE and cleaning using
Tomek links.

Expand Down
15 changes: 15 additions & 0 deletions imblearn/combine/tests/test_smote_enn.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,18 @@ def test_sample_wrong_X():
sm.fit(X, Y)
assert_raises(RuntimeError, sm.sample, np.random.random((100, 40)),
np.array([0] * 50 + [1] * 50))


def test_senn_multiclass_error():
""" Test either if an error is raised when the target are not binary
type. """

# continuous case
y = np.linspace(0, 1, 5000)
sm = SMOTEENN(random_state=RND_SEED)
assert_warns(UserWarning, sm.fit, X, y)

# multiclass case
y = np.array([0] * 2000 + [1] * 2000 + [2] * 1000)
sm = SMOTEENN(random_state=RND_SEED)
assert_warns(UserWarning, sm.fit, X, y)
15 changes: 15 additions & 0 deletions imblearn/combine/tests/test_smote_tomek.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,18 @@ def test_sample_wrong_X():
sm.fit(X, Y)
assert_raises(RuntimeError, sm.sample, np.random.random((100, 40)),
np.array([0] * 50 + [1] * 50))


def test_multiclass_error():
""" Test either if an error is raised when the target are not binary
type. """

# continuous case
y = np.linspace(0, 1, 5000)
sm = SMOTETomek(random_state=RND_SEED)
assert_warns(UserWarning, sm.fit, X, y)

# multiclass case
y = np.array([0] * 2000 + [1] * 2000 + [2] * 1000)
sm = SMOTETomek(random_state=RND_SEED)
assert_warns(UserWarning, sm.fit, X, y)
5 changes: 3 additions & 2 deletions imblearn/ensemble/balance_cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@

from sklearn.utils import check_random_state

from ..base import SamplerMixin
from ..base import BaseBinarySampler


ESTIMATOR_KIND = ('knn', 'decision-tree', 'random-forest', 'adaboost',
'gradient-boosting', 'linear-svm')


class BalanceCascade(SamplerMixin):
class BalanceCascade(BaseBinarySampler):
"""Create an ensemble of balanced sets by iteratively under-sampling the
imbalanced dataset using an estimator.

Expand Down Expand Up @@ -100,6 +100,7 @@ class BalanceCascade(SamplerMixin):
April 2009.

"""

def __init__(self, ratio='auto', return_indices=False, random_state=None,
n_max_subset=None, classifier='knn', bootstrap=True,
**kwargs):
Expand Down
6 changes: 4 additions & 2 deletions imblearn/ensemble/easy_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

import numpy as np

from ..base import SamplerMixin
from ..base import BaseMulticlassSampler
from ..under_sampling import RandomUnderSampler


class EasyEnsemble(SamplerMixin):
class EasyEnsemble(BaseMulticlassSampler):
"""Create an ensemble sets by iteratively applying random under-sampling.

This method iteratively select a random subset and make an ensemble of the
Expand Down Expand Up @@ -56,6 +56,8 @@ class EasyEnsemble(SamplerMixin):
-----
The method is described in [1]_.

This method supports multiclass target type.

Examples
--------

Expand Down
15 changes: 15 additions & 0 deletions imblearn/ensemble/tests/test_balance_cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,3 +339,18 @@ def test_sample_wrong_X():
bc.fit(X, Y)
assert_raises(RuntimeError, bc.sample, np.random.random((100, 40)),
np.array([0] * 50 + [1] * 50))


def test_multiclass_error():
""" Test either if an error is raised when the target are not binary
type. """

# continuous case
y = np.linspace(0, 1, 5000)
bc = BalanceCascade(random_state=RND_SEED)
assert_warns(UserWarning, bc.fit, X, y)

# multiclass case
y = np.array([0] * 2000 + [1] * 2000 + [2] * 1000)
bc = BalanceCascade(random_state=RND_SEED)
assert_warns(UserWarning, bc.fit, X, y)
30 changes: 30 additions & 0 deletions imblearn/ensemble/tests/test_easy_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from sklearn.datasets import make_classification
from sklearn.utils.estimator_checks import check_estimator

from collections import Counter

from imblearn.ensemble import EasyEnsemble

# Generate a global dataset to use
Expand Down Expand Up @@ -170,3 +172,31 @@ def test_sample_wrong_X():
ee.fit(X, Y)
assert_raises(RuntimeError, ee.sample, np.random.random((100, 40)),
np.array([0] * 50 + [1] * 50))


def test_continuous_error():
"""Test either if an error is raised when the target are continuous
type"""

# continuous case
y = np.linspace(0, 1, 5000)
ee = EasyEnsemble(random_state=RND_SEED)
assert_warns(UserWarning, ee.fit, X, y)


def test_multiclass_fit_sample():
"""Test fit sample method with multiclass target"""

# Make y to be multiclass
y = Y.copy()
y[0:1000] = 2

# Resample the data
ee = EasyEnsemble(random_state=RND_SEED)
X_resampled, y_resampled = ee.fit_sample(X, y)

# Check the size of y
count_y_res = Counter(y_resampled[0])
assert_equal(count_y_res[0], 400)
assert_equal(count_y_res[1], 400)
assert_equal(count_y_res[2], 400)
4 changes: 2 additions & 2 deletions imblearn/over_sampling/adasyn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from sklearn.neighbors import NearestNeighbors
from sklearn.utils import check_random_state

from ..base import SamplerMixin
from ..base import BaseBinarySampler


class ADASYN(SamplerMixin):
class ADASYN(BaseBinarySampler):

"""Perform over-sampling using ADASYN.

Expand Down
4 changes: 2 additions & 2 deletions imblearn/over_sampling/random_over_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@

from sklearn.utils import check_random_state

from ..base import SamplerMixin
from ..base import BaseMulticlassSampler


class RandomOverSampler(SamplerMixin):
class RandomOverSampler(BaseMulticlassSampler):

"""Class to perform random over-sampling.

Expand Down
7 changes: 5 additions & 2 deletions imblearn/over_sampling/smote.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,23 @@
from __future__ import print_function
from __future__ import division

import warnings

import numpy as np

from sklearn.utils import check_array
from sklearn.utils import check_random_state
from sklearn.utils.multiclass import type_of_target
from sklearn.neighbors import NearestNeighbors
from sklearn.svm import SVC

from ..base import SamplerMixin
from ..base import BaseBinarySampler


SMOTE_KIND = ('regular', 'borderline1', 'borderline2', 'svm')


class SMOTE(SamplerMixin):
class SMOTE(BaseBinarySampler):

"""Class to perform over-sampling using SMOTE.

Expand Down
15 changes: 15 additions & 0 deletions imblearn/over_sampling/tests/test_adasyn.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,18 @@ def test_sample_wrong_X():
ada.fit(X, Y)
assert_raises(RuntimeError, ada.sample, np.random.random((100, 40)),
np.array([0] * 50 + [1] * 50))


def test_multiclass_error():
""" Test either if an error is raised when the target are not binary
type. """

# continuous case
y = np.linspace(0, 1, 5000)
ada = ADASYN(random_state=RND_SEED)
assert_warns(UserWarning, ada.fit, X, y)

# multiclass case
y = np.array([0] * 2000 + [1] * 2000 + [2] * 1000)
ada = ADASYN(random_state=RND_SEED)
assert_warns(UserWarning, ada.fit, X, y)
Loading