Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG+1] Add predict_proba(X) and outlier handler for RadiusNeighborsClassifier #9597

Merged
merged 63 commits into from
Aug 7, 2019
Merged
Show file tree
Hide file tree
Changes from 52 commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
236c766
add predict_proba method for RadiusNeighborsClassifier
webber26232 Aug 21, 2017
67a59cd
add warning
webber26232 Aug 21, 2017
b69eef1
DOC Add predict_proba in class description
webber26232 Aug 21, 2017
9ac27e1
Finish formats
webber26232 Aug 21, 2017
9b46e2a
Finish formats
webber26232 Aug 21, 2017
dc9069c
Add test, improve warning
webber26232 Aug 22, 2017
f30ed4f
Add test
webber26232 Aug 22, 2017
0982ee3
add outlier handler
webber26232 Aug 24, 2017
9565eac
format, 2.7 float divide
webber26232 Aug 24, 2017
5baa070
modify code length
webber26232 Aug 24, 2017
1ceab5a
indent
webber26232 Aug 24, 2017
cd1c25f
add _check_outlier_handler, prepare for regressor
webber26232 Aug 25, 2017
be801b6
format
webber26232 Aug 25, 2017
d6a2ff8
bug
webber26232 Aug 25, 2017
96cb19e
outlier
webber26232 Aug 29, 2017
4b480a5
random int -> randomly choose class labels from y
webber26232 Aug 31, 2017
6f6c96c
fix weights index and vector scalar bug
webber26232 Sep 17, 2017
32e7207
fix weights inlier index, change inlier addition to assign, get index…
webber26232 Sep 17, 2017
8bb0ecf
resolve regression conflict
webber26232 Jan 15, 2018
0740152
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn…
webber26232 Jan 15, 2018
326fc22
formatting regression, remove prior/uniform, add most_frequent, move …
webber26232 Feb 25, 2018
0804952
fix inheritance
webber26232 Feb 25, 2018
3964fa5
fix typo
webber26232 Feb 25, 2018
6a53c9f
move oultier handler into fit(), add outlier label lists for multi ou…
webber26232 Mar 11, 2018
610a722
_init_param to super().__init__()
webber26232 Mar 11, 2018
db5a29c
change None to "raise" in testing
webber26232 Mar 11, 2018
7f491db
format
webber26232 Mar 11, 2018
9880d51
format
webber26232 Mar 11, 2018
8dcb4f3
change back from "raise" to None, fix some typos
webber26232 Apr 24, 2018
e4e28fa
fix indent conflicts
webber26232 Apr 24, 2018
135f0c4
fix documentation, indices, add predict_proba tests
webber26232 May 5, 2018
d86e1dd
add a space in doc
webber26232 May 5, 2018
7f7a840
improve documentation
webber26232 May 8, 2018
40c41b4
add warning for zero probas, simplify outlier indexing, merge testing…
webber26232 May 12, 2018
e8e432e
Merge branch 'master' into RadNeiClfPredProb
webber26232 May 12, 2018
f7e5891
fix assert_warns
webber26232 May 12, 2018
c9dd3c2
Merge branch 'RadNeiClfPredProb' of github.com:webber26232/scikit-lea…
webber26232 May 12, 2018
9d37e4f
Merge branch 'master' into RadNeiClfPredProb
webber26232 Aug 3, 2018
d365c0a
Merge branch 'master' into RadNeiClfPredProb
webber26232 Sep 27, 2018
291d1d7
Merge branch 'master' into RadNeiClfPredProb
webber26232 Jul 3, 2019
9ff7130
Merge pull request #1 from scikit-learn/master
webber26232 Jul 6, 2019
ad96981
Merge branch 'RadNeiClfPredProb' of github.com:webber26232/scikit-learn
webber26232 Jul 14, 2019
b32b45c
change predict() implementaion, add docs in whats_new
webber26232 Jul 21, 2019
1a57da0
Merge pull request #2 from scikit-learn/master
webber26232 Jul 21, 2019
6b63d38
Merge branch 'master' of github.com:webber26232/scikit-learn
webber26232 Jul 21, 2019
4b1c9cf
Merge branch 'master' into RadNeiClfPredProb
webber26232 Jul 21, 2019
19d95e5
remove changes in 0.20 docs
webber26232 Jul 21, 2019
69b88e6
change external.six to six
webber26232 Jul 21, 2019
8c8163b
format doc strings
webber26232 Jul 21, 2019
543e1ed
fix grammer in docs
webber26232 Jul 21, 2019
e73a20a
Move predict function for cleaner diff
TomDLT Jul 22, 2019
0208a60
Merge branch 'master' into RadNeiClfPredProb
webber26232 Aug 4, 2019
1638ee5
Update doc/whats_new/v0.22.rst
webber26232 Aug 5, 2019
39180ac
Update sklearn/neighbors/tests/test_neighbors.py
webber26232 Aug 6, 2019
644a9f1
Update doc/whats_new/v0.22.rst
webber26232 Aug 6, 2019
a21eb5c
Update doc/whats_new/v0.22.rst
webber26232 Aug 6, 2019
87b5288
add doc for iterations over multi-outputs, use np.flatnonzero instead…
webber26232 Aug 6, 2019
13712f2
Merge branch 'RadNeiClfPredProb' of github.com:webber26232/scikit-lea…
webber26232 Aug 6, 2019
ab73440
add outlier_label scalar verification and unit test, fix format, add …
webber26232 Aug 6, 2019
83cba2d
Merge pull request #3 from scikit-learn/master
webber26232 Aug 6, 2019
3783c9b
Merge branch 'RadNeiClfPredProb' of github.com:webber26232/scikit-lea…
webber26232 Aug 6, 2019
5f2d33e
better docs
webber26232 Aug 6, 2019
6ac2ffb
Update v0.22.rst
TomDLT Aug 7, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion doc/whats_new/v0.22.rst
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ Changelog

