Permalink
Browse files

MISC: tweak example layout

  • Loading branch information...
1 parent c2b356b commit df558e84beb02665056e3af89ab0952569fc7cef @GaelVaroquaux GaelVaroquaux committed Mar 25, 2012
Showing with 18 additions and 20 deletions.
  1. +18 −20 examples/cluster/plot_cluster_comparison.py
@@ -25,31 +25,26 @@
import numpy as np
import pylab as pl
-from sklearn.cluster import MeanShift, estimate_bandwidth
-from sklearn.cluster import KMeans
-from sklearn.cluster import Ward
-from sklearn.cluster import SpectralClustering
-from sklearn.cluster import DBSCAN
-from sklearn.cluster import AffinityPropagation
+from sklearn import cluster, datasets
from sklearn.metrics import euclidean_distances
from sklearn.neighbors import kneighbors_graph
-from sklearn.datasets import make_circles, make_moons, make_blobs
from sklearn.preprocessing import Scaler
np.random.seed(0)
# Generate datasets
n_samples = 300
-noisy_circles = make_circles(n_samples=n_samples, factor=.5, noise=.05)
-noisy_moons = make_moons(n_samples=n_samples, noise=.05)
-blobs = make_blobs(n_samples=n_samples, random_state=8)
+noisy_circles = datasets.make_circles(n_samples=n_samples, factor=.5,
+ noise=.05)
+noisy_moons = datasets.make_moons(n_samples=n_samples, noise=.05)
+blobs = datasets.make_blobs(n_samples=n_samples, random_state=8)
no_structure = np.random.rand(n_samples, 2), None
colors = np.array([x for x in 'bgrcmykbgrcmykbgrcmykbgrcmyk'])
colors = np.hstack([colors] * 20)
-pl.figure(figsize=(14, 10))
-pl.subplots_adjust(left=.001, right=.999, bottom=.01, top=.95, wspace=.05,
+pl.figure(figsize=(14, 9.5))
+pl.subplots_adjust(left=.001, right=.999, bottom=.001, top=.96, wspace=.05,
hspace=.01)
plot_num = 1
@@ -60,7 +55,7 @@
X = Scaler().fit_transform(X)
# estimate bandwidth for mean shift
- bandwidth = estimate_bandwidth(X, quantile=0.3)
+ bandwidth = cluster.estimate_bandwidth(X, quantile=0.3)
# connectivity matrix for structured Ward
connectivity = kneighbors_graph(X, n_neighbors=10)
@@ -71,12 +66,12 @@
distances = euclidean_distances(X)
# create clustering estimators
- ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
- two_means = KMeans(k=2)
- ward_five = Ward(n_clusters=2, connectivity=connectivity)
- spectral = SpectralClustering(k=2, mode='arpack')
- dbscan = DBSCAN(eps=.3)
- affinity_propagation = AffinityPropagation(damping=.9)
+ ms = cluster.MeanShift(bandwidth=bandwidth, bin_seeding=True)
+ two_means = cluster.KMeans(k=2)
+ ward_five = cluster.Ward(n_clusters=2, connectivity=connectivity)
+ spectral = cluster.SpectralClustering(k=2, mode='arpack')
+ dbscan = cluster.DBSCAN(eps=.3)
+ affinity_propagation = cluster.AffinityPropagation(damping=.9)
for algorithm in [two_means, affinity_propagation, ms, spectral,
ward_five, dbscan]:
@@ -89,7 +84,10 @@
algorithm.fit(-distances, p=-20*distances.max())
else:
algorithm.fit(X)
- y_pred = algorithm.labels_.astype(np.int)
+ if hasattr(algorithm, 'labels_'):
+ y_pred = algorithm.labels_.astype(np.int)
+ else:
+ y_pred = algorithm.predict(X)
# plot
pl.subplot(4, 6, plot_num)

0 comments on commit df558e8

Please sign in to comment.