From d472a6a79c45af11f393440f0e2bd116ebb5a415 Mon Sep 17 00:00:00 2001 From: jarfa Date: Fri, 29 Jan 2016 13:35:30 -0500 Subject: [PATCH 1/2] much faster isotonic regression prediction (involved re-setting interpolation to linear) --- doc/whats_new.rst | 8 +++ sklearn/isotonic.py | 100 ++++++++++++++++++++++++--------- sklearn/tests/test_isotonic.py | 38 +++++++++++++ 3 files changed, 120 insertions(+), 26 deletions(-) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index fd991744b6919..047c7d8a4d791 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -99,6 +99,9 @@ Enhancements now accept score functions that take X, y as input and return only the scores. By `Nikolay Mayorov`_. + - Prediction of out-of-sample events with Isotonic Regression is now much + faster (over 1000x in tests with synthetic data). By `Jonathan Arfa`_. + Bug fixes ......... @@ -140,6 +143,9 @@ API changes summary - ``residual_metric`` has been deprecated in :class:`linear_model.RANSACRegressor`. Use ``loss`` instead. By `Manoj Kumar`_. + - Access to public attributes ``.X_`` and ``.y_`` has been deprecated in + :class:`isotonic.IsotonicRegression`. By `Jonathan Arfa`_. + .. _changes_0_17_1: @@ -4071,3 +4077,5 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson. .. _Andrea Bravi: https://github.com/AndreaBravi .. _Devashish Deshpande: https://github.com/dsquareindia + +.. _Jonathan Arfa: https://github.com/jarfa diff --git a/sklearn/isotonic.py b/sklearn/isotonic.py index 1a805d1625b63..01fd8cb1c1ce3 100644 --- a/sklearn/isotonic.py +++ b/sklearn/isotonic.py @@ -8,6 +8,7 @@ from scipy.stats import spearmanr from .base import BaseEstimator, TransformerMixin, RegressorMixin from .utils import as_float_array, check_array, check_consistent_length +from .utils import deprecated from .utils.fixes import astype from ._isotonic import _isotonic_regression, _make_unique import warnings @@ -193,12 +194,6 @@ class IsotonicRegression(BaseEstimator, TransformerMixin, RegressorMixin): Attributes ---------- - X_ : ndarray (n_samples, ) - A copy of the input X. - - y_ : ndarray (n_samples, ) - Isotonic fit of y. - X_min_ : float Minimum value of input array `X_` for left bound. @@ -234,6 +229,34 @@ def __init__(self, y_min=None, y_max=None, increasing=True, self.increasing = increasing self.out_of_bounds = out_of_bounds + @property + @deprecated("Attribute ``X_`` is deprecated in version 0.18 and will be" + " removed in version 0.20.") + def X_(self): + return self._X_ + + @X_.setter + def X_(self, value): + self._X_ = value + + @X_.deleter + def X_(self): + del self._X_ + + @property + @deprecated("Attribute ``y_`` is deprecated in version 0.18 and will" + " be removed in version 0.20.") + def y_(self): + return self._y_ + + @y_.setter + def y_(self, value): + self._y_ = value + + @y_.deleter + def y_(self): + del self._y_ + def _check_fit_data(self, X, y, sample_weight=None): if len(X.shape) != 1: raise ValueError("X should be a 1d array") @@ -252,10 +275,10 @@ def _build_f(self, X, y): # single y, constant prediction self.f_ = lambda x: y.repeat(x.shape) else: - self.f_ = interpolate.interp1d(X, y, kind='slinear', + self.f_ = interpolate.interp1d(X, y, kind='linear', bounds_error=bounds_error) - def _build_y(self, X, y, sample_weight): + def _build_y(self, X, y, sample_weight, trim_duplicates=True): """Build the y_ IsotonicRegression.""" check_consistent_length(X, y, sample_weight) X, y = [check_array(x, ensure_2d=False) for x in [X, y]] @@ -269,7 +292,8 @@ def _build_y(self, X, y, sample_weight): else: self.increasing_ = self.increasing - # If sample_weights is passed, removed zero-weight values and clean order + # If sample_weights is passed, removed zero-weight values and clean + # order if sample_weight is not None: sample_weight = check_array(sample_weight, ensure_2d=False) mask = sample_weight > 0 @@ -278,15 +302,37 @@ def _build_y(self, X, y, sample_weight): sample_weight = np.ones(len(y)) order = np.lexsort((y, X)) - order_inv = np.argsort(order) X, y, sample_weight = [astype(array[order], np.float64, copy=False) for array in [X, y, sample_weight]] - unique_X, unique_y, unique_sample_weight = _make_unique(X, y, sample_weight) - self.X_ = unique_X - self.y_ = isotonic_regression(unique_y, unique_sample_weight, self.y_min, - self.y_max, increasing=self.increasing_) + unique_X, unique_y, unique_sample_weight = _make_unique( + X, y, sample_weight) + + # Store _X_ and _y_ to maintain backward compat during the deprecation + # period of X_ and y_ + self._X_ = X = unique_X + self._y_ = y = isotonic_regression(unique_y, unique_sample_weight, + self.y_min, self.y_max, + increasing=self.increasing_) - return order_inv + # Handle the left and right bounds on X + self.X_min_, self.X_max_ = np.min(X), np.max(X) + + if trim_duplicates: + # Remove unnecessary points for faster prediction + keep_data = np.ones((len(y),), dtype=bool) + # Aside from the 1st and last point, remove points whose y values + # are equal to both the point before and the point after it. + keep_data[1:-1] = np.logical_or( + np.not_equal(y[1:-1], y[:-2]), + np.not_equal(y[1:-1], y[2:]) + ) + return X[keep_data], y[keep_data] + else: + # The ability to turn off trim_duplicates is only used to it make + # easier to unit test that removing duplicates in y does not have + # any impact the resulting interpolation function (besides + # prediction speed). + return X, y def fit(self, X, y, sample_weight=None): """Fit the model using X, y as training data. @@ -313,16 +359,18 @@ def fit(self, X, y, sample_weight=None): X is stored for future use, as `transform` needs X to interpolate new input data. """ - # Build y_ - self._build_y(X, y, sample_weight) - - # Handle the left and right bounds on X - self.X_min_ = np.min(self.X_) - self.X_max_ = np.max(self.X_) - - # Build f_ - self._build_f(self.X_, self.y_) - + # Transform y by running the isotonic regression algorithm and + # transform X accordingly. + X, y = self._build_y(X, y, sample_weight) + + # It is necessary to store the non-redundant part of the training set + # on the model to make it possible to support model persistence via + # the pickle module as the object built by scipy.interp1d is not + # picklable directly. + self._necessary_X_, self._necessary_y_ = X, y + + # Build the interpolation function + self._build_f(X, y) return self def transform(self, T): @@ -381,4 +429,4 @@ def __setstate__(self, state): We need to rebuild the interpolation function. """ self.__dict__.update(state) - self._build_f(self.X_, self.y_) + self._build_f(self._necessary_X_, self._necessary_y_) diff --git a/sklearn/tests/test_isotonic.py b/sklearn/tests/test_isotonic.py index c5ff5cb54e5ab..3a317c3ba99a5 100644 --- a/sklearn/tests/test_isotonic.py +++ b/sklearn/tests/test_isotonic.py @@ -346,3 +346,41 @@ def test_isotonic_zero_weight_loop(): # This will hang in failure case. regression.fit(x, y, sample_weight=w) + + +def test_fast_predict(): + # test that the faster prediction change doesn't + # affect out-of-sample predictions: + # https://github.com/scikit-learn/scikit-learn/pull/6206 + rng = np.random.RandomState(123) + n_samples = 10 ** 3 + # X values over the -10,10 range + X_train = 20.0 * rng.rand(n_samples) - 10 + y_train = np.less( + rng.rand(n_samples), + 1.0 / (1.0 + np.exp(-X_train)) + ).astype('int64') + + weights = rng.rand(n_samples) + # we also want to test that everything still works when some weights are 0 + weights[rng.rand(n_samples) < 0.1] = 0 + + slow_model = IsotonicRegression(y_min=0, y_max=1, out_of_bounds="clip") + fast_model = IsotonicRegression(y_min=0, y_max=1, out_of_bounds="clip") + + # Build interpolation function with ALL input data, not just the + # non-redundant subset. The following 2 lines are taken from the + # .fit() method, without removing unnecessary points + X_train_fit, y_train_fit = slow_model._build_y(X_train, y_train, + sample_weight=weights, + trim_duplicates=False) + slow_model._build_f(X_train_fit, y_train_fit) + + # fit with just the necessary data + fast_model.fit(X_train, y_train, sample_weight=weights) + + X_test = 20.0 * rng.rand(n_samples) - 10 + y_pred_slow = slow_model.predict(X_test) + y_pred_fast = fast_model.predict(X_test) + + assert_array_equal(y_pred_slow, y_pred_fast) From c694b9b004ff0f5adfeffacf888c42eb0717d2f6 Mon Sep 17 00:00:00 2001 From: jarfa Date: Thu, 4 Feb 2016 16:06:46 -0500 Subject: [PATCH 2/2] change to test_isotonic_regression_ties_min --- sklearn/tests/test_isotonic.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/tests/test_isotonic.py b/sklearn/tests/test_isotonic.py index 3a317c3ba99a5..42fa58253c58f 100644 --- a/sklearn/tests/test_isotonic.py +++ b/sklearn/tests/test_isotonic.py @@ -98,9 +98,9 @@ def test_isotonic_regression(): def test_isotonic_regression_ties_min(): # Setup examples with ties on minimum - x = [0, 1, 1, 2, 3, 4, 5] - y = [0, 1, 2, 3, 4, 5, 6] - y_true = [0, 1.5, 1.5, 3, 4, 5, 6] + x = [1, 1, 2, 3, 4, 5] + y = [1, 2, 3, 4, 5, 6] + y_true = [1.5, 1.5, 3, 4, 5, 6] # Check that we get identical results for fit/transform and fit_transform ir = IsotonicRegression()