Permalink
Browse files

ENH cluster comparison example (starting)

  • Loading branch information...
1 parent 40c45ca commit db1e1868beff8d3122102922f1ab7aad6f7780aa @amueller amueller committed Mar 3, 2012
Showing with 65 additions and 3 deletions.
  1. +62 −0 examples/cluster/plot_cluster_comparison.py
  2. +3 −3 sklearn/datasets/samples_generator.py
@@ -0,0 +1,62 @@
+"""
+=========================================================
+Comparing different clustering algorithms on toy datasets
+=========================================================
+
+This example aims at showing characteristics of different
+clustering algorithms on datasets that are "interesting"
+but still in 2D.
+
+While these examples give some intuition about the algorithms,
+this intuition might not apply to very high dimensional data.
+"""
+print __doc__
+
+import numpy as np
+import pylab as pl
+
+from sklearn.cluster import MeanShift, estimate_bandwidth
+from sklearn.cluster import KMeans
+from sklearn.cluster import Ward
+from sklearn.datasets import make_circles, make_moons, make_blobs
+
+# Generate datasets
+n_samples = 50
+circles = make_circles(n_samples=n_samples, factor=.5)
+moons = make_moons(n_samples=n_samples)
+noisy_moons = make_moons(n_samples=n_samples, noise=.1)
+blobs = make_blobs(n_samples=n_samples)
+
+i = 1
+colors = np.array([x for x in 'bgrcmykbgrcmykbgrcmykbgrcmyk'])
+colors = np.hstack([colors] * 5)
+
+for dataset in [circles, moons, noisy_moons, blobs]:
+ X, y = dataset
+ # estimate bandwidth for mean shift
+ bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=100)
+ #bandwidth = .2
+ print(bandwidth)
+
+ # create clustering estimators
+ ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
+ two_means = KMeans(k=2)
+ ten_means = KMeans(k=10)
+ #affinity = AffinityPropagation()
+ ward_five = Ward(n_clusters=5)
+ ward_two = Ward(n_clusters=2)
+
+ for algorithm in [two_means, ten_means, ms, ward_five]:
+ # predict cluster memberships
+ algorithm.fit(X)
+ y_pred = algorithm.labels_
+ pl.subplot(4, 4, i)
+ pl.title(str(algorithm).split('(')[0])
+ pl.scatter(X[:, 0], X[:, 1], color=colors[y_pred])
+ pl.xticks(())
+ pl.yticks(())
+
+ i += 1
+
+pl.show()
+
@@ -434,7 +434,7 @@ def make_regression(n_samples=100, n_features=100, n_informative=10, bias=0.0,
return X, y
-def make_circles(n_samples=100, shuffle=True, random_state=None):
+def make_circles(n_samples=100, shuffle=True, random_state=None, factor=.8):
"""Make a large circle containing a smaller circle in 2di
A simple toy dataset to visualize clustering and classification
@@ -460,8 +460,8 @@ def make_circles(n_samples=100, shuffle=True, random_state=None):
n_samples_out, n_samples_in = n_samples_out + 1, n_samples_in + 1
outer_circ_x = np.cos(np.linspace(0, 2 * np.pi, n_samples_out)[:-1])
outer_circ_y = np.sin(np.linspace(0, 2 * np.pi, n_samples_out)[:-1])
- inner_circ_x = np.cos(np.linspace(0, 2 * np.pi, n_samples_in)[:-1]) * 0.8
- inner_circ_y = np.sin(np.linspace(0, 2 * np.pi, n_samples_in)[:-1]) * 0.8
+ inner_circ_x = np.cos(np.linspace(0, 2 * np.pi, n_samples_in)[:-1]) * factor
+ inner_circ_y = np.sin(np.linspace(0, 2 * np.pi, n_samples_in)[:-1]) * factor
X = np.vstack((np.append(outer_circ_x, inner_circ_x),\
np.append(outer_circ_y, inner_circ_y))).T

0 comments on commit db1e186

Please sign in to comment.