Skip to content

Commit

Permalink
[MRG] EHN handling sparse matrices whenever possible (#316)
Browse files Browse the repository at this point in the history
* EHN POC sparse handling for RandomUnderSampler

* EHN support sparse ENN

* iter

* EHN sparse indexing IHT

* EHN sparse support nearmiss

* EHN support sparse matrices for NCR

* EHN support sparse Tomek and OSS

* EHN support sparsity for CNN

* EHN support sparse for SMOTE

* EHN support sparse adasyn

* EHN support sparsity for sombine methods

* EHN support sparsity BC

* DOC update docstring

* DOC fix example topic classification

* FIX fix test and class clustercentroids

* TST add common test

* TST add ensemble

* TST use allclose

* TST install conda with ubuntu container

* TST increase tolerance

* TST increase tolerance

* TST test all versions NearMiss and SMOTE

* TST set the algorithm of KMeans

* DOC add entry in user guide

* DOC add entry sparse for CC

* DOC whatsnew entry

* DOC fix api

* TST adapt pytest

* DOC update user guide

* address comments

* TST remove the last assert_regex
  • Loading branch information
glemaitre committed Aug 24, 2017
1 parent 488a0e8 commit cddf39b
Show file tree
Hide file tree
Showing 33 changed files with 682 additions and 550 deletions.
2 changes: 1 addition & 1 deletion appveyor.yml
Expand Up @@ -36,7 +36,7 @@ install:
- "python -c \"import struct; print(struct.calcsize('P') * 8)\""

# Installed prebuilt dependencies from conda
- "conda install pip numpy scipy scikit-learn=0.19.0 nose wheel matplotlib -y -q"
- "conda install pip numpy scipy scikit-learn=0.19.0 pandas nose wheel matplotlib -y -q"

# Install other nilearn dependencies
- "pip install coverage nose-timer pytest pytest-cov"
Expand Down
4 changes: 2 additions & 2 deletions build_tools/travis/install.sh
Expand Up @@ -38,7 +38,7 @@ if [[ "$DISTRIB" == "conda" ]]; then
# provided versions
conda create -n testenv --yes python=$PYTHON_VERSION pip
source activate testenv
conda install --yes numpy=$NUMPY_VERSION scipy=$SCIPY_VERSION
conda install --yes numpy=$NUMPY_VERSION scipy=$SCIPY_VERSION pandas

if [[ "$SKLEARN_VERSION" == "master" ]]; then
conda install --yes cython
Expand All @@ -59,7 +59,7 @@ elif [[ "$DISTRIB" == "ubuntu" ]]; then
# Create a new virtualenv using system site packages for python, numpy
virtualenv --system-site-packages testvenv
source testvenv/bin/activate
pip install scikit-learn nose nose-timer pytest pytest-cov codecov
pip install scikit-learn pandas nose nose-timer pytest pytest-cov codecov

fi

Expand Down
12 changes: 6 additions & 6 deletions doc/combine.rst
Expand Up @@ -29,18 +29,18 @@ than their former samplers::
... n_clusters_per_class=1,
... weights=[0.01, 0.05, 0.94],
... class_sep=0.8, random_state=0)
>>> print(Counter(y))
Counter({2: 4674, 1: 262, 0: 64})
>>> print(sorted(Counter(y).items()))
[(0, 64), (1, 262), (2, 4674)]
>>> from imblearn.combine import SMOTEENN
>>> smote_enn = SMOTEENN(random_state=0)
>>> X_resampled, y_resampled = smote_enn.fit_sample(X, y)
>>> print(Counter(y_resampled))
Counter({1: 4381, 0: 4060, 2: 3502})
>>> print(sorted(Counter(y_resampled).items()))
[(0, 4060), (1, 4381), (2, 3502)]
>>> from imblearn.combine import SMOTETomek
>>> smote_tomek = SMOTETomek(random_state=0)
>>> X_resampled, y_resampled = smote_tomek.fit_sample(X, y)
>>> print(Counter(y_resampled))
Counter({1: 4566, 0: 4499, 2: 4413})
>>> print(sorted(Counter(y_resampled).items()))
[(0, 4499), (1, 4566), (2, 4413)]