:mod:`sklearn.feature_selection`
................................

- |Fix| Fixed a bug where :class:`VarianceThreshold` with `threshold=0` did not
remove constant features due to numerical instability, by using range
rather than variance in this case.
Expand All @@ -275,7 +276,18 @@ Changelog
:pr:`14035` by `Guillaume Lemaitre <glemaitre>`.

:mod:`sklearn.neighbors`
.............................
....................

- |Feature| :class:`neighbors.RadiusNeighborsClassifier` now supports
predicting probabilities by using predict_proba() and supports more
webber26232 marked this conversation as resolved.
Show resolved Hide resolved
outlier_label options: 'most_frequent', different oulier_labels
webber26232 marked this conversation as resolved.
Show resolved Hide resolved
for multi-outputs.
:issue:`9629` by :user:`Wenbo Zhao <webber26232>`.

- |Efficiency| Efficiency improvements for
:func:`neighbors.RadiusNeighborsClassifier.prdict` by changing
webber26232 marked this conversation as resolved.
Show resolved Hide resolved
implementation from scipy.stats.mode to numpy.bincount.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is the right description anymore.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @jnothman meant something along the lines of,

- |Efficiency| Efficiency improvements for
  :func:`neighbors.RadiusNeighborsClassifier.predict_proba` by changing
  implementation from scipy.stats.mode to numpy.bincount, and for 
  :func:`neighbors.RadiusNeighborsClassifier.predict` that is now computed as
  an `argmax` of `predict_proba`.

?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have neighbors.RadiusNeighborsClassifier.predict_proba in the current master. It is created in this branch.

How about

  • |Efficiency| Efficiency improvements for :func:neighbors.RadiusNeighborsClassifier.predictby changing implementation from usingscipy.stats.modeto usingnumpy.bincountinpredict_proba.

