Skip to content

Commit

Permalink
Fix max_n_iterations in k_means algorithm (rust-ml#244)
Browse files Browse the repository at this point in the history
* initial commit

* use loop over while and related changes
  • Loading branch information
quettabit authored Sep 9, 2022
1 parent e206bf4 commit d34313c
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 30 deletions.
64 changes: 37 additions & 27 deletions algorithms/linfa-clustering/src/k_means/algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,19 +236,17 @@ impl<F: Float, R: Rng + Clone, DA: Data<Elem = F>, T, D: Distance<F>>

let mut min_inertia = F::infinity();
let mut best_centroids = None;
let mut best_iter = None;
let mut memberships = Array1::zeros(n_samples);
let mut dists = Array1::zeros(n_samples);

let n_runs = self.n_runs();

for _ in 0..n_runs {
let mut inertia = min_inertia;
let mut centroids =
self.init_method()
.run(self.dist_fn(), self.n_clusters(), observations, &mut rng);
let mut converged_iter: Option<u64> = None;
for n_iter in 0..self.max_n_iterations() {
let mut n_iter = 0;
let inertia = loop {
update_memberships_and_dists(
self.dist_fn(),
&centroids,
Expand All @@ -257,44 +255,39 @@ impl<F: Float, R: Rng + Clone, DA: Data<Elem = F>, T, D: Distance<F>>
&mut dists,
);
let new_centroids = compute_centroids(&centroids, &observations, &memberships);
inertia = dists.sum();
let distance = self
.dist_fn()
.distance(centroids.view(), new_centroids.view());
centroids = new_centroids;
if distance < self.tolerance() {
converged_iter = Some(n_iter);
break;
n_iter += 1;
if distance < self.tolerance() || n_iter == self.max_n_iterations() {
break dists.sum();
}
}
};

// We keep the centroids which minimize the inertia (defined as the sum of
// the squared distances of the closest centroid for all observations)
// over the n runs of the KMeans algorithm.
if inertia < min_inertia {
min_inertia = inertia;
best_centroids = Some(centroids.clone());
best_iter = converged_iter;
}
}

match best_iter {
Some(_n_iter) => match best_centroids {
Some(centroids) => {
let mut cluster_count = Array1::zeros(self.n_clusters());
memberships
.iter()
.for_each(|&c| cluster_count[c] += F::one());
Ok(KMeans {
centroids,
cluster_count,
inertia: min_inertia / F::cast(dataset.nsamples()),
dist_fn: self.dist_fn().clone(),
})
}
_ => Err(KMeansError::InertiaError),
},
None => Err(KMeansError::NotConverged),
match best_centroids {
Some(centroids) => {
let mut cluster_count = Array1::zeros(self.n_clusters());
memberships
.iter()
.for_each(|&c| cluster_count[c] += F::one());
Ok(KMeans {
centroids,
cluster_count,
inertia: min_inertia / F::cast(dataset.nsamples()),
dist_fn: self.dist_fn().clone(),
})
}
_ => Err(KMeansError::InertiaError),
}
}
}
Expand Down Expand Up @@ -851,6 +844,23 @@ mod tests {
assert!(params.fit_with(None, &data).is_ok());
}

#[test]
fn test_max_n_iterations() {
let mut rng = Xoshiro256Plus::seed_from_u64(42);
let xt = Array::random_using(100, Uniform::new(0., 1.0), &mut rng).insert_axis(Axis(1));
let yt = function_test_1d(&xt);
let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap();
let dataset = DatasetBase::from(data.clone());
// For data created using the above rng and seed, for 6 clusters, it would take 8 iterations to converge.
// However, when specifying max_n_iterations as 5, the algorithm should stop early gracefully.
let _model = KMeans::params_with(6, rng.clone(), L2Dist)
.n_runs(1)
.max_n_iterations(5)
.init_method(KMeansInit::Random)
.fit(&dataset)
.expect("KMeans fitted");
}

fn fittable<T: Fit<Array2<f64>, (), KMeansError>>(_: T) {}
#[test]
fn thread_rng_fittable() {
Expand Down
3 changes: 0 additions & 3 deletions algorithms/linfa-clustering/src/k_means/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@ pub enum KMeansError {
/// When inertia computation fails
#[error("Fitting failed: No inertia improvement (-inf)")]
InertiaError,
/// When fitting algorithm does not converge
#[error("Fitting failed: Did not converge. Try different init parameters or check for degenerate data.")]
NotConverged,
#[error(transparent)]
LinfaError(#[from] linfa::error::Error),
}
Expand Down

0 comments on commit d34313c

Please sign in to comment.