We can also see in the example below that :class:`SMOTEENN` tends to clean more
noisy samples than :class:`SMOTETomek`.
Expand Down
17 changes: 8 additions & 9 deletions doc/datasets/index.rst
Expand Up @@ -85,8 +85,8 @@ A specific data set can be selected as::
>>> ecoli = fetch_datasets()['ecoli']
>>> ecoli.data.shape
(336, 7)
>>> print(Counter((ecoli.target)))
Counter({-1: 301, 1: 35})
>>> print(sorted(Counter(ecoli.target).items()))
[(-1, 301), (1, 35)]

.. _make_imbalanced:

Expand All @@ -104,16 +104,16 @@ samples in the class::
>>> iris = load_iris()
>>> ratio = {0: 20, 1: 30, 2: 40}
>>> X_imb, y_imb = make_imbalance(iris.data, iris.target, ratio=ratio)
>>> Counter(y_imb)
Counter({2: 40, 1: 30, 0: 20})
>>> sorted(Counter(y_imb).items())
[(0, 20), (1, 30), (2, 40)]

Note that all samples of a class are passed-through if the class is not mentioned
in the dictionary::

>>> ratio = {0: 10}
>>> X_imb, y_imb = make_imbalance(iris.data, iris.target, ratio=ratio)
>>> Counter(y_imb)
Counter({1: 50, 2: 50, 0: 10})
>>> sorted(Counter(y_imb).items())
[(0, 10), (1, 50), (2, 50)]

Instead of a dictionary, a function can be defined and directly pass to
``ratio``::
Expand All @@ -126,9 +126,8 @@ Instead of a dictionary, a function can be defined and directly pass to
... return target_stats
>>> X_imb, y_imb = make_imbalance(iris.data, iris.target,
... ratio=ratio_multiplier)
>>> Counter(y_imb)
Counter({2: 47, 1: 35, 0: 25})

>>> sorted(Counter(y_imb).items())
[(0, 25), (1, 35), (2, 47)]

See :ref:`sphx_glr_auto_examples_datasets_plot_make_imbalance.py` and
:ref:`sphx_glr_auto_examples_plot_ratio_usage.py`.
12 changes: 6 additions & 6 deletions doc/ensemble.rst
Expand Up @@ -19,15 +19,15 @@ under-sampling the original set::
... n_clusters_per_class=1,
... weights=[0.01, 0.05, 0.94],
... class_sep=0.8, random_state=0)
>>> print(Counter(y))
Counter({2: 4674, 1: 262, 0: 64})
>>> print(sorted(Counter(y).items()))
[(0, 64), (1, 262), (2, 4674)]
>>> from imblearn.ensemble import EasyEnsemble
>>> ee = EasyEnsemble(random_state=0, n_subsets=10)
>>> X_resampled, y_resampled = ee.fit_sample(X, y)
>>> print(X_resampled.shape)
(10, 192, 2)
>>> print(Counter(y_resampled[0])) # doctest: +SKIP
Counter({0: 64, 1: 64, 2: 64})
>>> print(sorted(Counter(y_resampled[0]).items()))
[(0, 64), (1, 64), (2, 64)]

:class:`EasyEnsemble` has two important parameters: (i) ``n_subsets`` will be
used to return number of subset and (ii) ``replacement`` to randomly sample
Expand All @@ -48,8 +48,8 @@ parameter ``n_max_subset`` and an additional bootstraping can be activated with
>>> X_resampled, y_resampled = bc.fit_sample(X, y)
>>> print(X_resampled.shape)
(4, 192, 2)
>>> print(Counter(y_resampled[0])) # doctest: +SKIP
Counter({2: 64, 1: 64, 0: 64})
>>> print(sorted(Counter(y_resampled[0]).items()))
[(0, 64), (1, 64), (2, 64)]

See
:ref:`sphx_glr_auto_examples_ensemble_plot_easy_ensemble.py` and
Expand Down
61 changes: 61 additions & 0 deletions doc/introduction.rst
@@ -0,0 +1,61 @@
.. _introduction:

============
Introduction
============

.. _api_imblearn:

