diff --git a/doc/modules/neighbors.rst b/doc/modules/neighbors.rst index b19f6c9063e67..5faf71c2f9fcc 100644 --- a/doc/modules/neighbors.rst +++ b/doc/modules/neighbors.rst @@ -59,7 +59,7 @@ Nearest Neighbors Classification Neighbors-based classification is a type of *instance-based learning* or *non-generalizing learning*: it does not attempt to construct a general internal model, but simply stores instances of the training data. -Classification is computed from a simple majority vote of the nearest +The basic classification is computed from a simple majority vote of the nearest neighbors of each point: a query point is assigned the data class which has the most representatives within the nearest neighbors of the point. @@ -94,7 +94,25 @@ be accomplished through the ``weights`` keyword. The default value, distance from the query point. Alternatively, a user-defined function of the distance can be supplied which is used to compute the weights. - +There is a probabilistic interpretation of nearest neighbors classification: +a query point :math:`x` is assigned to the class +:math:`C_k` to which it has the highest probability of belonging. This +*posterior probability* is computed using Bayes' rule: +:math:`P(C_k \mid x) = \frac{P(x \mid C_k) P(C_k)}{P(x)}`. +The basic nearest neighbors classification (when ``class_prior='default'``) +uses a default *prior probability* :math:`P(C_k)` equal to the proportion of +training points which belong to class :math:`C_k`. In contrast, using +a flat prior (``class_prior='flat'``) assigns the same value (1 over the +number of classes) to each class prior probability :math:`P(C_k)`. +Alternatively, a user-defined list of the class prior probabilities (in +increasing order of class labels) can be supplied which is used to classify +the query points. + +The second example below illustrates the effect of assigning a much greater +prior probability (0.8) to the first class (in red) than the other two: in +regions where few data points appear, for example around the point (7, 4.5), +the model is more biased toward the red class than it was in the first +example. .. |classification_1| image:: ../auto_examples/neighbors/images/plot_classification_1.png :target: ../auto_examples/neighbors/plot_classification.html @@ -111,6 +129,11 @@ distance can be supplied which is used to compute the weights. * :ref:`example_neighbors_plot_classification.py`: an example of classification using nearest neighbors. +.. topic:: References: + + * `Pattern Recognition and Machine Learning`, + Bishop, C.M., New York: Springer (2006), p. 124-127 + .. _regression: Nearest Neighbors Regression @@ -118,7 +141,7 @@ Nearest Neighbors Regression Neighbors-based regression can be used in cases where the data labels are continuous rather than discrete variables. The label assigned to a query -point is computed based the mean of the labels of its nearest neighbors. +point is computed based on the mean of the labels of its nearest neighbors. scikit-learn implements two different neighbors regressors: :class:`KNeighborsRegressor` implements learning based on the :math:`k` diff --git a/doc/tutorial/statistical_inference/supervised_learning.rst b/doc/tutorial/statistical_inference/supervised_learning.rst index ca57af93b3eaa..1242bcca46c62 100644 --- a/doc/tutorial/statistical_inference/supervised_learning.rst +++ b/doc/tutorial/statistical_inference/supervised_learning.rst @@ -95,8 +95,8 @@ Scikit-learn documentation for more information about this type of classifier.) >>> from sklearn.neighbors import KNeighborsClassifier >>> knn = KNeighborsClassifier() >>> knn.fit(iris_X_train, iris_y_train) - KNeighborsClassifier(algorithm='auto', leaf_size=30, n_neighbors=5, p=2, - warn_on_equidistant=True, weights='uniform') + KNeighborsClassifier(algorithm='auto', class_prior='default', leaf_size=30, + n_neighbors=5, p=2, warn_on_equidistant=True, weights='uniform') >>> knn.predict(iris_X_test) array([1, 2, 1, 0, 0, 0, 2, 1, 2, 0]) >>> iris_y_test diff --git a/examples/neighbors/plot_classification.py b/examples/neighbors/plot_classification.py index 209820159b9ed..8bfd2bb5e77c6 100644 --- a/examples/neighbors/plot_classification.py +++ b/examples/neighbors/plot_classification.py @@ -27,9 +27,11 @@ cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF']) cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF']) -for weights in ['uniform', 'distance']: +for weights, class_prior in zip(['uniform', 'distance'], + ['default', [0.8, 0.1, 0.1]]): # we create an instance of Neighbours Classifier and fit the data. - clf = neighbors.KNeighborsClassifier(n_neighbors, weights=weights) + clf = neighbors.KNeighborsClassifier(n_neighbors, weights=weights, + class_prior=class_prior) clf.fit(X, y) # Plot the decision boundary. For that, we will asign a color to each @@ -47,8 +49,8 @@ # Plot also the training points pl.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold) - pl.title("3-Class classification (k = %i, weights = '%s')" - % (n_neighbors, weights)) + pl.title("3-Class classification (k = %i,\nweights = '%s', class_prior = '%s')" + % (n_neighbors, weights, class_prior)) pl.axis('tight') pl.show() diff --git a/sklearn/neighbors/base.py b/sklearn/neighbors/base.py index 357cd3eadfa09..8ca57f55f1654 100644 --- a/sklearn/neighbors/base.py +++ b/sklearn/neighbors/base.py @@ -15,6 +15,7 @@ from ..base import BaseEstimator from ..metrics import pairwise_distances from ..utils import safe_asarray, atleast2d_or_csr +from ..utils.fixes import unique class NeighborsWarning(UserWarning): @@ -45,14 +46,14 @@ def _get_weights(dist, weights): """Get the weights from an array of distances and a parameter ``weights`` Parameters - =========== + ---------- dist: ndarray The input distances weights: {'uniform', 'distance' or a callable} The kind of weighting used Returns - ======== + ------- weights_arr: array of the same shape as ``dist`` if ``weights == 'uniform'``, then returns None """ @@ -68,6 +69,39 @@ def _get_weights(dist, weights): raise ValueError("weights not recognized: should be 'uniform', " "'distance', or a callable function") +def _check_class_prior(class_prior): + """Check to make sure class prior is valid.""" + if class_prior in (None, 'default', 'flat'): + return class_prior + elif isinstance(class_prior, (list, np.ndarray)): + return class_prior + else: + raise ValueError("class prior not recognized: should be 'default', " + "'flat', or a list or ndarray") + +def _get_class_prior(y, class_prior): + """Get class prior from targets ``y`` and parameter ``class_prior`` + + Parameters + ---------- + y : ndarray + The target labels, from 0 to ``n-1`` (thus ``n`` classes) + class_prior: {'default', 'flat' or a dict} + The class prior probabilities to use + + Returns + ------- + class_prior_arr: array of the same shape as ``np.unique(y)`` + """ + if class_prior in (None, 'default'): + return np.bincount(y).astype(float) / len(y) + elif class_prior == 'flat': + return np.ones((len(np.unique(y)),)) / len(np.unique(y)) + elif isinstance(class_prior, (list, np.ndarray)): + return class_prior + else: + raise ValueError("class prior not recognized: should be 'default', " + "'flat', or a list or ndarray") class NeighborsBase(BaseEstimator): """Base class for nearest neighbors estimators.""" @@ -567,8 +601,7 @@ def fit(self, X, y): y : {array-like, sparse matrix}, shape = [n_samples] Target values, array of integer values. """ - self._y = np.asarray(y) - self._classes = np.sort(np.unique(y)) + self._classes, self._y = unique(y, return_inverse=True) return self._fit(X) diff --git a/sklearn/neighbors/classification.py b/sklearn/neighbors/classification.py index 4f2ada47460e4..3b2c7bf46268b 100644 --- a/sklearn/neighbors/classification.py +++ b/sklearn/neighbors/classification.py @@ -13,6 +13,7 @@ from .base import \ _check_weights, _get_weights, \ + _check_class_prior, _get_class_prior, \ NeighborsBase, KNeighborsMixin,\ RadiusNeighborsMixin, SupervisedIntegerMixin from ..base import ClassifierMixin @@ -42,6 +43,18 @@ class KNeighborsClassifier(NeighborsBase, KNeighborsMixin, Uniform weights are used by default. + class_prior : str, list or ndarray, optional (default = 'default') + class prior probabilities used in prediction. Possible values: + + - 'default': default prior probabilities. For each class, its + prior probability is the proportion of points in the dataset + that are in this class. + - 'flat': equiprobable prior probabilites. If there are C classes, + then the prior probability for every class is 1/C. + - [list or ndarray]: a used-defined list or ndarray, listing + the prior class probability for each class, in increasing order + of class label. + algorithm : {'auto', 'ball_tree', 'kd_tree', 'brute'}, optional Algorithm used to compute the nearest neighbors: @@ -86,6 +99,11 @@ class KNeighborsClassifier(NeighborsBase, KNeighborsMixin, [0] >>> print(neigh.predict_proba([[0.9]])) [[ 0.66666667 0.33333333]] + >>> neigh = KNeighborsClassifier(n_neighbors=3, class_prior=[0.75, 0.25]) + >>> neigh.fit(X, y) # doctest: +ELLIPSIS + KNeighborsClassifier(...) + >>> print(neigh.predict_proba([[2.0]])) + [[ 0.6 0.4]] See also -------- @@ -100,10 +118,16 @@ class KNeighborsClassifier(NeighborsBase, KNeighborsMixin, for a discussion of the choice of ``algorithm`` and ``leaf_size``. http://en.wikipedia.org/wiki/K-nearest_neighbor_algorithm + + References + ---------- + Bishop, Christopher M. *Pattern Recognition and Machine Learning*. + New York: Springer, 2006, p. 124-7. """ def __init__(self, n_neighbors=5, weights='uniform', + class_prior='default', algorithm='auto', leaf_size=30, warn_on_equidistant=True, p=2): self._init_params(n_neighbors=n_neighbors, @@ -112,6 +136,7 @@ def __init__(self, n_neighbors=5, warn_on_equidistant=warn_on_equidistant, p=p) self.weights = _check_weights(weights) + self.class_prior = _check_class_prior(class_prior) def predict(self, X): """Predict the class labels for the provided data @@ -126,19 +151,8 @@ def predict(self, X): labels: array List of class labels (one for each data sample). """ - X = atleast2d_or_csr(X) - - neigh_dist, neigh_ind = self.kneighbors(X) - pred_labels = self._y[neigh_ind] - - weights = _get_weights(neigh_dist, self.weights) - - if weights is None: - mode, _ = stats.mode(pred_labels, axis=1) - else: - mode, _ = weighted_mode(pred_labels, weights, axis=1) - - return mode.flatten().astype(np.int) + probabilities = self.predict_proba(X) + return self._classes[probabilities.argmax(axis=1)].astype(np.int) def predict_proba(self, X): """Return probability estimates for the test data X. @@ -157,31 +171,29 @@ def predict_proba(self, X): X = atleast2d_or_csr(X) neigh_dist, neigh_ind = self.kneighbors(X) - pred_labels = self._y[neigh_ind] + pred_indices = self._y[neigh_ind] weights = _get_weights(neigh_dist, self.weights) if weights is None: - weights = np.ones_like(pred_labels) + weights = np.ones_like(pred_indices) probabilities = np.zeros((X.shape[0], self._classes.size)) - # Translate class label to a column index in probabilities array. - # This may not be needed provided classes labels are guaranteed to be - # np.arange(n_classes) (e.g. consecutive and starting with 0) - pred_indices = pred_labels.copy() - for k, c in enumerate(self._classes): - pred_indices[pred_labels == c] = k - # a simple ':' index doesn't work right all_rows = np.arange(X.shape[0]) for i, idx in enumerate(pred_indices.T): # loop is O(n_neighbors) probabilities[all_rows, idx] += weights[:, i] + # Compute the unnormalized posterior probability, taking + # self.class_prior_ into consideration. + class_count = np.bincount(self._y) + class_prior = _get_class_prior(self._y, self.class_prior) + probabilities = (probabilities / class_count) * class_prior + # normalize 'votes' into real [0,1] probabilities probabilities = (probabilities.T / probabilities.sum(axis=1)).T - return probabilities @@ -209,6 +221,18 @@ class RadiusNeighborsClassifier(NeighborsBase, RadiusNeighborsMixin, Uniform weights are used by default. + class_prior : str, list or ndarray, optional (default = 'default') + class prior probabilities used in prediction. Possible values: + + - 'default': default prior probabilities. For each class, its + prior probability is the proportion of points in the dataset + that are in this class. + - 'flat': equiprobable prior probabilites. If there are C classes, + then the prior probability for every class is 1/C. + - [list or ndarray]: a used-defined list or ndarray, listing + the prior class probability for each class, in increasing order + of class label. + algorithm : {'auto', 'ball_tree', 'kd_tree', 'brute'}, optional Algorithm used to compute the nearest neighbors: @@ -248,6 +272,13 @@ class RadiusNeighborsClassifier(NeighborsBase, RadiusNeighborsMixin, RadiusNeighborsClassifier(...) >>> print(neigh.predict([[1.5]])) [0] + >>> print(neigh.predict_proba([[1.0]])) + [[ 0.66666667 0.33333333]] + >>> neigh = RadiusNeighborsClassifier(radius=1.0, class_prior=[0.2, 0.8]) + >>> neigh.fit(X, y) # doctest: +ELLIPSIS + RadiusNeighborsClassifier(...) + >>> print(neigh.predict([[1.5]])) + [1] See also -------- @@ -262,15 +293,21 @@ class RadiusNeighborsClassifier(NeighborsBase, RadiusNeighborsMixin, for a discussion of the choice of ``algorithm`` and ``leaf_size``. http://en.wikipedia.org/wiki/K-nearest_neighbor_algorithm + + References + ---------- + Bishop, Christopher M. *Pattern Recognition and Machine Learning*. + New York: Springer, 2006, p. 124-7. """ - def __init__(self, radius=1.0, weights='uniform', + def __init__(self, radius=1.0, weights='uniform', class_prior=None, algorithm='auto', leaf_size=30, p=2, outlier_label=None): self._init_params(radius=radius, algorithm=algorithm, leaf_size=leaf_size, p=p) self.weights = _check_weights(weights) + self.class_prior = _check_class_prior(class_prior) self.outlier_label = outlier_label def predict(self, X): @@ -286,19 +323,52 @@ def predict(self, X): labels: array List of class labels (one for each data sample). """ + if self.outlier_label != None: + probabilities, outliers = self.predict_proba(X) + else: + probabilities = self.predict_proba(X) + # Predict the class of each row, based on the maximum posterior + # probability. If needed, correct the predictions for outliers. + preds = self._classes[probabilities.argmax(axis=1)].astype(np.int) + if self.outlier_label != None: + preds[outliers] = self.outlier_label + + return preds + + def predict_proba(self, X): + """Return probability estimates for the test data X. + + Parameters + ---------- + X: array, shape = (n_samples, n_features) + A 2-D array representing the test points. + + Returns + ------- + probabilities : array, shape = [n_samples, n_classes] + Probabilities of the samples for each class in the model, + where classes are ordered arithmetically. If an outlier label + has been provided and is part of the actual classes, then + outliers will be assigned to that label with probability 1; if + the outlier label (e.g. -1) is not part of the actual classes, + then outliers will have probability 0 for every actual class. + outliers : list, length = n_samples + List of row indices in X that are outliers. Returned only if + self.outlier_label is not set to None. + """ X = atleast2d_or_csr(X) neigh_dist, neigh_ind = self.radius_neighbors(X) pred_labels = [self._y[ind] for ind in neigh_ind] - if self.outlier_label: - outlier_label = np.array((self.outlier_label, )) - small_value = np.array((1e-6, )) + outliers = [] # row indices of the outliers (if any) + # Test with None, since outlier_label could legitimately be 0 + if self.outlier_label != None: for i, pl in enumerate(pred_labels): # Check that all have at least 1 neighbor if len(pl) < 1: - pred_labels[i] = outlier_label - neigh_dist[i] = small_value + # We'll impose the label for that row later. + outliers.append(i) else: for pl in pred_labels: # Check that all have at least 1 neighbor @@ -310,13 +380,42 @@ def predict(self, X): 'dataset') weights = _get_weights(neigh_dist, self.weights) - if weights is None: - mode = np.asarray([stats.mode(pl)[0] for pl in pred_labels], - dtype=np.int) - else: - mode = np.asarray([weighted_mode(pl, w)[0] - for (pl, w) in zip(pred_labels, weights)], - dtype=np.int) + # `neigh_dist` is an array of objects, where each + # object is a 1D array of indices. + weights = np.array([np.ones(len(row)) for row in neigh_dist]) - return mode.flatten().astype(np.int) + probabilities = np.zeros((X.shape[0], self._classes.size)) + + # We cannot vectorize the following because of the way Python handles + # M += 1: if a predicted index was to occur more than once (for a + # given tested point), the corresponding element in `probabilities` + # would still be incremented only once. + for i, pi in enumerate(pred_labels): + if len(pi) < 1: + probabilities[i] = 1e-6 # prevent division by zero later + continue # outlier + # When we support NumPy >= 1.6, we'll be able to simply use: + # np.bincount(pi, weights, minlength=self._classes.size) + unpadded_probs = np.bincount(pi, weights[i]) + probabilities[i] = np.append(unpadded_probs, + np.zeros(self._classes.size - + unpadded_probs.shape[0])) + + # Compute the unnormalized posterior probability, taking + # self.class_prior_ into consideration. + class_count = np.bincount(self._y) + class_prior = _get_class_prior(self._y, self.class_prior) + probabilities = (probabilities / class_count) * class_prior + + # normalize 'votes' into real [0,1] probabilities + probabilities = (probabilities.T / probabilities.sum(axis=1)).T + if self.outlier_label != None: + probabilities[outliers] = 0. + outlier_indices = np.nonzero(self._classes == + self.outlier_label)[0] + if outlier_indices.size > 0: + probabilities[outliers, outlier_indices[0]] = 1 + return probabilities, outliers + else: + return probabilities \ No newline at end of file