Skip to content

Commit

Permalink
change basis_metod to basis_sampling and clustered to kmeans
Browse files Browse the repository at this point in the history
  • Loading branch information
nateyoder committed May 3, 2014
1 parent e7bec1e commit 5f313f8
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 33 deletions.
4 changes: 2 additions & 2 deletions doc/modules/kernel_approximation.rst
Expand Up @@ -36,9 +36,9 @@ The Nystroem method, as implemented in :class:`Nystroem` is a general method
for low-rank approximations of kernels. It achieves this by essentially subsampling
the data on which the kernel is evaluated.
The subsampling methodology used to generate the approximate kernel is specified by
the parameter ``basis_method`` which can either be ``random`` or ``clustered``.
the parameter ``basis_sampling`` which can either be ``random`` or ``kmeans``.
If the ``random`` method is specified randomly selected data will be utilized in
the approximation while the ``clustered`` method uses the cluster centers found via
the approximation while the ``kmeans`` method uses the cluster centers found via
k-means clustering. Further details concerning the subsampling methods can be found
in [ZK2010]_.
By default :class:`Nystroem` uses the ``rbf`` kernel, but it can use any
Expand Down
10 changes: 5 additions & 5 deletions examples/plot_kernel_approximation.py
Expand Up @@ -85,8 +85,8 @@
# create pipeline from kernel approximation
# and linear svm
feature_map_fourier = RBFSampler(gamma=kernel_gamma, random_state=0)
feature_map_random_nystroem = Nystroem(gamma=kernel_gamma, random_state=0, basis_method='random')
feature_map_clusted_nystroem_ = Nystroem(gamma=kernel_gamma, random_state=0, basis_method='clustered')
feature_map_random_nystroem = Nystroem(gamma=kernel_gamma, random_state=0, basis_sampling='random')
feature_map_clusted_nystroem_ = Nystroem(gamma=kernel_gamma, random_state=0, basis_sampling='kmeans')

fourier_approx_svm = pipeline.Pipeline([("feature_map", feature_map_fourier),
("svm", svm.LinearSVC())])
Expand Down Expand Up @@ -147,9 +147,9 @@
timescale.plot(sample_sizes, random_times, '--',
label='Random Nystroem approx. kernel')

accuracy.plot(sample_sizes, clustered_scores, label="Clustered Nystroem approx. kernel")
accuracy.plot(sample_sizes, clustered_scores, label="K-means Nystroem approx. kernel")
timescale.plot(sample_sizes, clustered_times, '--',
label='Clustered Nystroem approx. kernel')
label='K-means Nystroem approx. kernel')

accuracy.plot(sample_sizes, fourier_scores, label="Fourier approx. kernel")
timescale.plot(sample_sizes, fourier_times, '--',
Expand Down Expand Up @@ -204,7 +204,7 @@
'SVC with linear kernel',
'SVC (linear kernel)\n with Fourier rbf approx\n'
'n_components={}'.format(n_components_to_plot),
'SVC (linear kernel)\n with clustered Nystroem rbf approx\n'
'SVC (linear kernel)\n with K-means Nystroem rbf approx\n'
'n_components={}'.format(n_components_to_plot)]

plt.tight_layout()
Expand Down
13 changes: 6 additions & 7 deletions sklearn/kernel_approximation.py
Expand Up @@ -377,7 +377,7 @@ class Nystroem(BaseEstimator, TransformerMixin):
If int, random_state is the seed used by the random number generator;
if RandomState instance, random_state is the random number generator.
basis_method : string "random" or "clustered"
basis_sampling : string "random" or "kmeans"
Form approximation using randomly sampled columns or k-means
cluster centers to construct the Nystrom Approximation
Expand Down Expand Up @@ -420,15 +420,15 @@ class Nystroem(BaseEstimator, TransformerMixin):
"""
def __init__(self, kernel="rbf", gamma=None, coef0=1, degree=3,
kernel_params=None, n_components=100, random_state=None,
basis_method="random"):
basis_sampling="random"):
self.kernel = kernel
self.gamma = gamma
self.coef0 = coef0
self.degree = degree
self.kernel_params = kernel_params
self.n_components = n_components
self.random_state = random_state
self.basis_method = basis_method
self.basis_sampling = basis_sampling

def fit(self, X, y=None):
"""Fit estimator to data.
Expand Down Expand Up @@ -458,18 +458,17 @@ def fit(self, X, y=None):
else:
n_components = self.n_components

