Bug in MeanShift with small number of samples #2356

Closed
larsmans opened this Issue Aug 10, 2013 · 13 comments

Comments

Projects
None yet
6 participants
Owner

larsmans commented Aug 10, 2013

The code from this SO question, reposted below for reference, sometimes works and sometimes fails. I'm too tired to check if this is really a bug, but if it isn't, the error message could be made friendlier:

import numpy as np
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets.samples_generator import make_blobs

# Generate sample data
centers = [
    [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
    [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]
]
X, _ = make_blobs(n_samples=100, centers=centers, cluster_std=0.6)

# Compute clustering with MeanShift

# The following bandwidth can be automatically detected using
bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=500)

ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(X)
labels = ms.labels_
cluster_centers = ms.cluster_centers_

labels_unique = np.unique(labels)
n_clusters_ = len(labels_unique)

print("number of estimated clusters : %d" % n_clusters_)

This code runs successfully on my environment, and produces the following output:

number of estimated clusters : 1

What version of scikit learn are you running?

Owner

larsmans commented Aug 16, 2013

Bleeding edge on Linux x86-64 w/ NumPy 1.5 and variously SciPy 0.9 and 0.10. Which version are you running, and on platform?

Ahah, that would explain that: using an outdated version, 0.13.1. (OS is Ubuntu 12.04 w/ Numpy 1.5). Forgot to update this VM!

This would mean that somewhere between 0.13.1 and 1.5 something went wrong.

Owner

amueller commented Aug 17, 2013

I would imagine the problem is in estimate_bandwidth. Is that right? What do you get as bandwidth?

Owner

larsmans commented Aug 18, 2013

Somewhere around 3. It doesn't seem to matter for the crash, though.

Owner

larsmans commented Aug 18, 2013

Ah: sorted_centers in mean_shift is empty, so the estimated number of clusters seems to be zero for some reason. That causes neighbors to crash.

jbohg commented Aug 28, 2013

I ran into the same problem as @larsmans describes above where sorted_centers in mean_shift is empty. I tracked it down to get_bin_seeds(...).

def get_bin_seeds(X, bin_size, min_bin_freq=1):
    # Bin points
    bin_sizes = defaultdict(int)
    for point in X:
        binned_point = np.cast[np.int32](point / bin_size)
        bin_sizes[tuple(binned_point)] += 1

    # Select only those bins as seeds which have enough members
    bin_seeds = np.array([point for point, freq in six.iteritems(bin_sizes) if
                          freq >= min_bin_freq], dtype=np.float32)
    bin_seeds = bin_seeds * bin_size
    return bin_seeds

The problem is caused by

binned_point = np.cast[np.int32](point / bin_size)

because first of all, casting to int is the same as flooring a number to the next integer. Let's say bin_size=1 then every number in the interval ]-1,1[ will be added to bin 0.
And secondly, all seed points will end up at the corners of the bins and not in their centres as assumed in the function mean_shift

  # Find mean of points within bandwidth
            i_nbrs = nbrs.radius_neighbors([my_mean], bandwidth,
                                           return_distance=False)[0]

Therefore, at extreme cases, there won't be a point in the neighbourhood of size bandwidth around a seed point and sorted_centers is empty.

I fixed this by replacing the cast with a rounding function:

binned_point = np.round(point / bin_size)

This fixes both problems mentioned above and it should never happen that sorted_centers is empty.

Owner

larsmans commented Mar 8, 2014

I can still get it to crash even with np.round as @jbohg suggests.

I'm having this issue as well and I'm not using bin seeding. Using sklearn 0.14.1, scipy 0.13.3, and numpy 1.8.0:

arr = array([[618.97993024, 616.77224179],
             [618.97993024, 616.77224179],
             [622.95978882, 621.22766862],
             [622.95978882, 621.22766862]])

# Bandwidth ends up being 0.0
bandwidth = estimate_bandwidth(arr, quantile=0.6)

clustering = MeanShift(arr, bandwidth=bandwidth, bin_seeding=False)

# This line crashes with the same issue (sorted_centers is empty, causing
# NearestNeighbors to crash)
clustering.fit(arr)

jbohg commented Mar 20, 2014

I can reproduce this bug as well especially in high dimensions (see code example by @larsmans). Exchanging np.cast by np.round definitely fixed one bug.

I think the problem is that get_bin_seeds is done with a grid-based approach while the nearest neighbour search nbrs.radius_neighbors is done with a radius-based approach. That means that points in the corners of the bins won't be associated to the seeds. And this leads in some cases to sorted_centers being empty.

Owner

kastnerkyle commented Jun 2, 2014

Using @larsmans original code from above, when I switch to np.round, I get number of estimated clusters : 2 . So things are working for me at least ... what is happening/changing to get the error in high dimensions? Anyone have a gist of what they modified?

Versions (Anaconda 2.0):
Ubuntu 12.04
Python 3.4.1 |Anaconda 2.0.0 (64-bit)| (default, May 19 2014, 13:02:41)
[GCC 4.1.2 20080704 (Red Hat 4.1.2-54)] on linux

@amueller amueller modified the milestones: 0.15.1, 0.15 Jul 18, 2014

Owner

amueller commented Jan 28, 2015

@davesque your bandwidth is 0.0. Still the error is not great.

Owner

amueller commented Jan 28, 2015

I agree with @jbohg in that we should use round, see below for illustration:
mean_shift_bin_seeding

That does not fully solve the problem, though. The reason this breaks in high dimensions is the curse of dimensionality (and our estimate_bandwidth heuristic). We estimate the bandwidth based on the max distance to some neighbors, and then use that to construct a grid. But with rising dimensionality, the distance to the nearest grid point is unbounded. I think the simplest solution is to bail from get_bin_seeds if each point lies in its own bin.

@amueller amueller closed this in #4176 Feb 7, 2015

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment