Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize K-means #97

Merged
merged 15 commits into from
Mar 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 27 additions & 23 deletions algorithms/linfa-clustering/benches/appx_dbscan.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use criterion::{
black_box, criterion_group, criterion_main, AxisScale, Criterion, ParameterizedBenchmark,
black_box, criterion_group, criterion_main, AxisScale, BenchmarkId, Criterion,
PlotConfiguration,
};
use linfa::traits::Transformer;
Expand All @@ -19,28 +19,32 @@ fn appx_dbscan_bench(c: &mut Criterion) {
/*(10000, 0.1),*/
];

let benchmark = ParameterizedBenchmark::new(
"appx_dbscan",
move |bencher, &cluster_size_and_slack| {
let min_points = 4;
let n_features = 3;
let tolerance = 0.3;
let centroids =
Array2::random_using((min_points, n_features), Uniform::new(-30., 30.), &mut rng);
let dataset = generate_blobs(cluster_size_and_slack.0, &centroids, &mut rng);
bencher.iter(|| {
black_box(
AppxDbscan::params(min_points)
.tolerance(tolerance)
.slack(cluster_size_and_slack.1)
.transform(&dataset),
)
});
},
cluster_sizes_and_slacks,
)
.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic));
c.bench("appx_dbscan", benchmark);
let mut benchmark = c.benchmark_group("appx_dbscan");
benchmark.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic));
for cluster_size_and_slack in cluster_sizes_and_slacks {
let rng = &mut rng;
benchmark.bench_with_input(
BenchmarkId::new("appx_dbscan", cluster_size_and_slack.0),
&cluster_size_and_slack,
move |bencher, &cluster_size_and_slack| {
let min_points = 4;
let n_features = 3;
let tolerance = 0.3;
let centroids =
Array2::random_using((min_points, n_features), Uniform::new(-30., 30.), rng);
let dataset = generate_blobs(cluster_size_and_slack.0, &centroids, rng);
bencher.iter(|| {
black_box(
AppxDbscan::params(min_points)
.tolerance(tolerance)
.slack(cluster_size_and_slack.1)
.transform(&dataset),
)
});
},
);
}
benchmark.finish();
}

criterion_group! {
Expand Down
48 changes: 26 additions & 22 deletions algorithms/linfa-clustering/benches/dbscan.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use criterion::{
black_box, criterion_group, criterion_main, AxisScale, Criterion, ParameterizedBenchmark,
black_box, criterion_group, criterion_main, AxisScale, BenchmarkId, Criterion,
PlotConfiguration,
};
use linfa::traits::Transformer;
Expand All @@ -14,27 +14,31 @@ fn dbscan_bench(c: &mut Criterion) {
let mut rng = Isaac64Rng::seed_from_u64(40);
let cluster_sizes = vec![10, 100, 1000, 10000];

let benchmark = ParameterizedBenchmark::new(
"dbscan",
move |bencher, &cluster_size| {
let min_points = 4;
let n_features = 3;
let tolerance = 0.3;
let centroids =
Array2::random_using((min_points, n_features), Uniform::new(-30., 30.), &mut rng);
let dataset = generate_blobs(cluster_size, &centroids, &mut rng);
bencher.iter(|| {
black_box(
Dbscan::params(min_points)
.tolerance(tolerance)
.transform(&dataset),
)
});
},
cluster_sizes,
)
.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic));
c.bench("dbscan", benchmark);
let mut benchmark = c.benchmark_group("dbscan");
benchmark.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic));
for cluster_size in cluster_sizes {
let rng = &mut rng;
benchmark.bench_with_input(
BenchmarkId::new("dbscan", cluster_size),
&cluster_size,
move |bencher, &cluster_size| {
let min_points = 4;
let n_features = 3;
let tolerance = 0.3;
let centroids =
Array2::random_using((min_points, n_features), Uniform::new(-30., 30.), rng);
let dataset = generate_blobs(cluster_size, &centroids, rng);
bencher.iter(|| {
black_box(
Dbscan::params(min_points)
.tolerance(tolerance)
.transform(&dataset),
)
});
},
);
}
benchmark.finish()
}

