Skip to content
Permalink
Browse files

Merge pull request #3811 from MechCoder/fix_repeated_checking

[MRG] Fix repeated calls of check_pairwise and type casting in pairwise_distances_argmin_min
  • Loading branch information...
agramfort committed Oct 29, 2014
2 parents 626f672 + b500e3e commit 3f49cee020a91a0be5d0d5602d29b3eefce9d758
Showing with 41 additions and 11 deletions.
  1. +4 −0 doc/whats_new.rst
  2. +37 −11 sklearn/metrics/pairwise.py
@@ -91,6 +91,10 @@ Enhancements
handle unknown categorical features more gracefully during transform.
By `Manoj Kumar`_

- Added option ``check_X_y`` to :func:`metrics.pairwise_distances_argmin_min`
that can give speed improvements by avoiding repeated checking when set to
False. By `Manoj Kumar`_

Documentation improvements
..........................

@@ -56,6 +56,30 @@


# Utility Functions
def _return_float_dtype(X, Y):
"""
1. If dtype of X and Y is float32, then dtype float32 is returned.
2. Else dtype float is returned.
"""
if not issparse(X) and not isinstance(X, np.ndarray):
X = np.asarray(X)

if Y is None:
Y_dtype = X.dtype
elif not issparse(Y) and not isinstance(Y, np.ndarray):
Y = np.asarray(Y)
Y_dtype = Y.dtype
else:
Y_dtype = Y.dtype

if X.dtype == Y_dtype == np.float32:
dtype = np.float32
else:
dtype = np.float

return X, Y, dtype


def check_pairwise_arrays(X, Y):
""" Set X and Y appropriately and checks inputs
@@ -85,22 +109,18 @@ def check_pairwise_arrays(X, Y):
If Y was None, safe_Y will be a pointer to X.
"""
X, Y, dtype = _return_float_dtype(X, Y)

if Y is X or Y is None:
X = Y = check_array(X, accept_sparse='csr')
X = Y = check_array(X, accept_sparse='csr', dtype=dtype)
else:
X = check_array(X, accept_sparse='csr')
Y = check_array(Y, accept_sparse='csr')
X = check_array(X, accept_sparse='csr', dtype=dtype)
Y = check_array(Y, accept_sparse='csr', dtype=dtype)
if X.shape[1] != Y.shape[1]:
raise ValueError("Incompatible dimension for X and Y matrices: "
"X.shape[1] == %d while Y.shape[1] == %d" % (
X.shape[1], Y.shape[1]))

if not (X.dtype == Y.dtype == np.float32):
if Y is X:
X = Y = check_array(X, ['csr', 'csc', 'coo'], dtype=np.float)
else:
X = check_array(X, ['csr', 'csc', 'coo'], dtype=np.float)
Y = check_array(Y, ['csr', 'csc', 'coo'], dtype=np.float)
return X, Y


@@ -225,7 +245,8 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False):


def pairwise_distances_argmin_min(X, Y, axis=1, metric="euclidean",
batch_size=500, metric_kwargs=None):
batch_size=500, metric_kwargs=None,
check_X_y=True):
"""Compute minimum distances between one point and a set of points.
This function computes for each row in X, the index of the row of Y which
@@ -280,6 +301,10 @@ def pairwise_distances_argmin_min(X, Y, axis=1, metric="euclidean",
metric_kwargs : dict, optional
Keyword arguments to pass to specified metric function.
check_X_y : bool, default True
Whether or not to check X and y for shape, validity and dtype. Speed
improvements possible if set to False when called repeatedly.
Returns
-------
argmin : numpy.ndarray
@@ -300,7 +325,8 @@ def pairwise_distances_argmin_min(X, Y, axis=1, metric="euclidean",
elif not callable(metric) and not isinstance(metric, str):
raise ValueError("'metric' must be a string or a callable")

X, Y = check_pairwise_arrays(X, Y)
if check_X_y:
X, Y = check_pairwise_arrays(X, Y)

if metric_kwargs is None:
metric_kwargs = {}

0 comments on commit 3f49cee

Please sign in to comment.
You can’t perform that action at this time.