if self.basis_method == "random":
if self.basis_sampling == "random":
inds = rnd.permutation(n_samples)
basis_inds = inds[:n_components]
basis = X[basis_inds]
elif self.basis_method == "clustered":
elif self.basis_sampling == "kmeans":
# Zhang and Kwok use 5 in their paper so lets do that
basis, _, _ = k_means(X, n_components, init='random', max_iter=5, n_init=1, random_state=rnd)
#If we are using k_means centers as input, cannot record basis_inds
basis_inds = None

else:
raise NameError('{0} is not a supported basis_method'.format(self.basis_method))
raise NameError('{0} is not a supported basis_sampling method'.format(self.basis_sampling))

basis_kernel = pairwise_kernels(basis, metric=self.kernel,
filter_params=True,
Expand Down
38 changes: 19 additions & 19 deletions sklearn/tests/test_kernel_approximation.py
Expand Up @@ -144,13 +144,13 @@ def test_nystroem_approximation_with_number_samples_is_exact():
X = rnd.uniform(size=(10, 4))

# With n_components = n_samples this is exact
ny_random = Nystroem(n_components=X.shape[0], basis_method='random')
ny_random = Nystroem(n_components=X.shape[0], basis_sampling='random')
X_transformed_random = ny_random.fit_transform(X)
K = rbf_kernel(X)
assert_array_equal(np.sort(ny_random.component_indices_), np.arange(X.shape[0]))
assert_array_almost_equal(np.dot(X_transformed_random, X_transformed_random.T), K)

ny_clustered = Nystroem(n_components=X.shape[0], basis_method='clustered')
ny_clustered = Nystroem(n_components=X.shape[0], basis_sampling='kmeans')
X_transformed_clustered = ny_clustered.fit_transform(X)
K = rbf_kernel(X)
# No component indicies to report for k-means
Expand All @@ -162,13 +162,13 @@ def test_nystroem_approximation_returns_appropriate_indices():
rnd = np.random.RandomState(0)
X = rnd.uniform(size=(10, 4))

ny_random = Nystroem(n_components=2, basis_method='random')
ny_random = Nystroem(n_components=2, basis_sampling='random')
X_transformed = ny_random.fit_transform(X)
assert_equal(X_transformed.shape, (X.shape[0], 2))
assert_equal(len(ny_random.component_indices_), 2)
assert_array_almost_equal(ny_random.components_, X[ny_random.component_indices_])

ny_clustered = Nystroem(n_components=2, basis_method='clustered')
ny_clustered = Nystroem(n_components=2, basis_sampling='kmeans')
ny_clustered.fit_transform(X)
# No component indicies to report for k-means
assert_equal(ny_clustered.component_indices_, None)
Expand All @@ -182,7 +182,7 @@ def test_nystroem_approximation_with_singular_kernel_matrix():
K = rbf_kernel(X)
assert_equal(np.linalg.matrix_rank(K), 10)

ny_random = Nystroem(n_components=X.shape[0], basis_method='random')
ny_random = Nystroem(n_components=X.shape[0], basis_sampling='random')
X_transformed = ny_random.fit_transform(X)
assert_equal(X_transformed.shape, (X.shape[0], 12))
assert_array_almost_equal(np.dot(X_transformed, X_transformed.T), K)
Expand All @@ -193,50 +193,50 @@ def test_nystroem_approximation_for_multiple_kernels():
rnd = np.random.RandomState(0)
X = rnd.uniform(size=(10, 4))
trans_not_valid = Nystroem(n_components=2, random_state=rnd,
basis_method="not_a_valid_basis_method")
basis_sampling="not_a_valid_basis_sampling")
assert_raises(NameError, trans_not_valid.fit, X)