criterion_group! {
Expand Down
54 changes: 29 additions & 25 deletions algorithms/linfa-clustering/benches/gaussian_mixture.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use criterion::{
black_box, criterion_group, criterion_main, AxisScale, Criterion, ParameterizedBenchmark,
black_box, criterion_group, criterion_main, AxisScale, BenchmarkId, Criterion,
PlotConfiguration,
};
use linfa::traits::Fit;
Expand All @@ -15,30 +15,34 @@ fn gaussian_mixture_bench(c: &mut Criterion) {
let mut rng = Isaac64Rng::seed_from_u64(40);
let cluster_sizes = vec![10, 100, 1000, 10000];

let benchmark = ParameterizedBenchmark::new(
"gaussian_mixture",
move |bencher, &cluster_size| {
let n_clusters = 4;
let n_features = 3;
let centroids =
Array2::random_using((n_clusters, n_features), Uniform::new(-30., 30.), &mut rng);
let dataset: DatasetBase<_, _> =
(generate_blobs(cluster_size, &centroids, &mut rng)).into();
bencher.iter(|| {
black_box(
GaussianMixtureModel::params(n_clusters)
.with_rng(rng.clone())
.with_tolerance(1e-3)
.with_max_n_iterations(1000)
.fit(&dataset)
.expect("GMM fitting fail"),
)
});
},
cluster_sizes,
)
.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic));
c.bench("gaussian_mixture", benchmark);
let mut benchmark = c.benchmark_group("gaussian_mixture");
benchmark.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic));
for cluster_size in cluster_sizes {
let rng = &mut rng;
benchmark.bench_with_input(
BenchmarkId::new("gaussian_mixture", cluster_size),
&cluster_size,
move |bencher, &cluster_size| {
let n_clusters = 4;
let n_features = 3;
let centroids =
Array2::random_using((n_clusters, n_features), Uniform::new(-30., 30.), rng);
let dataset: DatasetBase<_, _> =
(generate_blobs(cluster_size, &centroids, rng)).into();
bencher.iter(|| {
black_box(
GaussianMixtureModel::params(n_clusters)
.with_rng(rng.clone())
.with_tolerance(1e-3)
.with_max_n_iterations(1000)
.fit(&dataset)
.expect("GMM fitting fail"),
)
});
},
);
}
benchmark.finish();
}

criterion_group! {
Expand Down
39 changes: 19 additions & 20 deletions algorithms/linfa-clustering/benches/k_means.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use criterion::{
black_box, criterion_group, criterion_main, AxisScale, Criterion, ParameterizedBenchmark,
black_box, criterion_group, criterion_main, AxisScale, BenchmarkId, Criterion,
PlotConfiguration,
};
use linfa::traits::Fit;
Expand All @@ -15,27 +15,26 @@ fn k_means_bench(c: &mut Criterion) {
let mut rng = Isaac64Rng::seed_from_u64(40);
let cluster_sizes = vec![10, 100, 1000, 10000];

let benchmark = ParameterizedBenchmark::new(
"naive_k_means",
move |bencher, &cluster_size| {
let n_clusters = 4;
let n_features = 3;
let centroids =
Array2::random_using((n_clusters, n_features), Uniform::new(-30., 30.), &mut rng);
let dataset = DatasetBase::from(generate_blobs(cluster_size, &centroids, &mut rng));
let mut benchmark = c.benchmark_group("naive_k_means");
benchmark.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic));
for cluster_size in cluster_sizes {
let rng = &mut rng;
let n_clusters = 4;
let n_features = 3;
let centroids =
Array2::random_using((n_clusters, n_features), Uniform::new(-30., 30.), rng);
let dataset = DatasetBase::from(generate_blobs(cluster_size, &centroids, rng));
benchmark.bench_function(BenchmarkId::new("naive_k_means", cluster_size), |bencher| {
bencher.iter(|| {
black_box(
KMeans::params_with_rng(n_clusters, rng.clone())
.max_n_iterations(1000)
.tolerance(1e-3)
.fit(&dataset),
)
KMeans::params_with_rng(black_box(n_clusters), black_box(rng.clone()))
.max_n_iterations(black_box(1000))
.tolerance(black_box(1e-3))
.fit(&dataset)
});
},
cluster_sizes,
)
.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic));
c.bench("naive_k_means", benchmark);
});
}

benchmark.finish();
}