API's of imbalanced-learn samplers
----------------------------------

The available samplers follows the scikit-learn API using the base estimator and adding a sampling functionality throw the ``sample`` method::

:Estimator:

The base object, implements a ``fit`` method to learn from data, either::

estimator = obj.fit(data, targets)

:Sampler:

To resample a data sets, each sampler implements::

data_resampled, targets_resampled = obj.sample(data, targets)

Fitting and sampling can also be done in one step::

data_resampled, targets_resampled = obj.fit_sample(data, targets)

Imbalanced-learn samplers accept the same inputs that in scikit-learn:

* ``data``: array-like (2-D list, pandas.Dataframe, numpy.array) or sparse
matrices;
* ``targets``: array-like (1-D list, pandas.Series, numpy.array).

.. topic:: Sparse input

For sparse input the data is **converted to the Compressed Sparse Rows
representation** (see ``scipy.sparse.csr_matrix``) before being fed to the
sampler. To avoid unnecessary memory copies, it is recommended to choose the
CSR representation upstream.

.. _problem_statement:

Problem statement regarding imbalanced data sets
------------------------------------------------

The learning phase and the subsequent prediction of machine learning algorithms
can be affected by the problem of imbalanced data set. The balancing issue
corresponds to the difference of the number of samples in the different
classes. We illustrate the effect of training a linear SVM classifier with
different level of class balancing.

.. image:: ./auto_examples/over-sampling/images/sphx_glr_plot_comparison_over_sampling_001.png
:target: ./auto_examples/over-sampling/plot_comparison_over_sampling.html
:scale: 60
:align: center

As expected, the decision function of the linear SVM is highly impacted. With a
greater imbalanced ratio, the decision function favor the class with the larger
number of samples, usually referred as the majority class.
18 changes: 9 additions & 9 deletions doc/over_sampling.rst
Expand Up @@ -29,15 +29,15 @@ randomly sampling with replacement the current available samples. The
>>> ros = RandomOverSampler(random_state=0)
>>> X_resampled, y_resampled = ros.fit_sample(X, y)
>>> from collections import Counter
>>> print(Counter(y_resampled)) # doctest: +SKIP
Counter({2: 4674, 1: 4674, 0: 4674})
>>> print(sorted(Counter(y_resampled).items()))
[(0, 4674), (1, 4674), (2, 4674)]

The augmented data set should be used instead of the original data set to train
a classifier::

>>> from sklearn.svm import LinearSVC
>>> clf = LinearSVC()
>>> clf.fit(X_resampled, y_resampled) # doctest: +ELLIPSIS
>>> clf.fit(X_resampled, y_resampled) # doctest : +ELLIPSIS
LinearSVC(...)

In the figure below, we compare the decision functions of a classifier trained
Expand Down Expand Up @@ -67,12 +67,12 @@ can be used in the same manner::

>>> from imblearn.over_sampling import SMOTE, ADASYN
>>> X_resampled, y_resampled = SMOTE().fit_sample(X, y)
>>> print(Counter(y_resampled)) # doctest: +SKIP
Counter({2: 4674, 1: 4674, 0: 4674})
>>> print(sorted(Counter(y_resampled).items()))
[(0, 4674), (1, 4674), (2, 4674)]
>>> clf_smote = LinearSVC().fit(X_resampled, y_resampled)
>>> X_resampled, y_resampled = ADASYN().fit_sample(X, y)
>>> print(Counter(y_resampled))
Counter({2: 4674, 0: 4673, 1: 4662})
>>> print(sorted(Counter(y_resampled).items()))
[(0, 4673), (1, 4662), (2, 4674)]
>>> clf_adasyn = LinearSVC().fit(X_resampled, y_resampled)

The figure below illustrates the major difference of the different over-sampling
Expand Down Expand Up @@ -132,8 +132,8 @@ available: (i) ``'borderline1'``, (ii) ``'borderline2'``, and (iii) ``'svm'``::

>>> from imblearn.over_sampling import SMOTE, ADASYN
>>> X_resampled, y_resampled = SMOTE(kind='borderline1').fit_sample(X, y)
>>> print(Counter(y_resampled)) # doctest: +SKIP
Counter({2: 4674, 1: 4674, 0: 4674})
>>> print(sorted(Counter(y_resampled).items()))
[(0, 4674), (1, 4674), (2, 4674)]

