Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions imblearn/under_sampling/instance_hardness_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,30 @@
from __future__ import division, print_function

import warnings

from collections import Counter

import numpy as np

from six import string_types
import sklearn
from sklearn.base import ClassifierMixin
from sklearn.ensemble import RandomForestClassifier
from sklearn.cross_validation import StratifiedKFold

from six import string_types

from ..base import BaseBinarySampler


def _get_cv_splits(X, y, cv, random_state):
if hasattr(sklearn, 'model_selection'):
from sklearn.model_selection import StratifiedKFold
cv_iterator = StratifiedKFold(
n_splits=cv, shuffle=False, random_state=random_state).split(X, y)
else:
from sklearn.cross_validation import StratifiedKFold
cv_iterator = StratifiedKFold(
y, n_folds=cv, shuffle=False, random_state=random_state)

return cv_iterator


class InstanceHardnessThreshold(BaseBinarySampler):
"""Class to perform under-sampling based on the instance hardness
threshold.
Expand Down Expand Up @@ -225,8 +235,7 @@ def _sample(self, X, y):
"""

# Create the different folds
skf = StratifiedKFold(y, n_folds=self.cv, shuffle=False,
random_state=self.random_state)
skf = _get_cv_splits(X, y, self.cv, self.random_state)

probabilities = np.zeros(y.shape[0], dtype=float)

Expand Down