diff --git a/.gitignore b/.gitignore index 8edb32d96..12b751b5f 100644 --- a/.gitignore +++ b/.gitignore @@ -56,4 +56,8 @@ docs/_build/ # PyBuilder target/ +# vim +*.swp + +# emacs *~ \ No newline at end of file diff --git a/build_tools/travis/install.sh b/build_tools/travis/install.sh index 9b15aa46a..a3ba43624 100644 --- a/build_tools/travis/install.sh +++ b/build_tools/travis/install.sh @@ -33,7 +33,10 @@ conda create -n testenv --yes python=$PYTHON_VERSION pip nose \ pip install nose-timer # Install libgfortran with conda -conda install --yes libgfortran scikit-learn six +conda install --yes libgfortran \ + numpy=1.10.4 scipy=0.17.1 \ + scikit-learn=0.17.1 \ + six=1.10.0 if [[ "$COVERAGE" == "true" ]]; then pip install coverage coveralls diff --git a/doc/api.rst b/doc/api.rst index 74cfca2a7..f8f6f368b 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -22,6 +22,7 @@ Classes unbalanced_dataset.under_sampling.ClusterCentroids unbalanced_dataset.under_sampling.CondensedNearestNeighbour unbalanced_dataset.under_sampling.EditedNearestNeighbours + unbalanced_dataset.under_sampling.InstanceHardnessThreshold unbalanced_dataset.under_sampling.NearMiss unbalanced_dataset.under_sampling.NeighbourhoodCleaningRule unbalanced_dataset.under_sampling.OneSidedSelection diff --git a/examples/under-sampling/plot_instance_hardness_threshold.py b/examples/under-sampling/plot_instance_hardness_threshold.py new file mode 100644 index 000000000..f7d83ae81 --- /dev/null +++ b/examples/under-sampling/plot_instance_hardness_threshold.py @@ -0,0 +1,63 @@ +""" +=========================== +Instance Hardness Threshold +=========================== + +An illustration of the instance hardness threshold method. + +""" + +print(__doc__) +import numpy as np + +import matplotlib.pyplot as plt +import seaborn as sns +sns.set() + +# Define some color for the plotting +almost_black = '#262626' +palette = sns.color_palette() + +from sklearn.datasets import make_classification +from sklearn.decomposition import PCA +from sklearn.svm import SVC + +from unbalanced_dataset.under_sampling import InstanceHardnessThreshold + +# Generate the dataset +X, y = make_classification(n_classes=2, class_sep=1., weights=[0.05, 0.95], + n_informative=3, n_redundant=1, flip_y=0, + n_features=20, n_clusters_per_class=1, + n_samples=5000, random_state=10) + +pca = PCA(n_components=2) +X_vis = pca.fit_transform(X) + +# Two subplots, unpack the axes array immediately +f, axs = plt.subplots(2, 2) + +axs = [a for ax in axs for a in ax] +for ax, ratio in zip(axs, [0.0, 0.1, 0.3, 0.5]): + if ratio == 0.0: + ax.scatter(X_vis[y == 0, 0], X_vis[y == 0, 1], label="Class #0", + alpha=0.5, edgecolor=almost_black, facecolor=palette[0], + linewidth=0.15) + ax.scatter(X_vis[y == 1, 0], X_vis[y == 1, 1], label="Class #1", + alpha=0.5, edgecolor=almost_black, facecolor=palette[2], + linewidth=0.15) + ax.set_title('Original set') + else: + estimator = SVC(probability=True) + iht = InstanceHardnessThreshold(estimator, ratio=ratio) + X_res, y_res = iht.fit_transform(X, y) + X_res_vis = pca.transform(X_res) + + ax.scatter(X_res_vis[y_res == 0, 0], X_res_vis[y_res == 0, 1], + label="Class #0", alpha=.5, edgecolor=almost_black, + facecolor=palette[0], linewidth=0.15) + ax.scatter(X_res_vis[y_res == 1, 0], X_res_vis[y_res == 1, 1], + label="Class #1", alpha=.5, edgecolor=almost_black, + facecolor=palette[2], linewidth=0.15) + ax.set_title('Instance Hardness Threshold ({})'.format(ratio)) + +plt.show() diff --git a/unbalanced_dataset/under_sampling/__init__.py b/unbalanced_dataset/under_sampling/__init__.py index f8cfa43d4..5e410628f 100644 --- a/unbalanced_dataset/under_sampling/__init__.py +++ b/unbalanced_dataset/under_sampling/__init__.py @@ -12,6 +12,7 @@ from .one_sided_selection import OneSidedSelection from .neighbourhood_cleaning_rule import NeighbourhoodCleaningRule from .edited_nearest_neighbours import EditedNearestNeighbours +from .instance_hardness_threshold import InstanceHardnessThreshold __all__ = ['UnderSampler', 'RandomUnderSampler', @@ -21,4 +22,5 @@ 'CondensedNearestNeighbour', 'OneSidedSelection', 'NeighbourhoodCleaningRule', - 'EditedNearestNeighbours'] + 'EditedNearestNeighbours', + 'InstanceHardnessThreshold'] diff --git a/unbalanced_dataset/under_sampling/edited_nearest_neighbours.py b/unbalanced_dataset/under_sampling/edited_nearest_neighbours.py index f933fb0d0..fd50eeec7 100644 --- a/unbalanced_dataset/under_sampling/edited_nearest_neighbours.py +++ b/unbalanced_dataset/under_sampling/edited_nearest_neighbours.py @@ -16,7 +16,7 @@ class EditedNearestNeighbours(UnderSampler): - """Class to perform under-sampling based on the condensed nearest neighbour + """Class to perform under-sampling based on the edited nearest neighbour method. Parameters diff --git a/unbalanced_dataset/under_sampling/instance_hardness_threshold.py b/unbalanced_dataset/under_sampling/instance_hardness_threshold.py new file mode 100644 index 000000000..2b09c5306 --- /dev/null +++ b/unbalanced_dataset/under_sampling/instance_hardness_threshold.py @@ -0,0 +1,235 @@ +"""Class to perform under-sampling based on the instance hardness +threshold.""" +from __future__ import print_function +from __future__ import division + +import numpy as np + +from collections import Counter + +from sklearn.utils import check_X_y +from sklearn.cross_validation import StratifiedKFold + +from .under_sampler import UnderSampler + + +class InstanceHardnessThreshold(UnderSampler): + """Class to perform under-sampling based on the instance hardness + threshold. + + Parameters + ---------- + estimator : sklearn classifier + Classifier to be used in to estimate instance hardness of the samples. + + ratio : str or float, optional (default='auto') + If 'auto', the ratio will be defined automatically to balanced + the dataset. Otherwise, the ratio will corresponds to the number + of samples in the minority class over the the number of samples + in the majority class. + + cv : int, optional (default=5) + Number of folds to be used when estimating samples' instance hardness. + + return_indices : bool, optional (default=False) + Either to return or not the indices which will be selected from + the majority class. + + random_state : int or None, optional (default=None) + Seed for random number generation. + + verbose : bool, optional (default=True) + Boolean to either or not print information about the processing + + n_jobs : int, optional (default=-1) + The number of thread to open when it is possible. + + Attributes + ---------- + ratio_ : str or float, optional (default='auto') + If 'auto', the ratio will be defined automatically to balanced + the dataset. Otherwise, the ratio will corresponds to the number + of samples in the minority class over the the number of samples + in the majority class. + + rs_ : int or None, optional (default=None) + Seed for random number generation. + + min_c_ : str or int + The identifier of the minority class. + + max_c_ : str or int + The identifier of the majority class. + + stats_c_ : dict of str/int : int + A dictionary in which the number of occurences of each class is + reported. + + estimator_ : sklearn classifier + Classifier used in to estimate instance hardness of the samples. + + cv : int, optional (default=5) + Number of folds used when estimating samples' instance hardness. + + Notes + ----- + The method is based on [1]_. + + This class does not support multi-class. + + References + ---------- + .. [1] D. Smith, Michael R., Tony Martinez, and Christophe Giraud-Carrier. + "An instance level analysis of data complexity." Machine learning + 95.2 (2014): 225-256. + + """ + + def __init__(self, estimator, ratio='auto', return_indices=False, cv=5, + random_state=None, verbose=True, n_jobs=-1): + """Initialisation of Instance Hardness Threshold object. + + Parameters + ---------- + estimator : sklearn classifier + Classifier to be used in to estimate instance hardness of the + samples. + + ratio : str or float, optional (default='auto') + If 'auto', the ratio will be defined automatically to balanced + the dataset. Otherwise, the ratio will corresponds to the number + of samples in the minority class over the the number of samples + in the majority class. + + cv : int, optional (default=5) + Number of folds to be used when estimating samples' instance + hardness. + + return_indices : bool, optional (default=False) + Either to return or not the indices which will be selected from + the majority class. + + random_state : int or None, optional (default=None) + Seed for random number generation. + + verbose : bool, optional (default=True) + Boolean to either or not print information about the processing + + n_jobs : int, optional (default=-1) + The number of thread to open when it is possible. + + Returns + ------- + None + + """ + super(InstanceHardnessThreshold, self).__init__( + ratio=ratio, + return_indices=return_indices, + random_state=random_state, + verbose=verbose) + + if not hasattr(estimator, 'predict_proba'): + raise ValueError('Estimator does not have predict_proba method.') + else: + self.estimator_ = estimator + + self.cv = cv + self.n_jobs = n_jobs + + def fit(self, X, y): + """Find the classes statistics before to perform sampling. + + Parameters + ---------- + X : ndarray, shape (n_samples, n_features) + Matrix containing the data which have to be sampled. + + y : ndarray, shape (n_samples, ) + Corresponding label for each sample in X. + + Returns + ------- + self : object, + Return self. + + """ + # Check the consistency of X and y + X, y = check_X_y(X, y) + + super(InstanceHardnessThreshold, self).fit(X, y) + + return self + + def transform(self, X, y): + """Resample the dataset. + + Parameters + ---------- + X : ndarray, shape (n_samples, n_features) + Matrix containing the data which have to be sampled. + + y : ndarray, shape (n_samples, ) + Corresponding label for each sample in X. + + Returns + ------- + X_resampled : ndarray, shape (n_samples_new, n_features) + The array containing the resampled data. + + y_resampled : ndarray, shape (n_samples_new) + The corresponding label of `X_resampled` + + idx_under : ndarray, shape (n_samples, ) + If `return_indices` is `True`, a boolean array will be returned + containing the which samples have been selected. + + """ + # Check the consistency of X and y + X, y = check_X_y(X, y) + + super(InstanceHardnessThreshold, self).transform(X, y) + + skf = StratifiedKFold(y, n_folds=self.cv, shuffle=False, + random_state=self.rs_) + + probabilities = np.zeros(y.shape[0], dtype=float) + + for train_index, test_index in skf: + X_train, X_test = X[train_index], X[test_index] + y_train, y_test = y[train_index], y[test_index] + + self.estimator_.fit(X_train, y_train) + + probs = self.estimator_.predict_proba(X_test) + classes = self.estimator_.classes_ + probabilities[test_index] = [ + probs[l, np.where(classes == c)[0][0]] + for l, c in enumerate(y_test)] + + # Compute the number of cluster needed + if self.ratio_ == 'auto': + num_samples = self.stats_c_[self.min_c_] + else: + num_samples = int(self.stats_c_[self.min_c_] / self.ratio_) + + # Find the percentile corresponding to the top num_samples + threshold = np.percentile( + probabilities[y != self.min_c_], + (1. - (num_samples / self.stats_c_[self.maj_c_])) * 100.) + + mask = np.logical_or(probabilities >= threshold, y == self.min_c_) + + # Sample the data + X_resampled = X[mask] + y_resampled = y[mask] + + if self.verbose: + print("Under-sampling performed: {}".format(Counter(y_resampled))) + + # If we need to offer support for the indices + if self.return_indices: + idx_under = np.nonzero(mask)[0] + return X_resampled, y_resampled, idx_under + else: + return X_resampled, y_resampled diff --git a/unbalanced_dataset/under_sampling/tests/data/iht_idx.npy b/unbalanced_dataset/under_sampling/tests/data/iht_idx.npy new file mode 100644 index 000000000..880becfe4 Binary files /dev/null and b/unbalanced_dataset/under_sampling/tests/data/iht_idx.npy differ diff --git a/unbalanced_dataset/under_sampling/tests/data/iht_x.npy b/unbalanced_dataset/under_sampling/tests/data/iht_x.npy new file mode 100644 index 000000000..ae894e498 Binary files /dev/null and b/unbalanced_dataset/under_sampling/tests/data/iht_x.npy differ diff --git a/unbalanced_dataset/under_sampling/tests/data/iht_x_05.npy b/unbalanced_dataset/under_sampling/tests/data/iht_x_05.npy new file mode 100644 index 000000000..8f38e4877 Binary files /dev/null and b/unbalanced_dataset/under_sampling/tests/data/iht_x_05.npy differ diff --git a/unbalanced_dataset/under_sampling/tests/data/iht_y.npy b/unbalanced_dataset/under_sampling/tests/data/iht_y.npy new file mode 100644 index 000000000..2789ae428 Binary files /dev/null and b/unbalanced_dataset/under_sampling/tests/data/iht_y.npy differ diff --git a/unbalanced_dataset/under_sampling/tests/data/iht_y_05.npy b/unbalanced_dataset/under_sampling/tests/data/iht_y_05.npy new file mode 100644 index 000000000..a30f6b30a Binary files /dev/null and b/unbalanced_dataset/under_sampling/tests/data/iht_y_05.npy differ diff --git a/unbalanced_dataset/under_sampling/tests/test_instance_hardness_threshold.py b/unbalanced_dataset/under_sampling/tests/test_instance_hardness_threshold.py new file mode 100644 index 000000000..2e37c9965 --- /dev/null +++ b/unbalanced_dataset/under_sampling/tests/test_instance_hardness_threshold.py @@ -0,0 +1,159 @@ +"""Test the module .""" +from __future__ import print_function + +import os + +import numpy as np +from numpy.testing import assert_raises +from numpy.testing import assert_equal +from numpy.testing import assert_array_equal + +from sklearn.datasets import make_classification +from sklearn.ensemble import GradientBoostingClassifier +from sklearn.svm import SVC + +from unbalanced_dataset.under_sampling import InstanceHardnessThreshold + + +# Generate a global dataset to use +RND_SEED = 0 +X, Y = make_classification(n_classes=2, class_sep=2, weights=[0.1, 0.9], + n_informative=3, n_redundant=1, flip_y=0, + n_features=20, n_clusters_per_class=1, + n_samples=5000, random_state=RND_SEED) +ESTIMATOR = GradientBoostingClassifier() + + +def test_iht_bad_ratio(): + """Test either if an error is raised with a wrong decimal value for + the ratio""" + + # Define a negative ratio + ratio = -1.0 + assert_raises(ValueError, InstanceHardnessThreshold, ESTIMATOR, + ratio=ratio) + + # Define a ratio greater than 1 + ratio = 100.0 + assert_raises(ValueError, InstanceHardnessThreshold, ESTIMATOR, + ratio=ratio) + + # Define ratio as an unknown string + ratio = 'rnd' + assert_raises(ValueError, InstanceHardnessThreshold, ESTIMATOR, + ratio=ratio) + + # Define ratio as a list which is not supported + ratio = [.5, .5] + assert_raises(ValueError, InstanceHardnessThreshold, ESTIMATOR, + ratio=ratio) + + +def test_iht_estimator_no_proba(): + """Test either if an error is raised when the estimator does not have + predict_proba function""" + + # Resample the data + ratio = 0.5 + est = SVC() + assert_raises(ValueError, InstanceHardnessThreshold, est, ratio=ratio, + random_state=RND_SEED) + +def test_iht_init(): + """Test the initialisation of the object""" + + # Define a ratio + verbose = True + ratio = 'auto' + iht = InstanceHardnessThreshold(ESTIMATOR, ratio=ratio, + random_state=RND_SEED, + verbose=verbose) + + assert_equal(iht.rs_, RND_SEED) + assert_equal(iht.verbose, verbose) + assert_equal(iht.min_c_, None) + assert_equal(iht.maj_c_, None) + assert_equal(iht.stats_c_, {}) + + +def test_iht_fit_single_class(): + """Test either if an error when there is a single class""" + + # Create the object + iht = InstanceHardnessThreshold(ESTIMATOR, random_state=RND_SEED) + # Resample the data + # Create a wrong y + y_single_class = np.zeros((X.shape[0], )) + assert_raises(RuntimeError, iht.fit, X, y_single_class) + + +def test_iht_fit(): + """Test the fitting method""" + + # Create the object + iht = InstanceHardnessThreshold(ESTIMATOR, random_state=RND_SEED) + # Fit the data + iht.fit(X, Y) + + # Check if the data information have been computed + assert_equal(iht.min_c_, 0) + assert_equal(iht.maj_c_, 1) + assert_equal(iht.stats_c_[0], 500) + assert_equal(iht.stats_c_[1], 4500) + + +def test_iht_transform_wt_fit(): + """Test either if an error is raised when transform is called before + fitting""" + + # Create the object + iht = InstanceHardnessThreshold(ESTIMATOR, random_state=RND_SEED) + assert_raises(RuntimeError, iht.transform, X, Y) + + +def test_iht_fit_transform(): + """Test the fit transform routine""" + + # Resample the data + iht = InstanceHardnessThreshold(ESTIMATOR, random_state=RND_SEED) + X_resampled, y_resampled = iht.fit_transform(X, Y) + + currdir = os.path.dirname(os.path.abspath(__file__)) + X_gt = np.load(os.path.join(currdir, 'data', 'iht_x.npy')) + y_gt = np.load(os.path.join(currdir, 'data', 'iht_y.npy')) + assert_array_equal(X_resampled, X_gt) + assert_array_equal(y_resampled, y_gt) + + +def test_iht_fit_transform_with_indices(): + """Test the fit transform routine with indices support""" + + # Resample the data + iht = InstanceHardnessThreshold(ESTIMATOR, return_indices=True, + random_state=RND_SEED) + X_resampled, y_resampled, idx_under = iht.fit_transform(X, Y) + + currdir = os.path.dirname(os.path.abspath(__file__)) + X_gt = np.load(os.path.join(currdir, 'data', 'iht_x.npy')) + y_gt = np.load(os.path.join(currdir, 'data', 'iht_y.npy')) + idx_gt = np.load(os.path.join(currdir, 'data', 'iht_idx.npy')) + assert_array_equal(X_resampled, X_gt) + assert_array_equal(y_resampled, y_gt) + assert_array_equal(idx_under, idx_gt) + + +def test_iht_fit_transform_half(): + """Test the fit transform routine with a 0.5 ratio""" + + # Resample the data + ratio = 0.5 + iht = InstanceHardnessThreshold(ESTIMATOR, ratio=ratio, + random_state=RND_SEED) + X_resampled, y_resampled = iht.fit_transform(X, Y) + + currdir = os.path.dirname(os.path.abspath(__file__)) + X_gt = np.load(os.path.join(currdir, 'data', 'iht_x_05.npy')) + y_gt = np.load(os.path.join(currdir, 'data', 'iht_y_05.npy')) + assert_array_equal(X_resampled, X_gt) + assert_array_equal(y_resampled, y_gt) + diff --git a/unbalanced_dataset/under_sampling/tests/test_random_under_sampler.py b/unbalanced_dataset/under_sampling/tests/test_random_under_sampler.py index e304dd75c..ec39359dc 100644 --- a/unbalanced_dataset/under_sampling/tests/test_random_under_sampler.py +++ b/unbalanced_dataset/under_sampling/tests/test_random_under_sampler.py @@ -1,4 +1,4 @@ -"""Test the module under sampler.""" +"""Test the module random under sampler.""" from __future__ import print_function import os diff --git a/unbalanced_dataset/version.py b/unbalanced_dataset/version.py index 618cfb83f..5eb391eeb 100644 --- a/unbalanced_dataset/version.py +++ b/unbalanced_dataset/version.py @@ -31,7 +31,7 @@ # collections.OrderedDict to preserve Python 2.6 compatibility. REQUIRED_MODULE_METADATA = ( ('numpy', { - 'min_version': '1.11.0', + 'min_version': '1.10.4', 'required_at_installation': True, 'install_info': _UNBALANCED_DATASET_INSTALL_MSG}), ('scipy', { @@ -60,7 +60,8 @@ def _import_module_with_version_check( except ImportError as exc: user_friendly_info = ('Module "{0}" could not be found. {1}').format( module_name, - install_info or 'Please install it properly to use unbalanced_dataset.') + install_info or 'Please install it properly to use' + ' unbalanced_dataset.') exc.args += (user_friendly_info,) raise @@ -73,8 +74,8 @@ def _import_module_with_version_check( if version_too_old: message = ( 'A {module_name} version of at least {minimum_version} ' - 'is required to use unbalanced_dataset. {module_version} was found. ' - 'Please upgrade {module_name}').format( + 'is required to use unbalanced_dataset. {module_version} was ' + 'found. Please upgrade {module_name}').format( module_name=module_name, minimum_version=minimum_version, module_version=module_version)