See :ref:`sphx_glr_auto_examples_over-sampling_plot_comparison_over_sampling.py`
to see a comparison between the different over-sampling methods.
Expand Down
20 changes: 0 additions & 20 deletions doc/problem_statement.rst

This file was deleted.

54 changes: 30 additions & 24 deletions doc/under_sampling.rst
Expand Up @@ -28,13 +28,13 @@ K-means method instead of the original samples::
... n_clusters_per_class=1,
... weights=[0.01, 0.05, 0.94],
... class_sep=0.8, random_state=0)
>>> print(Counter(y))
Counter({2: 4674, 1: 262, 0: 64})
>>> print(sorted(Counter(y).items()))
[(0, 64), (1, 262), (2, 4674)]
>>> from imblearn.under_sampling import ClusterCentroids
>>> cc = ClusterCentroids(random_state=0)
>>> X_resampled, y_resampled = cc.fit_sample(X, y)
>>> print(Counter(y_resampled))
Counter({0: 64, 1: 64, 2: 64})
>>> print(sorted(Counter(y_resampled).items()))
[(0, 64), (1, 64), (2, 64)]

The figure below illustrates such under-sampling.

Expand All @@ -49,6 +49,12 @@ your data are grouped into clusters. In addition, the number of centroids
should be set such that the under-sampled clusters are representative of the
original one.

.. warning::

:class:`ClusterCentroids` supports sparse matrices. However, the new samples
generated are not specifically sparse. Therefore, even if the resulting
matrix will be sparse, the algorithm will be inefficient in this regard.

See :ref:`sphx_glr_auto_examples_under-sampling_plot_cluster_centroids.py` and
:ref:`sphx_glr_auto_examples_under-sampling_plot_comparison_under_sampling.py`.

Expand Down Expand Up @@ -77,8 +83,8 @@ randomly selecting a subset of data for the targeted classes::
>>> from imblearn.under_sampling import RandomUnderSampler
>>> rus = RandomUnderSampler(random_state=0)
>>> X_resampled, y_resampled = rus.fit_sample(X, y)
>>> print(Counter(y_resampled))
Counter({0: 64, 1: 64, 2: 64})
>>> print(sorted(Counter(y_resampled).items()))
[(0, 64), (1, 64), (2, 64)]

.. image:: ./auto_examples/under-sampling/images/sphx_glr_plot_comparison_under_sampling_002.png
:target: ./auto_examples/under-sampling/plot_comparison_under_sampling.html
Expand Down Expand Up @@ -108,8 +114,8 @@ be selected with the parameter ``version``::
>>> from imblearn.under_sampling import NearMiss
>>> nm1 = NearMiss(random_state=0, version=1)
>>> X_resampled_nm1, y_resampled = nm1.fit_sample(X, y)
>>> print(Counter(y_resampled))
Counter({0: 64, 1: 64, 2: 64})
>>> print(sorted(Counter(y_resampled).items()))
[(0, 64), (1, 64), (2, 64)]

As later stated in the next section, :class:`NearMiss` heuristic rules are
based on nearest neighbors algorithm. Therefore, the parameters ``n_neighbors``
Expand Down Expand Up @@ -238,13 +244,13 @@ available: (i) the majority (i.e., ``kind_sel='mode'``) or (ii) all (i.e.,
``kind_sel='all'``) the nearest-neighbors have to belong to the same class than
the sample inspected to keep it in the dataset::

>>> Counter(y)
Counter({2: 4674, 1: 262, 0: 64})
>>> sorted(Counter(y).items())
[(0, 64), (1, 262), (2, 4674)]
>>> from imblearn.under_sampling import EditedNearestNeighbours
>>> enn = EditedNearestNeighbours(random_state=0)
>>> X_resampled, y_resampled = enn.fit_sample(X, y)
>>> print(Counter(y_resampled))
Counter({2: 4568, 1: 213, 0: 64})
>>> print(sorted(Counter(y_resampled).items()))
[(0, 64), (1, 213), (2, 4568)]