# Kernel tests to perform with each basis method used
def test_nystroem_approximation_with_basis(tested_basis):
# Test default kernel
trans = Nystroem(n_components=2, random_state=rnd, basis_method=tested_basis)
trans = Nystroem(n_components=2, random_state=rnd, basis_sampling=tested_basis)
transformed = trans.fit(X).transform(X)
assert_equal(transformed.shape, (X.shape[0], 2))

# test callable kernel
linear_kernel = lambda X, Y: np.dot(X, Y.T)
trans = Nystroem(n_components=2, kernel=linear_kernel, random_state=rnd, basis_method=tested_basis)
trans = Nystroem(n_components=2, kernel=linear_kernel, random_state=rnd, basis_sampling=tested_basis)
transformed = trans.fit(X).transform(X)
assert_equal(transformed.shape, (X.shape[0], 2))

# test that available kernels fit and transform
kernels_available = kernel_metrics()
for kern in kernels_available:
trans = Nystroem(n_components=2, kernel=kern, random_state=rnd, basis_method=tested_basis)
trans = Nystroem(n_components=2, kernel=kern, random_state=rnd, basis_sampling=tested_basis)
transformed = trans.fit(X).transform(X)
assert_equal(transformed.shape, (X.shape[0], 2))

# Test default kernel
trans = Nystroem(n_components=2, random_state=rnd, basis_method=tested_basis)
trans = Nystroem(n_components=2, random_state=rnd, basis_sampling=tested_basis)
transformed = trans.fit(X).transform(X)
assert_equal(transformed.shape, (X.shape[0], 2))

# test callable kernel
linear_kernel = lambda X, Y: np.dot(X, Y.T)
trans = Nystroem(n_components=2, kernel=linear_kernel, random_state=rnd, basis_method=tested_basis)
trans = Nystroem(n_components=2, kernel=linear_kernel, random_state=rnd, basis_sampling=tested_basis)
transformed = trans.fit(X).transform(X)
assert_equal(transformed.shape, (X.shape[0], 2))

# test that available kernels fit and transform
kernels_available = kernel_metrics()
for kern in kernels_available:
trans = Nystroem(n_components=2, kernel=kern, random_state=rnd, basis_method=tested_basis)
trans = Nystroem(n_components=2, kernel=kern, random_state=rnd, basis_sampling=tested_basis)
transformed = trans.fit(X).transform(X)
assert_equal(transformed.shape, (X.shape[0], 2))

# Go through all the kernels with each basis_method
basis_methods = ("random", "clustered")
for current_basis in basis_methods:
# Go through all the kernels with each basis_sampling method
basis_sampling_methods = ("random", "kmeans")
for current_basis in basis_sampling_methods:
yield test_nystroem_approximation_with_basis, current_basis


Expand All @@ -247,9 +247,9 @@ def test_nystroem_poly_kernel_params():

K = polynomial_kernel(X, degree=3.1, coef0=.1)
nystroem_random = Nystroem(kernel="polynomial", n_components=X.shape[0],
degree=3.1, coef0=.1, basis_method="random")
degree=3.1, coef0=.1, basis_sampling="random")
nystroem_k_means = Nystroem(kernel="polynomial", n_components=X.shape[0],
degree=3.1, coef0=.1, basis_method="clustered")
degree=3.1, coef0=.1, basis_sampling="kmeans")

transformed_k_means = nystroem_k_means.fit_transform(X)
transformed_random = nystroem_random.fit_transform(X)
Expand All @@ -274,13 +274,13 @@ def logging_histogram_kernel(x, y, log):
kernel_log = []
Nystroem(kernel=logging_histogram_kernel,
n_components=(n_samples - 1),
kernel_params={'log': kernel_log}, basis_method="clustered").fit(X)
kernel_params={'log': kernel_log}, basis_sampling="kmeans").fit(X)

assert_equal(len(kernel_log), n_samples * (n_samples - 1) / 2)

kernel_log = []
Nystroem(kernel=logging_histogram_kernel,
n_components=(n_samples - 1),
kernel_params={'log': kernel_log}, basis_method="random").fit(X)
kernel_params={'log': kernel_log}, basis_sampling="random").fit(X)

assert_equal(len(kernel_log), n_samples * (n_samples - 1) / 2)

0 comments on commit 5f313f8

Please sign in to comment.