Skip to content

Commit

Permalink
Changes based on review
Browse files Browse the repository at this point in the history
* Doc string updates for clear information

* Removed redundant attribute n_clusters_

* Fixed tests

* Changes to FeatureAgglomeration to include distance threshold
  • Loading branch information
VathsalaAchar committed Oct 30, 2017
1 parent 6c5f957 commit 8ea9afa
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 51 deletions.
39 changes: 16 additions & 23 deletions sklearn/cluster/hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,8 +549,6 @@ def _hc_cut(n_clusters, children, n_leaves,
----------
n_clusters : int or ndarray
The number of clusters to form.
NOTE: You should set either ``n_clusters`` or ``distance_threshold``,
NOT both.
children : 2D array, shape (n_nodes-1, 2)
The children of each non-leaf node. Values less than `n_samples`
Expand All @@ -563,11 +561,11 @@ def _hc_cut(n_clusters, children, n_leaves,
n_leaves : int
Number of leaves of the tree.
distance_threshold : int (optional)
distance_threshold : float (optional)
The distance threshold to cluster at.
If ``distance_threshold`` is set then ``n_clusters`` will be the
number of clusters at which the distances exceed the
``distance_threshold``
If distance_threshold is set then n_clusters will be ignored. Instead,
the number of clusters will be determined by when distances between
clusters first exceed the distance_threshold during agglomeration.
NOTE: You should set either ``n_clusters`` or ``distance_threshold``,
NOT both. If the ``distance_threshold`` is set then ``n_clusters`` is
ignored.
Expand Down Expand Up @@ -623,10 +621,8 @@ class AgglomerativeClustering(BaseEstimator, ClusterMixin):
----------
n_clusters : int, default=2
The number of clusters to find.
NOTE: You should set either ``n_clusters`` or ``distance_threshold``,
NOT both.
distance_threshold : int (optional)
distance_threshold : float (optional)
The distance threshold to cluster at.
NOTE: You should set either ``n_clusters`` or ``distance_threshold``,
NOT both. If the ``distance_threshold`` is set then ``n_clusters`` is
Expand Down Expand Up @@ -733,17 +729,15 @@ def fit(self, X, y=None):
X = check_array(X, ensure_min_samples=2, estimator=self)
memory = check_memory(self.memory)

self.n_clusters_ = self.n_clusters

if self.n_clusters_ is not None and self.n_clusters_ <= 0:
if self.n_clusters is not None and self.n_clusters <= 0:
raise ValueError("n_clusters should be an integer greater than 0."
" %s was provided." % str(self.n_clusters_))
" %s was provided." % str(self.n_clusters))

if self.n_clusters_ is None and self.distance_threshold is None:
if self.n_clusters is None and self.distance_threshold is None:
raise ValueError("Either n_clusters (>0) or distance_threshold "
"needs to be set, got n_clusters={} and "
"distance_threshold={} instead.".format(
self.n_clusters_, self.distance_threshold))
self.n_clusters, self.distance_threshold))

if self.linkage == "ward" and self.affinity != "euclidean":
raise ValueError("%s was provided as affinity. Ward can only "
Expand Down Expand Up @@ -771,8 +765,8 @@ def fit(self, X, y=None):
# Early stopping is likely to give a speed up only for
# a large number of clusters. The actual threshold
# implemented here is heuristic
compute_full_tree = self.n_clusters_ < max(100, .02 * n_samples)
n_clusters = self.n_clusters_
compute_full_tree = self.n_clusters < max(100, .02 * n_samples)
n_clusters = self.n_clusters
if compute_full_tree:
n_clusters = None

Expand All @@ -799,11 +793,11 @@ def fit(self, X, y=None):
# Cut the tree
if compute_full_tree:
if distance_threshold is not None:
self.labels_ = _hc_cut(self.n_clusters_, self.children_,
self.labels_ = _hc_cut(self.n_clusters, self.children_,
self.n_leaves_,
distance_threshold, distances)
else:
self.labels_ = _hc_cut(self.n_clusters_, self.children_,
self.labels_ = _hc_cut(self.n_clusters, self.children_,
self.n_leaves_)
else:
labels = _hierarchical.hc_get_heads(parents, copy=False)
Expand All @@ -826,10 +820,8 @@ class FeatureAgglomeration(AgglomerativeClustering, AgglomerationTransform):
----------
n_clusters : int, default 2
The number of clusters to find.
NOTE: You should set either ``n_clusters`` or ``distance_threshold``,
NOT both.
distance_threshold : int (optional)
distance_threshold : float (optional)
The distance threshold to cluster at.
NOTE: You should set either ``n_clusters`` or ``distance_threshold``,
NOT both. If the ``distance_threshold`` is set then ``n_clusters`` is
Expand Down Expand Up @@ -897,7 +889,8 @@ class FeatureAgglomeration(AgglomerativeClustering, AgglomerationTransform):
are merged to form node `n_features + i`
"""

def __init__(self, n_clusters=2, affinity="euclidean",
def __init__(self, n_clusters=2, distance_threshold=None,
affinity="euclidean",
memory=None,
connectivity=None, compute_full_tree='auto',
linkage='ward', pooling_func=np.mean):
Expand Down
28 changes: 0 additions & 28 deletions sklearn/cluster/tests/test_hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,6 @@ def test_agg_n_cluster_and_distance_threshold():
agc.fit(X)
# Expecting no errors here
assert_true(agc.n_clusters == n_clus)
assert_true(agc.n_clusters_ == n_clus)


def test_affinity_passed_to_fix_connectivity():
Expand Down Expand Up @@ -643,30 +642,3 @@ def test_agglomerative_clustering_with_distance_threshold_edge_case():
linkage=linkage)
y_pred = clusterer.fit_predict(X)
assert_equal(1, adjusted_rand_score(y_true, y_pred))


def test_affinity_passed_to_fix_connectivity():
# Test that the affinity parameter is actually passed to the pairwise
# function

size = 2
rng = np.random.RandomState(0)
X = rng.randn(size, size)
mask = np.array([True, False, False, True])

connectivity = grid_to_graph(n_x=size, n_y=size,
mask=mask, return_as=np.ndarray)

class FakeAffinity:
def __init__(self):
self.counter = 0

def increment(self, *args, **kwargs):
self.counter += 1
return self.counter

fa = FakeAffinity()

linkage_tree(X, connectivity=connectivity, affinity=fa.increment)

assert_equal(fa.counter, 3)

0 comments on commit 8ea9afa

Please sign in to comment.