criterion_group! {
Expand Down
65 changes: 26 additions & 39 deletions algorithms/linfa-clustering/src/k_means/algorithm.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
use crate::k_means::errors::{KMeansError, Result};
use crate::k_means::helpers::IncrementalMean;
use crate::k_means::hyperparameters::{KMeansHyperParams, KMeansHyperParamsBuilder};
use linfa::{traits::*, DatasetBase, Float};
use ndarray::{s, Array1, Array2, ArrayBase, Axis, Data, DataMut, Ix1, Ix2, Zip};
use ndarray::{Array1, Array2, ArrayBase, Axis, Data, DataMut, Ix1, Ix2, Zip};
use ndarray_rand::rand;
use ndarray_rand::rand::Rng;
use ndarray_stats::DeviationExt;
use rand_isaac::Isaac64Rng;
use std::collections::HashMap;

#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
Expand Down Expand Up @@ -258,54 +256,36 @@ fn compute_inertia<F: Float>(
/// If you check the `compute_cluster_memberships` function,
/// you can see that it expects to receive centroids as a 2-dimensional array.
///
/// `compute_centroids` wraps our `compute_centroids_hashmap` to return a 2-dimensional array,
/// `compute_centroids` returns a 2-dimensional array,
/// where the i-th row corresponds to the i-th cluster.
fn compute_centroids<F: Float>(
// The number of clusters could be inferred from `centroids_hashmap`,
// but it is indeed possible for a cluster to become empty during the
// multiple rounds of assignment-update optimisations
// This would lead to an underestimate of the number of clusters
// and several errors down the line due to shape mismatches
n_clusters: usize,
// (n_observations, n_features)
observations: &ArrayBase<impl Data<Elem = F>, Ix2>,
// (n_observations,)
cluster_memberships: &ArrayBase<impl Data<Elem = usize>, Ix1>,
) -> Array2<F> {
let centroids_hashmap = compute_centroids_hashmap(&observations, &cluster_memberships);
let (_, n_features) = observations.dim();

let mut counts: Array1<usize> = Array1::zeros(n_clusters);
let mut centroids: Array2<F> = Array2::zeros((n_clusters, n_features));
for (centroid_index, centroid) in centroids_hashmap.into_iter() {
centroids
.slice_mut(s![centroid_index, ..])
.assign(&centroid.current_mean);
}
centroids
}

/// Iterate over our observations and capture in a HashMap the new centroids.
/// The HashMap is a (cluster_index => new centroid) mapping.
fn compute_centroids_hashmap<F: Float>(
// (n_observations, n_features)
observations: &ArrayBase<impl Data<Elem = F>, Ix2>,
// (n_observations,)
cluster_memberships: &ArrayBase<impl Data<Elem = usize>, Ix1>,
) -> HashMap<usize, IncrementalMean<F>> {
let mut new_centroids: HashMap<usize, IncrementalMean<F>> = HashMap::new();
Zip::from(observations.genrows())
.and(cluster_memberships)
.apply(|observation, cluster_membership| {
if let Some(incremental_mean) = new_centroids.get_mut(cluster_membership) {
incremental_mean.update(&observation);
} else {
new_centroids.insert(
*cluster_membership,
IncrementalMean::new(observation.to_owned()),
);
.apply(|observation, &cluster_membership| {
let mut centroid = centroids.row_mut(cluster_membership);
centroid += &observation;
counts[cluster_membership] += 1;
});

Zip::from(centroids.genrows_mut())
.and(&counts)
.apply(|mut centroid, &cnt| {
if cnt != 0 {
centroid /= F::from(cnt).unwrap();
}
});
new_centroids
centroids
}

/// Given a matrix of centroids with shape (n_centroids, n_features)
Expand Down Expand Up @@ -351,11 +331,9 @@ fn closest_centroid<F: Float>(
// (n_features)
observation: &ArrayBase<impl Data<Elem = F>, Ix1>,
) -> usize {
let mut iterator = centroids.genrows().into_iter().peekable();
let iterator = centroids.genrows().into_iter();

let first_centroid = iterator
.peek()
.expect("There has to be at least one centroid");
let first_centroid = centroids.row(0);
let (mut closest_index, mut minimum_distance) = (
0,
first_centroid
Expand Down Expand Up @@ -472,6 +450,15 @@ mod tests {
assert_eq!(centroids.len_of(Axis(0)), 2);
}

#[test]
fn test_compute_extra_centroids() {
let observations = array![[1.0, 2.0]];
let memberships = array![0];
// Should return an average of 0 for empty clusters
let centroids = compute_centroids(2, &observations, &memberships);
assert_abs_diff_eq!(centroids, array![[1.0, 2.0], [0.0, 0.0]]);
}

#[test]
// An observation is closest to itself.
fn nothing_is_closer_than_self() {
Expand Down
Loading