Skip to content

Conversation

IshankGulati
Copy link
Contributor

added kernel k-means with tests and examples (Issue #5373 )

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe one extra label update is needed here, after the loop: because cluster centers have been updated but training labels haven't, it's possible for the loop to finish in a way that self.fit(X).predict(X) returns different results than self.fit_predict(X). See e.g. #5231

@jakevdp
Copy link
Member

jakevdp commented Oct 20, 2015

Really great start! Nice, clean, easy to read code.

One overall comment: I think an n_init parameter would be important to add to the estimator as well.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Print statement needs parentheses. Also, we probably want from __future__ import print_function at the top of the file.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from __future__ import division would be useful as well – then you wouldn't have to cast to float in e.g. float(n_same) / n_samples and other places (also prevents potential hard-to-track errors between Python 2 & 3)

@IshankGulati
Copy link
Contributor Author

@jakevdp I have done some of the corrections. This code is not completely written by me. I have added tests and examples and some minor modifications so I have a few doubts
Is n_init required to take care of random label initialization?
If yes than what should be the criteria to select the best run among all?
Is last label update(out of the loop) to be done because of earlier within_distances update?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mask_j ? Or is it too heavy ?

@jakevdp
Copy link
Member

jakevdp commented Oct 21, 2015

If yes than what should be the criteria to select the best run among all?

In standard K means, the chosen result is based on minimizing the "inertia", or the sum of square distances from cluster centers. We could do something similar here.

@jakevdp
Copy link
Member

jakevdp commented Oct 21, 2015

Is last label update(out of the loop) to be done because of earlier within_distances update?

The last label update is needed because the cluster centers might change at the end of the final iteration; the labels_ attribute should match these final cluster centers.

@IshankGulati
Copy link
Contributor Author

@jakevdp So what do you suggest to be used in place of inertia?

@jakevdp
Copy link
Member

jakevdp commented Oct 21, 2015

I think you can generalize the inertia to be the sum-of-square-kernel-dissimilarities. These dissimilarities are already computed within each iteration; I think it's as easy as doing inertia = np.sum((dist[self.labels_] ** 2)), though I'd want a second opinion on that 😄 (perhaps @mblondel can chime in)

@IshankGulati IshankGulati force-pushed the kmeans branch 9 times, most recently from ca4bc24 to f629b54 Compare February 21, 2016 07:17
@IshankGulati
Copy link
Contributor Author

Sorry for a long absence. I completely forgot about this PR.
I have made some changes.
@jakevdp Can you please review?

@IshankGulati IshankGulati force-pushed the kmeans branch 2 times, most recently from 1800df7 to e40c184 Compare February 21, 2016 12:36
@jakevdp
Copy link
Member

jakevdp commented Feb 21, 2016

I'd rebase on master - I think those circleCI issues have been fixed.

@IshankGulati
Copy link
Contributor Author

@jakevdp I have already done that but still circleCI build is failing.


kernel : {'linear_kernel', 'polynomial_kernel', 'sigmoid_kernel',
'rbf_kernel', 'chi2_kernel'}, default: 'linear_kernel'
The type of kernel to be used in kernel k means algorithm.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this will render well in the HTML docs. After the colon you should specify the type, not the allowed values. See svm for an example of what I mean.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, ending each with _kernel is a bit redundant. Through the rest of the package we tend to use just "linear", "sigmoid", etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have corrected this. Please have a look.

@jakevdp
Copy link
Member

jakevdp commented Feb 21, 2016

Looks pretty good, but tests are currently lacking. There are about 10 parameters that control the behavior of the Kernel K Means object, but only a few of them are ever used in the test script as far as I can tell. In particular, we should run a test of all the valid kernel types.

@IshankGulati IshankGulati force-pushed the kmeans branch 2 times, most recently from fa4b135 to a1982f6 Compare February 21, 2016 16:17
@IshankGulati IshankGulati force-pushed the kmeans branch 5 times, most recently from 4efa460 to b25c223 Compare February 7, 2017 06:58
@IshankGulati
Copy link
Contributor Author

@tjnycum @amueller I have completed the mentioned changes but one common test is failing where denom is calculated in _compute_dist().
denom = sw[mask].sum()
I have tried replacing it with
denom = sum([x * y for x, y in zip(sw, mask)])
but it still doesn't solve the problem.

@IshankGulati
Copy link
Contributor Author

@amueller @tjnycum Avoid above comment. All the tests are passing now.

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a first pass


* - :ref:'Kernel K-Means'
- number of clusters
- Medium ``n_clusters`` and ``n_samples``
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indentation

Kernel K-Means
==============

The :class:`KernelKMeans` algorithm is an enhancement of the :class:`KMeans` algorithm
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please try to keep lines under 80 chars

@@ -0,0 +1,287 @@
"""
Kernel K-means clustering
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please put on first line

Data used in clustering.
n_iter_ : Iteration in which algorithm converged
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

References section belongs here, as does perhaps a brief doctest example


# sanity check: re-predict labeling for training set samples
pred = clf.predict(X)
assert_array_equal(pred, clf.labels_)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should also check that prediction is possible on a non-training dataset, and results should at least differ between datasets.

==============

The :class:`KernelKMeans` algorithm is an enhancement of the :class:`KMeans` algorithm
which uses a kernel function to generate an appropriate non-linear mapping drom the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drom -> from

The :class:`KernelKMeans` algorithm is an enhancement of the :class:`KMeans` algorithm
which uses a kernel function to generate an appropriate non-linear mapping drom the
original (input) space to a higher-dimensional feature space to extract clusters
that are non-linearly seperable in input space.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"non-" -> "not"

sample_weight_ : array-like, shape=(n_samples,)
labels_ : shape=(n_samples,)
Labels of each point.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*each training sample

K : Kernel matrix
dist : array-like, shape=(n_samples, n_clusters)
Distance of each sample from cluster centers.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this array is used for output; as an input, it should be all-zero.

within_distances : array, shape=(n_clusters,)
Distance update.
update_within : {true, false}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

{True, False} or just bool

@IshankGulati IshankGulati force-pushed the kmeans branch 3 times, most recently from d354daa to 7e235f6 Compare February 24, 2017 12:31
@IshankGulati
Copy link
Contributor Author

@jnothman I have completed all the changes.

@codecov
Copy link

codecov bot commented Feb 24, 2017

Codecov Report

Merging #5483 into master will decrease coverage by 0.7%.
The diff coverage is 98.3%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #5483      +/-   ##
==========================================
- Coverage   96.19%   95.48%   -0.71%     
==========================================
  Files         348      344       -4     
  Lines       64645    61084    -3561     
==========================================
- Hits        62187    58328    -3859     
- Misses       2458     2756     +298
Impacted Files Coverage Δ
sklearn/cluster/__init__.py 100% <100%> (ø) ⬆️
sklearn/cluster/tests/test_kernel_kmeans.py 100% <100%> (ø)
sklearn/cluster/kernel_kmeans.py 97.32% <97.32%> (ø)
sklearn/utils/random.py 59.29% <0%> (-32.2%) ⬇️
sklearn/utils/arpack.py 43.05% <0%> (-31.95%) ⬇️
sklearn/datasets/tests/test_kddcup99.py 29.16% <0%> (-10.84%) ⬇️
sklearn/manifold/mds.py 84.46% <0%> (-9.71%) ⬇️
sklearn/linear_model/tests/test_bayes.py 83.01% <0%> (-8.16%) ⬇️
sklearn/manifold/spectral_embedding_.py 85.16% <0%> (-6.25%) ⬇️
sklearn/multioutput.py 87.05% <0%> (-5.5%) ⬇️
... and 260 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 3e29334...6ca0e63. Read the comment docs.

@lrusnac
Copy link

lrusnac commented Mar 10, 2017

Hi @IshankGulati, I got here because I'm trying to use K-Means with cosine distance. Looks like the kernel can be used for that?. I see a callable option but couldn't find an example of how to write and use a custom kernel. Could you direct me in the right direction please? Would it work to implement it like this and set kernel=my_cosine_function? thanks

@IshankGulati
Copy link
Contributor Author

Hi @lrusnac, sklearn already has cosine kernel. If you want to use a custom kernel, you can take help of the link I mentioned to implement a custom kernel and then pass it as a callable. If you face any problem feel free to ping me.

@jnothman
Copy link
Member

jnothman commented Mar 12, 2017 via email

@lrusnac
Copy link

lrusnac commented Mar 14, 2017

@IshankGulati and @jnothman thanks.
I was able to use the cosine_similarity. But I got an issue.

KernelKMeans has no n_jobs parameter so basically it doesn't run in parallel. I found out that the bottleneck is in the function _get_kernel where it calls pairwise_kernels. This last function has a parameter n_jobs but by default is 1 so just one core. I was able to set it's value to -1 and it increased the speed of clustering. Is there a reason KernelKMeans has no parallelisation or it's something that got forgotten?

@jnothman
Copy link
Member

jnothman commented Mar 14, 2017 via email

@ajilling
Copy link

Is anyone planning on merging this?

@jnothman
Copy link
Member

I suppose we need to decide if it's something we're keen on, or whether it would better reside in scikit-learn-extra. Have you tried using it, @ajilling? What does it help you solve, and is it efficient?

@ajilling
Copy link

@jnothman I have used it. It has one weakness in that it fails whenever it encounters an empty cluster - this differs from most of the other sklearn clustering implementations which seem to be able to handle that. Efficiency is on par with everything else.

I'm working on a project where I apply different clustering algorithms to datasets and compare the results. Sklearn is the first place I look, so to me, the more algorithms in there the better. Otherwise I have to resort to unknown packages or embedded R code.

@jnothman
Copy link
Member

jnothman commented Jun 22, 2019 via email

@ogrisel
Copy link
Member

ogrisel commented Oct 20, 2021

Just in case someone finds this closed issue by googling, one alternative would be to do:

from sklearn.pipeline import make_pipeline
from sklearn.decomposition import PCA
from sklearn.kernel_approximation import Nystroem
from sklearn.cluster import KMeans


X_train = ...

fake_kernel_kmeans = make_pipeline(
    Nystroem(n_components=1000, kernel="rbf", gamma=1e-3),  # gamma needs to be tuned
    PCA(50),
    KMeans(n_clusters=10),
).fit(X_train)

Note: I am not sure if the PCA step to reduce the dimensionality is helpful or detrimental, prior to k-means. Maybe dropping the PCA step and reducing n_components or running k-means in high dim are better options, for a give notion of "better" that includes both computational concerns, qualitative clustering quality (probably depends on n_cluster) or downstream predictive accuracy metrics if the clustering is used as a feature extraction strategy for a downstream supervised learning task.

It's probably also necessary to add a column transformer to one-hot encode categorical features and StandardScale or QuantileTransform or SplineTransform the numerical features to get a meaningful kernel expansion.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants