Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hcluster v2 #86

Merged
merged 11 commits into from
Apr 1, 2011
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import pylab as pl
from scipy import linalg, ndimage

from scikits.learn.feature_extraction.image import img_to_graph
from scikits.learn.feature_extraction.image import grid_to_graph
from scikits.learn import feature_selection
from scikits.learn.cluster import WardAgglomeration
from scikits.learn.linear_model import BayesianRidge
Expand Down Expand Up @@ -62,8 +62,9 @@
mem = Memory(cachedir='.', verbose=1)

# Ward agglomeration followed by BayesianRidge
A = img_to_graph(mask, mask)
ward = WardAgglomeration(n_clusters=10, connectivity=A, memory=mem, n_comp=1)
A = grid_to_graph(n_x=size, n_y=size)
ward = WardAgglomeration(n_clusters=10, connectivity=A, memory=mem,
n_components=1)
clf = Pipeline([('ward', ward), ('ridge', ridge)])
parameters = {'ward__n_clusters': [10, 20, 30]}
# Select the optimal number of parcels with grid search
Expand Down
6 changes: 3 additions & 3 deletions examples/cluster/plot_lena_ward_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np
import scipy as sp
import pylab as pl
from scikits.learn.feature_extraction.image import img_to_graph
from scikits.learn.feature_extraction.image import grid_to_graph
from scikits.learn.cluster import Ward

###############################################################################
Expand All @@ -31,8 +31,8 @@
X = np.atleast_2d(lena[mask]).T

###############################################################################
# Define the structure A of the data. Here a 10 nearest neighbors
connectivity = img_to_graph(mask, mask)
# Define the structure A of the data. Pixels connected to their neighbors.
connectivity = grid_to_graph(*lena.shape)

###############################################################################
# Compute clustering
Expand Down
6 changes: 3 additions & 3 deletions examples/cluster/plot_ward_unstructured.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
===========================================================

Example builds a swiss roll dataset and runs the hierarchical
clustering on k-Nearest Neighbors graph. It's a hierarchical
clustering without structure prior.
clustering on their position. It's a hierarchical clustering
without structure prior.

"""

Expand All @@ -24,7 +24,7 @@

###############################################################################
# Generate data (swiss roll dataset)
n_samples = 500
n_samples = 1000
noise = 0.05
X = swiss_roll(n_samples, noise)

Expand Down
647 changes: 60 additions & 587 deletions scikits/learn/cluster/_inertia.c

Large diffs are not rendered by default.

21 changes: 0 additions & 21 deletions scikits/learn/cluster/_inertia.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,9 @@ ctypedef np.int_t INT
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def compute_inertia(np.ndarray[DOUBLE, ndim=1] m_1,\
np.ndarray[DOUBLE, ndim=2] m_2,\
np.ndarray[DOUBLE, ndim=2] m_3,\
np.ndarray[INT, ndim=1] coord_row,
np.ndarray[INT, ndim=1] coord_col,\
np.ndarray[DOUBLE, ndim=1] res):
cdef int size_max = coord_row.shape[0]
cdef int n_features = m_3.shape[1]
cdef int i, j, row, col
cdef DOUBLE pa, n
for i in range(size_max):
row = coord_row[i]
col = coord_col[i]
n = m_1[row] + m_1[col]
pa = 0.
for j in range(n_features):
pa += m_3[row, j] + m_3[col, j]
pa -= ((m_2[row, j] + m_2[col, j])**2) / n
res[i] = pa
return res

def compute_ward_dist(np.ndarray[DOUBLE, ndim=1] m_1,\
np.ndarray[DOUBLE, ndim=2] m_2,\
np.ndarray[DOUBLE, ndim=2] m_3,\
np.ndarray[INT, ndim=1] coord_row,
np.ndarray[INT, ndim=1] coord_col,\
np.ndarray[DOUBLE, ndim=1] res):
Expand Down
104 changes: 38 additions & 66 deletions scikits/learn/cluster/hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import numpy as np
from scipy import sparse
from scipy.cluster import hierarchy

from ..base import BaseEstimator
from ..utils._csgraph import cs_graph_components
Expand All @@ -23,8 +24,7 @@
###############################################################################
# Ward's algorithm

def ward_tree(X, connectivity=None, n_components=None, copy=True,
inertia_criterion=False):
def ward_tree(X, connectivity=None, n_components=None, copy=True):
"""Ward clustering based on a Feature matrix. Heapq-based representation
of the inertia matrix.

Expand All @@ -49,9 +49,6 @@ def ward_tree(X, connectivity=None, n_components=None, copy=True,
copy : bool (optional)
Make a copy of connectivity or work inplace. If connectivity
is not of LIL type there will be a copy in any case.

inertia_criterion: bool (optional)
Use an inertia criterion instead of classical Ward's criterion

Returns
-------
Expand Down Expand Up @@ -79,56 +76,47 @@ def ward_tree(X, connectivity=None, n_components=None, copy=True,
" connectivity matrix is %d > 1. The tree will be stopped early."
% n_components)
else:
n_components = 1
out = hierarchy.ward(X)
children_ = out[:, :2].astype(np.int)
return children_, 1, n_samples

n_nodes = 2 * n_samples - n_components

if connectivity is None:
coord_row, coord_col = np.where(np.tril(np.ones((n_samples, n_samples),
dtype=np.bool), k=-1))
A = [range(0, ind) + range(ind + 1, n_samples)
for ind in range(n_samples)]
if (connectivity.shape[0] != n_samples or
connectivity.shape[1] != n_samples):
raise ValueError('Wrong shape for connectivity matrix: %s '
'when X is %s' % (connectivity.shape, X.shape))
# convert connectivity matrix to LIL eventually with a copy
if sparse.isspmatrix_lil(connectivity) and copy:
connectivity = connectivity.copy()
else:
if (connectivity.shape[0] != n_samples or
connectivity.shape[1] != n_samples):
raise ValueError('Wrong shape for connectivity matrix: %s '
'when X is %s' % (connectivity.shape, X.shape))
# convert connectivity matrix to LIL eventually with a copy
if sparse.isspmatrix_lil(connectivity) and copy:
connectivity = connectivity.copy()
else:
connectivity = connectivity.tolil()

# Remove diagonal from connectivity matrix
connectivity.setdiag(np.zeros(connectivity.shape[0]))

# create inertia matrix
coord_row = []
coord_col = []
A = []
for ind, row in enumerate(connectivity.rows):
A.append(row)
# We keep only the upper triangular for the moments
# Generator expressions are faster than arrays on the following
row = [i for i in row if i < ind]
coord_row.extend(len(row) * [ind,])
coord_col.extend(row)
coord_row = np.array(coord_row, dtype=np.int)
coord_col = np.array(coord_col, dtype=np.int)
connectivity = connectivity.tolil()

# Remove diagonal from connectivity matrix
connectivity.setdiag(np.zeros(connectivity.shape[0]))

# create inertia matrix
coord_row = []
coord_col = []
A = []
for ind, row in enumerate(connectivity.rows):
A.append(row)
# We keep only the upper triangular for the moments
# Generator expressions are faster than arrays on the following
row = [i for i in row if i < ind]
coord_row.extend(len(row) * [ind, ])
coord_col.extend(row)

coord_row = np.array(coord_row, dtype=np.int)
coord_col = np.array(coord_col, dtype=np.int)

# build moments as a list
moments = [np.zeros(n_nodes), np.zeros((n_nodes, n_features)),
np.zeros((n_nodes, n_features))]
moments = [np.zeros(n_nodes), np.zeros((n_nodes, n_features))]
moments[0][:n_samples] = 1
moments[1][:n_samples] = X
moments[2][:n_samples] = X ** 2
inertia = np.empty(len(coord_row), dtype=np.float)
if inertia_criterion:
_inertia.compute_inertia(moments[0], moments[1], moments[2],
coord_row, coord_col, inertia)
else:
_inertia.compute_ward_dist(moments[0], moments[1], moments[2],
coord_row, coord_col, inertia)
_inertia.compute_ward_dist(moments[0], moments[1],
coord_row, coord_col, inertia)
inertia = zip(inertia, coord_row, coord_col)
heapq.heapify(inertia)

Expand All @@ -152,7 +140,7 @@ def ward_tree(X, connectivity=None, n_components=None, copy=True,
used_node[i], used_node[j] = False, False

# update the moments
for p in range(3):
for p in range(2):
moments[p][k] = moments[p][i] + moments[p][j]

# update the structure matrix A and the inertia matrix
Expand All @@ -166,12 +154,9 @@ def ward_tree(X, connectivity=None, n_components=None, copy=True,
coord_row = np.empty_like(coord_col)
coord_row.fill(k)
ini = np.empty(len(coord_row), dtype=np.float)
if inertia_criterion:
_inertia.compute_inertia(moments[0], moments[1], moments[2],
coord_row, coord_col, ini)
else:
_inertia.compute_ward_dist(moments[0], moments[1], moments[2],
coord_row, coord_col, ini)

_inertia.compute_ward_dist(moments[0], moments[1],
coord_row, coord_col, ini)
for tupl in itertools.izip(ini, coord_row, coord_col):
heapq.heappush(inertia, tupl)

Expand Down Expand Up @@ -225,26 +210,13 @@ def _hc_cut(n_clusters, children, n_leaves):
n_clusters : int or ndarray
The number of clusters to form.

parent : array-like, shape = [n_nodes]
Int. Gives the parent node for each node, i.e. parent[i] is the
parent node of the node i. The last value of parent is the
root node, that is its self parent, so the last value is taken
3 times in the array.
The n_nodes is equal at (2*n_samples - 1), and takes into
account the nb_samples leaves, and the unique root.

children : list of pairs. Lenght of n_nodes
List of the children of each nodes.
Leaves have empty list of children and are not stored.

n_leaves : int
Number of leaves of the tree.

heights : array-like, shape = [n_nodes]
Gives the inertia of the created nodes. The n_samples first
values of the array are 0, and thus the values are positive (or
null) and are ranked in an increasing order.

Return
------
labels_ : array [n_points]
Expand Down
46 changes: 41 additions & 5 deletions scikits/learn/cluster/tests/test_hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
"""

import numpy as np
from scipy.cluster import hierarchy

from scikits.learn.cluster import Ward, WardAgglomeration, ward_tree
from scikits.learn.feature_extraction.image import img_to_graph
from scikits.learn.cluster.hierarchical import _hc_cut
from scikits.learn.feature_extraction.image import grid_to_graph


def test_structured_ward_tree():
Expand All @@ -16,7 +19,7 @@ def test_structured_ward_tree():
np.random.seed(0)
mask = np.ones([10, 10], dtype=np.bool)
X = np.random.randn(50, 100)
connectivity = img_to_graph(mask, mask)
connectivity = grid_to_graph(*mask.shape)
children, n_components, n_leaves = ward_tree(X.T, connectivity)
n_nodes = 2 * X.shape[1] - 1
assert(len(children) + n_leaves == n_nodes)
Expand All @@ -40,7 +43,7 @@ def test_height_ward_tree():
np.random.seed(0)
mask = np.ones([10, 10], dtype=np.bool)
X = np.random.randn(50, 100)
connectivity = img_to_graph(mask, mask)
connectivity = grid_to_graph(*mask.shape)
children, n_nodes, n_leaves = ward_tree(X.T, connectivity)
n_nodes = 2 * X.shape[1] - 1
assert(len(children) + n_leaves == n_nodes)
Expand All @@ -53,7 +56,7 @@ def test_ward_clustering():
np.random.seed(0)
mask = np.ones([10, 10], dtype=np.bool)
X = np.random.randn(100, 50)
connectivity = img_to_graph(mask, mask)
connectivity = grid_to_graph(*mask.shape)
clustering = Ward(n_clusters=10, connectivity=connectivity)
clustering.fit(X)
assert(np.size(np.unique(clustering.labels_)) == 10)
Expand All @@ -66,7 +69,7 @@ def test_ward_agglomeration():
np.random.seed(0)
mask = np.ones([10, 10], dtype=np.bool)
X = np.random.randn(50, 100)
connectivity = img_to_graph(mask, mask)
connectivity = grid_to_graph(*mask.shape)
ward = WardAgglomeration(n_clusters=5, connectivity=connectivity)
ward.fit(X)
assert(np.size(np.unique(ward.labels_)) == 5)
Expand All @@ -77,6 +80,39 @@ def test_ward_agglomeration():
assert(np.unique(Xfull[0]).size == 5)


def assess_same_labelling(cut1, cut2):
"""Util for comparison with scipy"""
co_clust = []
for cut in [cut1, cut2]:
n = len(cut)
k = cut.max() + 1
ecut = np.zeros((n, k))
ecut[np.arange(n), cut] = 1
co_clust.append(np.dot(ecut, ecut.T))
assert((co_clust[0] == co_clust[1]).all())


def test_scikit_vs_scipy():
"""Test scikit ward with full connectivity (i.e. unstructured) against scipy
"""
from scipy.sparse import lil_matrix
n, p, k = 10, 5, 3

connectivity = lil_matrix(np.ones((n, n)))
for i in range(5):
X = .1*np.random.normal(size=(n, p))
X -= 4*np.arange(n)[:, np.newaxis]
X -= X.mean(axis=1)[:, np.newaxis]

out = hierarchy.ward(X)

children_ = out[:, :2].astype(np.int)
children, _, n_leaves = ward_tree(X, connectivity)

cut = _hc_cut(k, children, n_leaves)
cut_ = _hc_cut(k, children_, n_leaves)
assess_same_labelling(cut, cut_)

if __name__ == '__main__':
import nose
nose.run(argv=['', __file__])
Loading