The parameter ``n_neighbors`` allows to give a classifier subclassed from
``KNeighborsMixin`` from scikit-learn to find the nearest neighbors and make
Expand All @@ -257,8 +263,8 @@ Generally, repeating the algorithm will delete more data::
>>> from imblearn.under_sampling import RepeatedEditedNearestNeighbours
>>> renn = RepeatedEditedNearestNeighbours(random_state=0)
>>> X_resampled, y_resampled = renn.fit_sample(X, y)
>>> print(Counter(y_resampled))
Counter({2: 4551, 1: 208, 0: 64})
>>> print(sorted(Counter(y_resampled).items()))
[(0, 64), (1, 208), (2, 4551)]

:class:`AllKNN` differs from the previous
:class:`RepeatedEditedNearestNeighbours` since the number of neighbors of the
Expand All @@ -267,8 +273,8 @@ internal nearest neighbors algorithm is increased at each iteration::
>>> from imblearn.under_sampling import AllKNN
>>> allknn = AllKNN(random_state=0)
>>> X_resampled, y_resampled = allknn.fit_sample(X, y)
>>> print(Counter(y_resampled))
Counter({2: 4601, 1: 220, 0: 64})
>>> print(sorted(Counter(y_resampled).items()))
[(0, 64), (1, 220), (2, 4601)]

In the example below, it can be seen that the three algorithms have similar
impact by cleaning noisy samples next to the boundaries of the classes.
Expand Down Expand Up @@ -305,8 +311,8 @@ The :class:`CondensedNearestNeighbour` can be used in the following manner::
>>> from imblearn.under_sampling import CondensedNearestNeighbour
>>> cnn = CondensedNearestNeighbour(random_state=0)
>>> X_resampled, y_resampled = cnn.fit_sample(X, y)
>>> print(Counter(y_resampled))
Counter({2: 116, 0: 64, 1: 25})
>>> print(sorted(Counter(y_resampled).items()))
[(0, 64), (1, 24), (2, 115)]

However as illustrated in the figure below, :class:`CondensedNearestNeighbour`
is sensitive to noise and will add noisy samples.
Expand All @@ -320,8 +326,8 @@ used as::
>>> from imblearn.under_sampling import OneSidedSelection
>>> oss = OneSidedSelection(random_state=0)
>>> X_resampled, y_resampled = oss.fit_sample(X, y)
>>> print(Counter(y_resampled))
Counter({2: 4403, 1: 174, 0: 64})
>>> print(sorted(Counter(y_resampled).items()))
[(0, 64), (1, 174), (2, 4403)]

Our implementation offer to set the number of seeds to put in the set :math:`C`
originally by setting the parameter ``n_seeds_S``.
Expand All @@ -334,8 +340,8 @@ neighbors classifier. The class can be used as::
>>> from imblearn.under_sampling import NeighbourhoodCleaningRule
>>> ncr = NeighbourhoodCleaningRule(random_state=0)
>>> X_resampled, y_resampled = ncr.fit_sample(X, y)
>>> print(Counter(y_resampled))
Counter({2: 4666, 1: 234, 0: 64})
>>> print(sorted(Counter(y_resampled).items()))
[(0, 64), (1, 234), (2, 4666)]

.. image:: ./auto_examples/under-sampling/images/sphx_glr_plot_comparison_under_sampling_005.png
:target: ./auto_examples/under-sampling/plot_comparison_under_sampling.html
Expand All @@ -362,8 +368,8 @@ removed. The class can be used as::
>>> iht = InstanceHardnessThreshold(random_state=0,
... estimator=LogisticRegression())
>>> X_resampled, y_resampled = iht.fit_sample(X, y)
>>> print(Counter(y_resampled))
Counter({0: 64, 1: 64, 2: 64})
>>> print(sorted(Counter(y_resampled).items()))
[(0, 64), (1, 64), (2, 64)]

This class has 2 important parameters. ``estimator`` will accept any
scikit-learn classifier which has a method ``predict_proba``. The classifier
Expand Down

0 comments on commit cddf39b

Please sign in to comment.