Right, of course. But then we can't say that the implementation of predict_proba changed from using scipy.stats.mode, as it didn't exist previously. Maybe just saying that "predict is now computed from predict_proba" (and that it's faster) without going in much details.

:pr:`9597` by :user:`Wenbo Zhao <webber26232>`.

- |Fix| KNearestRegressor now throws error when fit on non-square data and
metric = precomputed. :class:`neighbors.NeighborsBase`
Expand Down
181 changes: 149 additions & 32 deletions sklearn/neighbors/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@

import numpy as np
from scipy import stats
from six import string_types
from ..utils.extmath import weighted_mode
from ..utils.validation import _is_arraylike, _num_samples

import warnings
from .base import \
_check_weights, _get_weights, \
NeighborsBase, KNeighborsMixin,\
Expand Down Expand Up @@ -141,7 +144,6 @@ def __init__(self, n_neighbors=5,
weights='uniform', algorithm='auto', leaf_size=30,
p=2, metric='minkowski', metric_params=None, n_jobs=None,
**kwargs):

super().__init__(
n_neighbors=n_neighbors,
algorithm=algorithm,
Expand All @@ -151,7 +153,7 @@ def __init__(self, n_neighbors=5,
self.weights = _check_weights(weights)

def predict(self, X):
"""Predict the class labels for the provided data
"""Predict the class labels for the provided data.

Parameters
----------
Expand All @@ -174,7 +176,7 @@ def predict(self, X):
classes_ = [self.classes_]

n_outputs = len(classes_)
n_samples = X.shape[0]
n_samples = _num_samples(X)
weights = _get_weights(neigh_dist, self.weights)

y_pred = np.empty((n_samples, n_outputs), dtype=classes_[0].dtype)
Expand Down Expand Up @@ -218,7 +220,7 @@ def predict_proba(self, X):
_y = self._y.reshape((-1, 1))
classes_ = [self.classes_]

n_samples = X.shape[0]
n_samples = _num_samples(X)

weights = _get_weights(neigh_dist, self.weights)
if weights is None:
Expand Down Expand Up @@ -302,10 +304,13 @@ class RadiusNeighborsClassifier(NeighborsBase, RadiusNeighborsMixin,
metric. See the documentation of the DistanceMetric class for a
list of available metrics.

outlier_label : int, optional (default = None)
Label, which is given for outlier samples (samples with no
neighbors on given radius).
If set to None, ValueError is raised, when outlier is detected.
outlier_label : {manual label, 'most_frequent'}, optional (default = None)
label for outlier samples (samples with no neighbors in given radius).

- manual label: str or int label (should be the same type as y)
webber26232 marked this conversation as resolved.
Show resolved Hide resolved
or list of manual labels if multi-output is used.
- 'most_frequent' : assign the most frequent label of y to outliers.
- None : when any outlier is detected, ValueError will be raised.

metric_params : dict, optional (default = None)
Additional keyword arguments for the metric function.
Expand Down Expand Up @@ -346,6 +351,8 @@ class RadiusNeighborsClassifier(NeighborsBase, RadiusNeighborsMixin,
RadiusNeighborsClassifier(...)
>>> print(neigh.predict([[1.5]]))
[0]
>>> print(neigh.predict_proba([[1.0]]))
[[0.66666667 0.33333333]]

See also
--------
Expand Down Expand Up @@ -375,8 +382,57 @@ def __init__(self, radius=1.0, weights='uniform',
self.weights = _check_weights(weights)
self.outlier_label = outlier_label

def fit(self, X, y):
"""Fit the model using X as training data and y as target values

Parameters
----------
X : {array-like, sparse matrix, BallTree, KDTree}
Training data. If array or matrix, shape [n_samples, n_features],
or [n_samples, n_samples] if metric='precomputed'.

y : {array-like, sparse matrix}
Target values of shape = [n_samples] or [n_samples, n_outputs]

"""

SupervisedIntegerMixin.fit(self, X, y)

classes_ = self.classes_
_y = self._y
if not self.outputs_2d_:
_y = self._y.reshape((-1, 1))
classes_ = [self.classes_]

if self.outlier_label is None:
outlier_label_ = None
elif self.outlier_label == 'most_frequent':
outlier_label_ = []
for k, classes_k in enumerate(classes_):
TomDLT marked this conversation as resolved.
Show resolved Hide resolved
label_count = np.bincount(_y[:, k])
outlier_label_.append(classes_k[label_count.argmax()])
else:
if (_is_arraylike(self.outlier_label) and
not isinstance(self.outlier_label, string_types)):
if len(self.outlier_label) != len(classes_):
raise ValueError('The length of outlier_label: {} is '
'inconsistent with the output '
'length: {}'.format(self.outlier_label,
len(classes_)))
outlier_label_ = self.outlier_label
else:
outlier_label_ = [self.outlier_label] * len(classes_)
# ensure the dtype of outlier label is consistent with y
if any(np.append(classes, label).dtype != classes.dtype
for classes, label in zip(classes_, outlier_label_)):
raise TypeError('The dtype of outlier_label is'
'inconsistent with y')

self.outlier_label_ = outlier_label_
return self

def predict(self, X):
"""Predict the class labels for the provided data
"""Predict the class labels for the provided data.

Parameters
----------
Expand All @@ -388,54 +444,115 @@ def predict(self, X):
-------
y : array of shape [n_samples] or [n_samples, n_outputs]
Class labels for each data sample.
"""

probs = self.predict_proba(X)
classes_ = self.classes_

if not self.outputs_2d_:
probs = [probs]
classes_ = [self.classes_]

n_outputs = len(classes_)
n_samples = probs[0].shape[0]
y_pred = np.empty((n_samples, n_outputs),
dtype=classes_[0].dtype)

for k, prob in enumerate(probs):
TomDLT marked this conversation as resolved.
Show resolved Hide resolved
max_prob_index = prob.argmax(axis=1)
y_pred[:, k] = classes_[k].take(max_prob_index)

outlier_zero_probs = (prob == 0).all(axis=1)
if outlier_zero_probs.any():
zero_prob_index = np.flatnonzero(outlier_zero_probs)
y_pred[zero_prob_index, k] = self.outlier_label_[k]

if not self.outputs_2d_:
y_pred = y_pred.ravel()

return y_pred

def predict_proba(self, X):
"""Return probability estimates for the test data X.

Parameters
----------
X : array-like, shape (n_query, n_features), \
or (n_query, n_indexed) if metric == 'precomputed'
Test samples.

Returns
-------
p : array of shape = [n_samples, n_classes], or a list of n_outputs
of such arrays if n_outputs > 1.
The class probabilities of the input samples. Classes are ordered
webber26232 marked this conversation as resolved.
Show resolved Hide resolved
by lexicographic order.
"""

X = check_array(X, accept_sparse='csr')
n_samples = X.shape[0]
n_samples = _num_samples(X)

neigh_dist, neigh_ind = self.radius_neighbors(X)
inliers = [i for i, nind in enumerate(neigh_ind) if len(nind) != 0]
outliers = [i for i, nind in enumerate(neigh_ind) if len(nind) == 0]
outlier_mask = np.zeros(n_samples, dtype=np.bool)
outlier_mask[:] = [len(nind) == 0 for nind in neigh_ind]
outliers = np.flatnonzero(outlier_mask)
inliers = np.flatnonzero(~outlier_mask)

classes_ = self.classes_
_y = self._y
if not self.outputs_2d_:
_y = self._y.reshape((-1, 1))
classes_ = [self.classes_]
n_outputs = len(classes_)

if self.outlier_label is not None:
neigh_dist[outliers] = 1e-6
elif outliers:
if self.outlier_label_ is None and outliers.size > 0:
raise ValueError('No neighbors found for test samples %r, '
'you can try using larger radius, '
'give a label for outliers, '
'or consider removing them from your dataset.'
'giving a label for outliers, '
'or considering removing them from your dataset.'
% outliers)

weights = _get_weights(neigh_dist, self.weights)
if weights is not None:
weights = weights[inliers]

y_pred = np.empty((n_samples, n_outputs), dtype=classes_[0].dtype)
probabilities = []
for k, classes_k in enumerate(classes_):
pred_labels = np.zeros(len(neigh_ind), dtype=object)
pred_labels[:] = [_y[ind, k] for ind in neigh_ind]

proba_k = np.zeros((n_samples, classes_k.size))
proba_inl = np.zeros((len(inliers), classes_k.size))

# samples have different size of neighbors within the same radius
if weights is None:
mode = np.array([stats.mode(pl)[0]
for pl in pred_labels[inliers]], dtype=np.int)
for i, idx in enumerate(pred_labels[inliers]):
proba_inl[i, :] = np.bincount(idx,
minlength=classes_k.size)
else:
mode = np.array(
[weighted_mode(pl, w)[0]
for (pl, w) in zip(pred_labels[inliers], weights[inliers])
], dtype=np.int)
for i, idx in enumerate(pred_labels[inliers]):
proba_inl[i, :] = np.bincount(idx,
weights[i],
minlength=classes_k.size)
proba_k[inliers, :] = proba_inl

if outliers.size > 0:
label_index = np.where(classes_k == self.outlier_label_[k])
TomDLT marked this conversation as resolved.
Show resolved Hide resolved
if label_index[0].size != 0:
webber26232 marked this conversation as resolved.
Show resolved Hide resolved
proba_k[outliers, label_index[0][0]] = 1.0
else:
warnings.warn('Outlier label {} is not in training '
'classes. All class probabilities of '
'outliers will be assigned with 0.'
''.format(self.outlier_label_[k]))

mode = mode.ravel()

y_pred[inliers, k] = classes_k.take(mode)
# normalize 'votes' into real [0,1] probabilities
normalizer = proba_k.sum(axis=1)[:, np.newaxis]
normalizer[normalizer == 0.0] = 1.0
proba_k /= normalizer

if outliers:
y_pred[outliers, :] = self.outlier_label
probabilities.append(proba_k)

if not self.outputs_2d_:
y_pred = y_pred.ravel()
probabilities = probabilities[0]

return y_pred
return probabilities
5 changes: 2 additions & 3 deletions sklearn/neighbors/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import numpy as np
from scipy.sparse import issparse

from .base import _get_weights, _check_weights, NeighborsBase, KNeighborsMixin
from .base import _check_weights, _get_weights, NeighborsBase, KNeighborsMixin
from .base import RadiusNeighborsMixin, SupervisedFloatMixin
from ..base import RegressorMixin
from ..utils import check_array
Expand Down Expand Up @@ -347,12 +347,11 @@ def predict(self, X):
if len(ind) else empty_obs
for (i, ind) in enumerate(neigh_ind)])

if np.max(np.isnan(y_pred)):
if np.any(np.isnan(y_pred)):
empty_warning_msg = ("One or more samples have no neighbors "
"within specified radius; predicting NaN.")
warnings.warn(empty_warning_msg)


if self._y.ndim == 1:
y_pred = y_pred.ravel()

Expand Down
Loading