Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP

Loading…

Bug in MeanShift with small number of samples #2356

Closed
larsmans opened this Issue · 13 comments

6 participants

@larsmans
Owner

The code from this SO question, reposted below for reference, sometimes works and sometimes fails. I'm too tired to check if this is really a bug, but if it isn't, the error message could be made friendlier:

import numpy as np
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets.samples_generator import make_blobs

# Generate sample data
centers = [
    [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
    [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]
]
X, _ = make_blobs(n_samples=100, centers=centers, cluster_std=0.6)

# Compute clustering with MeanShift

# The following bandwidth can be automatically detected using
bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=500)

ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(X)
labels = ms.labels_
cluster_centers = ms.cluster_centers_

labels_unique = np.unique(labels)
n_clusters_ = len(labels_unique)

print("number of estimated clusters : %d" % n_clusters_)
@SHoltzen

This code runs successfully on my environment, and produces the following output:

number of estimated clusters : 1

What version of scikit learn are you running?

@larsmans
Owner

Bleeding edge on Linux x86-64 w/ NumPy 1.5 and variously SciPy 0.9 and 0.10. Which version are you running, and on platform?

@SHoltzen

Ahah, that would explain that: using an outdated version, 0.13.1. (OS is Ubuntu 12.04 w/ Numpy 1.5). Forgot to update this VM!

This would mean that somewhere between 0.13.1 and 1.5 something went wrong.

@amueller
Owner

I would imagine the problem is in estimate_bandwidth. Is that right? What do you get as bandwidth?

@larsmans
Owner

Somewhere around 3. It doesn't seem to matter for the crash, though.

@larsmans
Owner

Ah: sorted_centers in mean_shift is empty, so the estimated number of clusters seems to be zero for some reason. That causes neighbors to crash.

@jbohg

I ran into the same problem as @larsmans describes above where sorted_centers in mean_shift is empty. I tracked it down to get_bin_seeds(...).

def get_bin_seeds(X, bin_size, min_bin_freq=1):
    # Bin points
    bin_sizes = defaultdict(int)
    for point in X:
        binned_point = np.cast[np.int32](point / bin_size)
        bin_sizes[tuple(binned_point)] += 1

    # Select only those bins as seeds which have enough members
    bin_seeds = np.array([point for point, freq in six.iteritems(bin_sizes) if
                          freq >= min_bin_freq], dtype=np.float32)
    bin_seeds = bin_seeds * bin_size
    return bin_seeds

The problem is caused by

binned_point = np.cast[np.int32](point / bin_size)

because first of all, casting to int is the same as flooring a number to the next integer. Let's say bin_size=1 then every number in the interval ]-1,1[ will be added to bin 0.
And secondly, all seed points will end up at the corners of the bins and not in their centres as assumed in the function mean_shift

  # Find mean of points within bandwidth
            i_nbrs = nbrs.radius_neighbors([my_mean], bandwidth,
                                           return_distance=False)[0]

Therefore, at extreme cases, there won't be a point in the neighbourhood of size bandwidth around a seed point and sorted_centers is empty.

I fixed this by replacing the cast with a rounding function:

binned_point = np.round(point / bin_size)

This fixes both problems mentioned above and it should never happen that sorted_centers is empty.

@larsmans
Owner

I can still get it to crash even with np.round as @jbohg suggests.

@davesque

I'm having this issue as well and I'm not using bin seeding. Using sklearn 0.14.1, scipy 0.13.3, and numpy 1.8.0:

arr = array([[618.97993024, 616.77224179],
             [618.97993024, 616.77224179],
             [622.95978882, 621.22766862],
             [622.95978882, 621.22766862]])

# Bandwidth ends up being 0.0
bandwidth = estimate_bandwidth(arr, quantile=0.6)

clustering = MeanShift(arr, bandwidth=bandwidth, bin_seeding=False)

# This line crashes with the same issue (sorted_centers is empty, causing
# NearestNeighbors to crash)
clustering.fit(arr)
@jbohg

I can reproduce this bug as well especially in high dimensions (see code example by @larsmans). Exchanging np.cast by np.round definitely fixed one bug.

I think the problem is that get_bin_seeds is done with a grid-based approach while the nearest neighbour search nbrs.radius_neighbors is done with a radius-based approach. That means that points in the corners of the bins won't be associated to the seeds. And this leads in some cases to sorted_centers being empty.

@kastnerkyle
Owner

Using @larsmans original code from above, when I switch to np.round, I get number of estimated clusters : 2 . So things are working for me at least ... what is happening/changing to get the error in high dimensions? Anyone have a gist of what they modified?

Versions (Anaconda 2.0):
Ubuntu 12.04
Python 3.4.1 |Anaconda 2.0.0 (64-bit)| (default, May 19 2014, 13:02:41)
[GCC 4.1.2 20080704 (Red Hat 4.1.2-54)] on linux

@amueller amueller modified the milestone: 0.15.1, 0.15
@amueller
Owner

@davesque your bandwidth is 0.0. Still the error is not great.

@amueller
Owner

I agree with @jbohg in that we should use round, see below for illustration:
mean_shift_bin_seeding

That does not fully solve the problem, though. The reason this breaks in high dimensions is the curse of dimensionality (and our estimate_bandwidth heuristic). We estimate the bandwidth based on the max distance to some neighbors, and then use that to construct a grid. But with rising dimensionality, the distance to the nearest grid point is unbounded. I think the simplest solution is to bail from get_bin_seeds if each point lies in its own bin.

@amueller amueller closed this in #4176
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Something went wrong with that request. Please try again.