From e432b568f12d68aa762af49372c7014325e96400 Mon Sep 17 00:00:00 2001 From: xd009642 Date: Tue, 24 Dec 2019 14:57:53 +0000 Subject: [PATCH 01/11] Initial commit of DBScan --- linfa-clustering/src/dbscan/algorithm.rs | 65 +++++++++++++++++++ .../src/dbscan/hyperparameters.rs | 59 +++++++++++++++++ linfa-clustering/src/dbscan/mod.rs | 5 ++ linfa-clustering/src/lib.rs | 2 + 4 files changed, 131 insertions(+) create mode 100644 linfa-clustering/src/dbscan/algorithm.rs create mode 100644 linfa-clustering/src/dbscan/hyperparameters.rs create mode 100644 linfa-clustering/src/dbscan/mod.rs diff --git a/linfa-clustering/src/dbscan/algorithm.rs b/linfa-clustering/src/dbscan/algorithm.rs new file mode 100644 index 000000000..a4b34267e --- /dev/null +++ b/linfa-clustering/src/dbscan/algorithm.rs @@ -0,0 +1,65 @@ +use crate::dbscan::hyperparameters::DBScanHyperParams; +use ndarray::{Array1, ArrayBase, Axis, Data, Ix1, Ix2}; + +pub fn predict( + hyperparameters: DBScanHyperParams, + observations: &ArrayBase + Sync, Ix2>, +) -> Array1> { + let mut result = Array1::from_elem(observations.dim().1, None); + let mut latest_id = 0; + for (i, obs) in observations.axis_iter(Axis(1)).enumerate() { + if result[i].is_some() { + continue; + } + let n = find_neighbors(&obs, observations, hyperparameters.tolerance()); + if n.len() < hyperparameters.minimum_points() { + continue; + } + // Now go over the neighbours adding them to the cluster + let mut search_queue = n + .iter() + .filter(|x| result[[**x]].is_none()) + .copied() + .collect::>(); + while !search_queue.is_empty() { + let cand = search_queue.remove(0); + + result[cand] = Some(latest_id); + + let mut n = find_neighbors(&obs, observations, hyperparameters.tolerance()) + .iter() + .filter(|x| result[[**x]].is_none() && !search_queue.contains(x)) + .copied() + .collect::>(); + + search_queue.append(&mut n); + } + latest_id += 1; + } + result +} + +fn find_neighbors( + candidate: &ArrayBase + Sync, Ix1>, + observations: &ArrayBase + Sync, Ix2>, + eps: f64, +) -> Vec { + let mut res = vec![]; + for (i, obs) in observations.axis_iter(Axis(1)).enumerate() { + if distance(candidate, &obs) < eps { + res.push(i); + } + } + res +} + +fn distance( + lhs: &ArrayBase + Sync, Ix1>, + rhs: &ArrayBase + Sync, Ix1>, +) -> f64 { + let mut acc = 0.0; + for (l, r) in lhs.iter().zip(rhs.iter()) { + acc += (l - r).powi(2); + } + acc.sqrt() +} diff --git a/linfa-clustering/src/dbscan/hyperparameters.rs b/linfa-clustering/src/dbscan/hyperparameters.rs new file mode 100644 index 000000000..42f1d0d1b --- /dev/null +++ b/linfa-clustering/src/dbscan/hyperparameters.rs @@ -0,0 +1,59 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +/// The set of hyperparameters that can be specified for the execution of +/// the [DBSCAN algorithm](struct.DBScan.html). +pub struct DBScanHyperParams { + /// Distance between points for them to be considered neighbours. + tolerance: f64, + /// Minimum number of neighboring points a point needs to have to be a core + /// point and not a noise point. + min_points: usize, +} + +/// Helper struct used to construct a set of hyperparameters for +pub struct DBScanHyperParamsBuilder { + tolerance: f64, + min_points: usize, +} + +impl DBScanHyperParamsBuilder { + pub fn tolerance(mut self, tolerance: f64) -> Self { + self.tolerance = tolerance; + self + } + + pub fn build(self) -> DBScanHyperParams { + DBScanHyperParams::build(self.tolerance, self.min_points) + } +} + +impl DBScanHyperParams { + pub fn new(min_points: usize) -> DBScanHyperParamsBuilder { + DBScanHyperParamsBuilder { + min_points, + tolerance: 1e-4, + } + } + + pub fn tolerance(&self) -> f64 { + self.tolerance + } + + pub fn minimum_points(&self) -> usize { + self.min_points + } + + fn build(tolerance: f64, min_points: usize) -> Self { + if tolerance <= 0. { + panic!("`tolerance` must be greater than 0!"); + } + if min_points == 0 { + panic!("`min_points` cannot be 0!"); + } + Self { + tolerance, + min_points, + } + } +} diff --git a/linfa-clustering/src/dbscan/mod.rs b/linfa-clustering/src/dbscan/mod.rs new file mode 100644 index 000000000..988c809bf --- /dev/null +++ b/linfa-clustering/src/dbscan/mod.rs @@ -0,0 +1,5 @@ +mod algorithm; +mod hyperparameters; + +pub use algorithm::*; +pub use hyperparameters::*; diff --git a/linfa-clustering/src/lib.rs b/linfa-clustering/src/lib.rs index ec910c320..9f262177f 100644 --- a/linfa-clustering/src/lib.rs +++ b/linfa-clustering/src/lib.rs @@ -18,9 +18,11 @@ //! Implementation choices, algorithmic details and a tutorial can be found [here](struct.KMeans.html). //! //! Check [here](https://github.com/LukeMathWalker/clustering-benchmarks) for extensive benchmarks against `scikit-learn`'s K-means implementation. +mod dbscan; #[allow(clippy::new_ret_no_self)] mod k_means; mod utils; +pub use dbscan::*; pub use k_means::*; pub use utils::*; From cc69fafbe4cd37a31af71039ea4c688a22c6b086 Mon Sep 17 00:00:00 2001 From: xd009642 Date: Tue, 24 Dec 2019 16:01:27 +0000 Subject: [PATCH 02/11] Finish tests and impl * Remove Sync trait bounds * Use ndarray_stats l2_norm * Actually use observation in search queue to get neighbours * Add two tests for noise points and nested dense clusters --- linfa-clustering/src/dbscan/algorithm.rs | 81 ++++++++++++++----- .../src/dbscan/hyperparameters.rs | 5 +- 2 files changed, 65 insertions(+), 21 deletions(-) diff --git a/linfa-clustering/src/dbscan/algorithm.rs b/linfa-clustering/src/dbscan/algorithm.rs index a4b34267e..bbdc9344d 100644 --- a/linfa-clustering/src/dbscan/algorithm.rs +++ b/linfa-clustering/src/dbscan/algorithm.rs @@ -1,9 +1,10 @@ use crate::dbscan::hyperparameters::DBScanHyperParams; -use ndarray::{Array1, ArrayBase, Axis, Data, Ix1, Ix2}; +use ndarray::{s, Array1, ArrayBase, Axis, Data, Ix1, Ix2}; +use ndarray_stats::DeviationExt; pub fn predict( - hyperparameters: DBScanHyperParams, - observations: &ArrayBase + Sync, Ix2>, + hyperparameters: &DBScanHyperParams, + observations: &ArrayBase, Ix2>, ) -> Array1> { let mut result = Array1::from_elem(observations.dim().1, None); let mut latest_id = 0; @@ -26,11 +27,15 @@ pub fn predict( result[cand] = Some(latest_id); - let mut n = find_neighbors(&obs, observations, hyperparameters.tolerance()) - .iter() - .filter(|x| result[[**x]].is_none() && !search_queue.contains(x)) - .copied() - .collect::>(); + let mut n = find_neighbors( + &observations.slice(s![.., cand]), + observations, + hyperparameters.tolerance(), + ) + .iter() + .filter(|x| result[[**x]].is_none() && !search_queue.contains(x)) + .copied() + .collect::>(); search_queue.append(&mut n); } @@ -40,26 +45,64 @@ pub fn predict( } fn find_neighbors( - candidate: &ArrayBase + Sync, Ix1>, - observations: &ArrayBase + Sync, Ix2>, + candidate: &ArrayBase, Ix1>, + observations: &ArrayBase, Ix2>, eps: f64, ) -> Vec { let mut res = vec![]; for (i, obs) in observations.axis_iter(Axis(1)).enumerate() { - if distance(candidate, &obs) < eps { + if candidate.l2_dist(&obs).unwrap() < eps { res.push(i); } } res } -fn distance( - lhs: &ArrayBase + Sync, Ix1>, - rhs: &ArrayBase + Sync, Ix1>, -) -> f64 { - let mut acc = 0.0; - for (l, r) in lhs.iter().zip(rhs.iter()) { - acc += (l - r).powi(2); +#[cfg(test)] +mod tests { + use super::*; + use ndarray::{arr1, s, Array2}; + + #[test] + fn nested_clusters() { + // Create a circuit of points and then a cluster in the centre + // and ensure they are identified as two separate clusters + let params = DBScanHyperParams::new(4).tolerance(1.0).build(); + + let mut data: Array2 = Array2::zeros((2, 50)); + let rising = Array1::linspace(0.0, 8.0, 10); + data.slice_mut(s![0, 0..10]).assign(&rising); + data.slice_mut(s![0, 10..20]).assign(&rising); + data.slice_mut(s![1, 20..30]).assign(&rising); + data.slice_mut(s![1, 30..40]).assign(&rising); + + data.slice_mut(s![1, 0..10]).fill(0.0); + data.slice_mut(s![1, 10..20]).fill(8.0); + data.slice_mut(s![0, 20..30]).fill(0.0); + data.slice_mut(s![0, 30..40]).fill(8.0); + + data.slice_mut(s![.., 40..]).fill(5.0); + + let labels = predict(¶ms, &data); + + assert!(labels.slice(s![..40]).iter().all(|x| x == &Some(0))); + assert!(labels.slice(s![40..]).iter().all(|x| x == &Some(1))); + } + + #[test] + fn non_cluster_points() { + let params = DBScanHyperParams::new(4).build(); + + let data: Array2 = Array2::zeros((2, 3)); + + let labels = predict(¶ms, &data); + assert!(labels.iter().all(|x| x.is_none())); + + let mut data: Array2 = Array2::zeros((2, 5)); + data.slice_mut(s![.., 0]).assign(&arr1(&[10.0, 10.0])); + + let labels = predict(¶ms, &data); + let expected = arr1(&[None, Some(0), Some(0), Some(0), Some(0)]); + assert_eq!(labels, expected); } - acc.sqrt() } diff --git a/linfa-clustering/src/dbscan/hyperparameters.rs b/linfa-clustering/src/dbscan/hyperparameters.rs index 42f1d0d1b..6bb9ed6e0 100644 --- a/linfa-clustering/src/dbscan/hyperparameters.rs +++ b/linfa-clustering/src/dbscan/hyperparameters.rs @@ -48,8 +48,9 @@ impl DBScanHyperParams { if tolerance <= 0. { panic!("`tolerance` must be greater than 0!"); } - if min_points == 0 { - panic!("`min_points` cannot be 0!"); + // There is always at least one neighbor to a point (itself) + if min_points <= 1 { + panic!("`min_points` must be greater than 1!"); } Self { tolerance, From 2da1e64f175610a7f169546abeeb04fe8f52c16f Mon Sep 17 00:00:00 2001 From: xd009642 Date: Tue, 24 Dec 2019 16:07:51 +0000 Subject: [PATCH 03/11] Add candidate to results as soon as possible --- linfa-clustering/src/dbscan/algorithm.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/linfa-clustering/src/dbscan/algorithm.rs b/linfa-clustering/src/dbscan/algorithm.rs index bbdc9344d..d04d5afb5 100644 --- a/linfa-clustering/src/dbscan/algorithm.rs +++ b/linfa-clustering/src/dbscan/algorithm.rs @@ -19,9 +19,11 @@ pub fn predict( // Now go over the neighbours adding them to the cluster let mut search_queue = n .iter() - .filter(|x| result[[**x]].is_none()) + .filter(|x| result[[**x]].is_none() && **x != i) .copied() .collect::>(); + + result[i] = Some(latest_id); while !search_queue.is_empty() { let cand = search_queue.remove(0); From 0d9f9fd2a413ddbe69033bc48e7d2b6efe80bdd2 Mon Sep 17 00:00:00 2001 From: xd009642 Date: Tue, 24 Dec 2019 16:13:02 +0000 Subject: [PATCH 04/11] Apply renaming suggestions --- linfa-clustering/src/dbscan/algorithm.rs | 32 ++++++++++++------------ 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/linfa-clustering/src/dbscan/algorithm.rs b/linfa-clustering/src/dbscan/algorithm.rs index d04d5afb5..9bbe9bf1a 100644 --- a/linfa-clustering/src/dbscan/algorithm.rs +++ b/linfa-clustering/src/dbscan/algorithm.rs @@ -6,44 +6,44 @@ pub fn predict( hyperparameters: &DBScanHyperParams, observations: &ArrayBase, Ix2>, ) -> Array1> { - let mut result = Array1::from_elem(observations.dim().1, None); - let mut latest_id = 0; + let mut cluster_memberships = Array1::from_elem(observations.dim().1, None); + let mut current_cluster_id = 0; for (i, obs) in observations.axis_iter(Axis(1)).enumerate() { - if result[i].is_some() { + if cluster_memberships[i].is_some() { continue; } - let n = find_neighbors(&obs, observations, hyperparameters.tolerance()); - if n.len() < hyperparameters.minimum_points() { + let neighbors = find_neighbors(&obs, observations, hyperparameters.tolerance()); + if neighbors.len() < hyperparameters.minimum_points() { continue; } // Now go over the neighbours adding them to the cluster - let mut search_queue = n + let mut search_queue = neighbors .iter() - .filter(|x| result[[**x]].is_none() && **x != i) + .filter(|x| cluster_memberships[[**x]].is_none() && **x != i) .copied() .collect::>(); - result[i] = Some(latest_id); + cluster_memberships[i] = Some(current_cluster_id); while !search_queue.is_empty() { - let cand = search_queue.remove(0); + let candidate = search_queue.remove(0); - result[cand] = Some(latest_id); + cluster_memberships[candidate] = Some(current_cluster_id); - let mut n = find_neighbors( - &observations.slice(s![.., cand]), + let mut neighbors = find_neighbors( + &observations.slice(s![.., candidate]), observations, hyperparameters.tolerance(), ) .iter() - .filter(|x| result[[**x]].is_none() && !search_queue.contains(x)) + .filter(|x| cluster_memberships[[**x]].is_none() && !search_queue.contains(x)) .copied() .collect::>(); - search_queue.append(&mut n); + search_queue.append(&mut neighbors); } - latest_id += 1; + current_cluster_id += 1; } - result + cluster_memberships } fn find_neighbors( From 58d280a2cd13bae513f83f83913bd4eff12ee300 Mon Sep 17 00:00:00 2001 From: xd009642 Date: Tue, 24 Dec 2019 16:16:05 +0000 Subject: [PATCH 05/11] Apply comments --- linfa-clustering/src/dbscan/algorithm.rs | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/linfa-clustering/src/dbscan/algorithm.rs b/linfa-clustering/src/dbscan/algorithm.rs index 9bbe9bf1a..a75619e73 100644 --- a/linfa-clustering/src/dbscan/algorithm.rs +++ b/linfa-clustering/src/dbscan/algorithm.rs @@ -17,13 +17,13 @@ pub fn predict( continue; } // Now go over the neighbours adding them to the cluster + cluster_memberships[i] = Some(current_cluster_id); let mut search_queue = neighbors .iter() - .filter(|x| cluster_memberships[[**x]].is_none() && **x != i) + .filter(|x| cluster_memberships[[**x]].is_none()) .copied() .collect::>(); - cluster_memberships[i] = Some(current_cluster_id); while !search_queue.is_empty() { let candidate = search_queue.remove(0); @@ -94,12 +94,6 @@ mod tests { #[test] fn non_cluster_points() { let params = DBScanHyperParams::new(4).build(); - - let data: Array2 = Array2::zeros((2, 3)); - - let labels = predict(¶ms, &data); - assert!(labels.iter().all(|x| x.is_none())); - let mut data: Array2 = Array2::zeros((2, 5)); data.slice_mut(s![.., 0]).assign(&arr1(&[10.0, 10.0])); @@ -107,4 +101,14 @@ mod tests { let expected = arr1(&[None, Some(0), Some(0), Some(0), Some(0)]); assert_eq!(labels, expected); } + + #[test] + fn dataset_too_small() { + let params = DBScanHyperParams::new(4).build(); + + let data: Array2 = Array2::zeros((2, 3)); + + let labels = predict(¶ms, &data); + assert!(labels.iter().all(|x| x.is_none())); + } } From 596290dafa9baf8062eece762a534cbecc9a1e05 Mon Sep 17 00:00:00 2001 From: xd009642 Date: Tue, 24 Dec 2019 16:33:55 +0000 Subject: [PATCH 06/11] Add benchmark --- linfa-clustering/Cargo.toml | 4 +++ linfa-clustering/benches/dbscan.rs | 38 ++++++++++++++++++++++++ linfa-clustering/src/dbscan/algorithm.rs | 8 ++--- 3 files changed, 46 insertions(+), 4 deletions(-) create mode 100644 linfa-clustering/benches/dbscan.rs diff --git a/linfa-clustering/Cargo.toml b/linfa-clustering/Cargo.toml index e7fd9fd57..c9e903bf8 100644 --- a/linfa-clustering/Cargo.toml +++ b/linfa-clustering/Cargo.toml @@ -28,3 +28,7 @@ approx = "0.3" [[bench]] name = "k_means" harness = false + +[[bench]] +name = "dbscan" +harness = false diff --git a/linfa-clustering/benches/dbscan.rs b/linfa-clustering/benches/dbscan.rs new file mode 100644 index 000000000..a22b25bdb --- /dev/null +++ b/linfa-clustering/benches/dbscan.rs @@ -0,0 +1,38 @@ +use criterion::{ + black_box, criterion_group, criterion_main, AxisScale, Criterion, ParameterizedBenchmark, + PlotConfiguration, +}; +use linfa_clustering::{dbscan, generate_blobs, DBScanHyperParams}; +use ndarray::Array2; +use ndarray_rand::rand::SeedableRng; +use ndarray_rand::rand_distr::Uniform; +use ndarray_rand::RandomExt; +use rand_isaac::Isaac64Rng; + +fn dbscan_bench(c: &mut Criterion) { + let mut rng = Isaac64Rng::seed_from_u64(40); + let cluster_sizes = vec![10, 100, 1000, 10000]; + + let benchmark = ParameterizedBenchmark::new( + "dbscan", + move |bencher, &cluster_size| { + let min_points = 4; + let n_features = 3; + let centroids = + Array2::random_using((min_points, n_features), Uniform::new(-30., 30.), &mut rng); + let dataset = generate_blobs(cluster_size, ¢roids, &mut rng); + let hyperparams = DBScanHyperParams::new(min_points).tolerance(1e-3).build(); + bencher.iter(|| black_box(dbscan(&hyperparams, &dataset))); + }, + cluster_sizes, + ) + .plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); + c.bench("dbscan", benchmark); +} + +criterion_group! { + name = benches; + config = Criterion::default(); + targets = dbscan_bench +} +criterion_main!(benches); diff --git a/linfa-clustering/src/dbscan/algorithm.rs b/linfa-clustering/src/dbscan/algorithm.rs index a75619e73..db69a97e8 100644 --- a/linfa-clustering/src/dbscan/algorithm.rs +++ b/linfa-clustering/src/dbscan/algorithm.rs @@ -2,7 +2,7 @@ use crate::dbscan::hyperparameters::DBScanHyperParams; use ndarray::{s, Array1, ArrayBase, Axis, Data, Ix1, Ix2}; use ndarray_stats::DeviationExt; -pub fn predict( +pub fn dbscan( hyperparameters: &DBScanHyperParams, observations: &ArrayBase, Ix2>, ) -> Array1> { @@ -85,7 +85,7 @@ mod tests { data.slice_mut(s![.., 40..]).fill(5.0); - let labels = predict(¶ms, &data); + let labels = dbscan(¶ms, &data); assert!(labels.slice(s![..40]).iter().all(|x| x == &Some(0))); assert!(labels.slice(s![40..]).iter().all(|x| x == &Some(1))); @@ -97,7 +97,7 @@ mod tests { let mut data: Array2 = Array2::zeros((2, 5)); data.slice_mut(s![.., 0]).assign(&arr1(&[10.0, 10.0])); - let labels = predict(¶ms, &data); + let labels = dbscan(¶ms, &data); let expected = arr1(&[None, Some(0), Some(0), Some(0), Some(0)]); assert_eq!(labels, expected); } @@ -108,7 +108,7 @@ mod tests { let data: Array2 = Array2::zeros((2, 3)); - let labels = predict(¶ms, &data); + let labels = dbscan(¶ms, &data); assert!(labels.iter().all(|x| x.is_none())); } } From 06d9368ce486ea708b718a821fa4c6b4758bfc60 Mon Sep 17 00:00:00 2001 From: xd009642 Date: Tue, 24 Dec 2019 17:06:03 +0000 Subject: [PATCH 07/11] Change find_neighbors for gains Change to return reference to the neighbour data as well to avoid lookups --- linfa-clustering/src/dbscan/algorithm.rs | 29 +++++++++++------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/linfa-clustering/src/dbscan/algorithm.rs b/linfa-clustering/src/dbscan/algorithm.rs index db69a97e8..10c0b8003 100644 --- a/linfa-clustering/src/dbscan/algorithm.rs +++ b/linfa-clustering/src/dbscan/algorithm.rs @@ -1,5 +1,5 @@ use crate::dbscan::hyperparameters::DBScanHyperParams; -use ndarray::{s, Array1, ArrayBase, Axis, Data, Ix1, Ix2}; +use ndarray::{Array1, ArrayBase, ArrayView, Axis, Data, Ix1, Ix2}; use ndarray_stats::DeviationExt; pub fn dbscan( @@ -20,24 +20,21 @@ pub fn dbscan( cluster_memberships[i] = Some(current_cluster_id); let mut search_queue = neighbors .iter() - .filter(|x| cluster_memberships[[**x]].is_none()) + .filter(|x| cluster_memberships[[x.0]].is_none()) .copied() .collect::>(); while !search_queue.is_empty() { let candidate = search_queue.remove(0); - cluster_memberships[candidate] = Some(current_cluster_id); + cluster_memberships[candidate.0] = Some(current_cluster_id); - let mut neighbors = find_neighbors( - &observations.slice(s![.., candidate]), - observations, - hyperparameters.tolerance(), - ) - .iter() - .filter(|x| cluster_memberships[[**x]].is_none() && !search_queue.contains(x)) - .copied() - .collect::>(); + let mut neighbors = + find_neighbors(&candidate.1, observations, hyperparameters.tolerance()) + .iter() + .filter(|x| cluster_memberships[[x.0]].is_none() && !search_queue.contains(x)) + .copied() + .collect::>(); search_queue.append(&mut neighbors); } @@ -46,15 +43,15 @@ pub fn dbscan( cluster_memberships } -fn find_neighbors( +fn find_neighbors<'a>( candidate: &ArrayBase, Ix1>, - observations: &ArrayBase, Ix2>, + observations: &'a ArrayBase, Ix2>, eps: f64, -) -> Vec { +) -> Vec<(usize, ArrayView<'a, f64, Ix1>)> { let mut res = vec![]; for (i, obs) in observations.axis_iter(Axis(1)).enumerate() { if candidate.l2_dist(&obs).unwrap() < eps { - res.push(i); + res.push((i, obs)); } } res From ca28c4fb072f427e55cb49a63f3cada64906296f Mon Sep 17 00:00:00 2001 From: xd009642 Date: Tue, 24 Dec 2019 18:51:34 +0000 Subject: [PATCH 08/11] Return search queue directly --- linfa-clustering/src/dbscan/algorithm.rs | 37 ++++++++++++++---------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/linfa-clustering/src/dbscan/algorithm.rs b/linfa-clustering/src/dbscan/algorithm.rs index 10c0b8003..feff5894a 100644 --- a/linfa-clustering/src/dbscan/algorithm.rs +++ b/linfa-clustering/src/dbscan/algorithm.rs @@ -12,29 +12,29 @@ pub fn dbscan( if cluster_memberships[i].is_some() { continue; } - let neighbors = find_neighbors(&obs, observations, hyperparameters.tolerance()); - if neighbors.len() < hyperparameters.minimum_points() { + let (neighbor_count, mut search_queue) = find_neighbors( + &obs, + observations, + hyperparameters.tolerance(), + &cluster_memberships, + ); + if neighbor_count < hyperparameters.minimum_points() { continue; } // Now go over the neighbours adding them to the cluster cluster_memberships[i] = Some(current_cluster_id); - let mut search_queue = neighbors - .iter() - .filter(|x| cluster_memberships[[x.0]].is_none()) - .copied() - .collect::>(); while !search_queue.is_empty() { let candidate = search_queue.remove(0); cluster_memberships[candidate.0] = Some(current_cluster_id); - let mut neighbors = - find_neighbors(&candidate.1, observations, hyperparameters.tolerance()) - .iter() - .filter(|x| cluster_memberships[[x.0]].is_none() && !search_queue.contains(x)) - .copied() - .collect::>(); + let (_, mut neighbors) = find_neighbors( + &candidate.1, + observations, + hyperparameters.tolerance(), + &cluster_memberships, + ); search_queue.append(&mut neighbors); } @@ -47,14 +47,19 @@ fn find_neighbors<'a>( candidate: &ArrayBase, Ix1>, observations: &'a ArrayBase, Ix2>, eps: f64, -) -> Vec<(usize, ArrayView<'a, f64, Ix1>)> { + clusters: &Array1>, +) -> (usize, Vec<(usize, ArrayView<'a, f64, Ix1>)>) { let mut res = vec![]; + let mut count = 0; for (i, obs) in observations.axis_iter(Axis(1)).enumerate() { if candidate.l2_dist(&obs).unwrap() < eps { - res.push((i, obs)); + count += 1; + if clusters[[i]].is_none() { + res.push((i, obs)); + } } } - res + (count, res) } #[cfg(test)] From 8c924869af321c225138359357c08b0b7f7f93e7 Mon Sep 17 00:00:00 2001 From: xd009642 Date: Tue, 24 Dec 2019 19:34:30 +0000 Subject: [PATCH 09/11] Move to zip in find_neighbors --- linfa-clustering/src/dbscan/algorithm.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/linfa-clustering/src/dbscan/algorithm.rs b/linfa-clustering/src/dbscan/algorithm.rs index feff5894a..7f90eb8bc 100644 --- a/linfa-clustering/src/dbscan/algorithm.rs +++ b/linfa-clustering/src/dbscan/algorithm.rs @@ -51,10 +51,14 @@ fn find_neighbors<'a>( ) -> (usize, Vec<(usize, ArrayView<'a, f64, Ix1>)>) { let mut res = vec![]; let mut count = 0; - for (i, obs) in observations.axis_iter(Axis(1)).enumerate() { + for (i, (obs, cluster)) in observations + .axis_iter(Axis(1)) + .zip(clusters.iter()) + .enumerate() + { if candidate.l2_dist(&obs).unwrap() < eps { count += 1; - if clusters[[i]].is_none() { + if cluster.is_none() { res.push((i, obs)); } } From 384de08f62b37e10edda6faff3b92139090c90ed Mon Sep 17 00:00:00 2001 From: xd009642 Date: Thu, 26 Dec 2019 12:20:38 +0000 Subject: [PATCH 10/11] Added examples, docs, Dbscan struct Rename DBScan to Dbscan because the whole thing is an acronym --- linfa-clustering/benches/dbscan.rs | 6 +- linfa-clustering/examples/dbscan.rs | 33 ++++ .../examples/{main.rs => kmeans.rs} | 0 linfa-clustering/src/dbscan/algorithm.rs | 150 +++++++++++++----- .../src/dbscan/hyperparameters.rs | 18 +-- 5 files changed, 154 insertions(+), 53 deletions(-) create mode 100644 linfa-clustering/examples/dbscan.rs rename linfa-clustering/examples/{main.rs => kmeans.rs} (100%) diff --git a/linfa-clustering/benches/dbscan.rs b/linfa-clustering/benches/dbscan.rs index a22b25bdb..d3e421c94 100644 --- a/linfa-clustering/benches/dbscan.rs +++ b/linfa-clustering/benches/dbscan.rs @@ -2,7 +2,7 @@ use criterion::{ black_box, criterion_group, criterion_main, AxisScale, Criterion, ParameterizedBenchmark, PlotConfiguration, }; -use linfa_clustering::{dbscan, generate_blobs, DBScanHyperParams}; +use linfa_clustering::{generate_blobs, Dbscan, DbscanHyperParams}; use ndarray::Array2; use ndarray_rand::rand::SeedableRng; use ndarray_rand::rand_distr::Uniform; @@ -21,8 +21,8 @@ fn dbscan_bench(c: &mut Criterion) { let centroids = Array2::random_using((min_points, n_features), Uniform::new(-30., 30.), &mut rng); let dataset = generate_blobs(cluster_size, ¢roids, &mut rng); - let hyperparams = DBScanHyperParams::new(min_points).tolerance(1e-3).build(); - bencher.iter(|| black_box(dbscan(&hyperparams, &dataset))); + let hyperparams = DbscanHyperParams::new(min_points).tolerance(1e-3).build(); + bencher.iter(|| black_box(Dbscan::predict(&hyperparams, &dataset))); }, cluster_sizes, ) diff --git a/linfa-clustering/examples/dbscan.rs b/linfa-clustering/examples/dbscan.rs new file mode 100644 index 000000000..06a3e78c4 --- /dev/null +++ b/linfa-clustering/examples/dbscan.rs @@ -0,0 +1,33 @@ +use linfa_clustering::{generate_blobs, Dbscan, DbscanHyperParams}; +use ndarray::array; +use ndarray_npy::write_npy; +use ndarray_rand::rand::SeedableRng; +use rand_isaac::Isaac64Rng; + +// A routine DBScan task: build a synthetic dataset, predict clusters for it +// and save both training data and predictions to disk. +fn main() { + // Our random number generator, seeded for reproducibility + let mut rng = Isaac64Rng::seed_from_u64(42); + + // For each our expected centroids, generate `n` data points around it (a "blob") + let expected_centroids = array![[10., 10.], [1., 12.], [20., 30.], [-20., 30.],]; + let n = 10000; + let dataset = generate_blobs(n, &expected_centroids, &mut rng); + + // Configure our training algorithm + let min_points = 3; + let hyperparams = DbscanHyperParams::new(min_points).tolerance(1e-5).build(); + + // Infer an optimal set of centroids based on the training data distribution + let cluster_memberships = Dbscan::predict(&hyperparams, &dataset); + + // Save to disk our dataset (and the cluster label assigned to each observation) + // We use the `npy` format for compatibility with NumPy + write_npy("clustered_dataset.npy", dataset).expect("Failed to write .npy file"); + write_npy( + "clustered_memberships.npy", + cluster_memberships.map(|&x| x.map(|c| c as i64).unwrap_or(-1)), + ) + .expect("Failed to write .npy file"); +} diff --git a/linfa-clustering/examples/main.rs b/linfa-clustering/examples/kmeans.rs similarity index 100% rename from linfa-clustering/examples/main.rs rename to linfa-clustering/examples/kmeans.rs diff --git a/linfa-clustering/src/dbscan/algorithm.rs b/linfa-clustering/src/dbscan/algorithm.rs index 7f90eb8bc..f46f766de 100644 --- a/linfa-clustering/src/dbscan/algorithm.rs +++ b/linfa-clustering/src/dbscan/algorithm.rs @@ -1,46 +1,114 @@ -use crate::dbscan::hyperparameters::DBScanHyperParams; +use crate::dbscan::hyperparameters::DbscanHyperParams; use ndarray::{Array1, ArrayBase, ArrayView, Axis, Data, Ix1, Ix2}; use ndarray_stats::DeviationExt; - -pub fn dbscan( - hyperparameters: &DBScanHyperParams, - observations: &ArrayBase, Ix2>, -) -> Array1> { - let mut cluster_memberships = Array1::from_elem(observations.dim().1, None); - let mut current_cluster_id = 0; - for (i, obs) in observations.axis_iter(Axis(1)).enumerate() { - if cluster_memberships[i].is_some() { - continue; - } - let (neighbor_count, mut search_queue) = find_neighbors( - &obs, - observations, - hyperparameters.tolerance(), - &cluster_memberships, - ); - if neighbor_count < hyperparameters.minimum_points() { - continue; - } - // Now go over the neighbours adding them to the cluster - cluster_memberships[i] = Some(current_cluster_id); - - while !search_queue.is_empty() { - let candidate = search_queue.remove(0); - - cluster_memberships[candidate.0] = Some(current_cluster_id); - - let (_, mut neighbors) = find_neighbors( - &candidate.1, +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +/// DBSCAN (Density-based Spatial Clustering of Applications with Noise) +/// clusters together points which are close together with enough neighbors +/// labelled points which are sparsely neighbored as noise. As points may be +/// part of a cluster or noise the predict method returns +/// `Array1>` +/// +/// As it groups together points in dense regions the number of clusters is +/// determined by the dataset and distance tolerance not the user. +/// +/// We provide an implemention of the standard O(N^2) query-based algorithm +/// of which more details can be found in the next section or +/// [here](https://en.wikipedia.org/wiki/DBSCAN). +/// +/// The standard DBSCAN algorithm isn't iterative and therefore there's +/// no fit method provided only predict. +/// +/// ## The algorithm +/// +/// The algorithm iterates over each point in the dataset and for every point +/// not yet assigned to a cluster: +/// - Find all points within the neighborhood of size `tolerance` +/// - If the number of points in the neighborhood is below a minimum size label +/// as noise +/// - Otherwise label the point with the cluster ID and repeat with each of the +/// neighbours +/// +/// ## Tutorial +/// +/// Let's do a walkthrough of an example running DBSCAN on some data. +/// +/// ``` +/// use linfa_clustering::{DbscanHyperParams, Dbscan, generate_blobs}; +/// use ndarray::{Axis, array, s}; +/// use ndarray_rand::rand::SeedableRng; +/// use rand_isaac::Isaac64Rng; +/// use approx::assert_abs_diff_eq; +/// +/// // Our random number generator, seeded for reproducibility +/// let seed = 42; +/// let mut rng = Isaac64Rng::seed_from_u64(seed); +/// +/// // `expected_centroids` has shape `(n_centroids, n_features)` +/// // i.e. three points in the 2-dimensional plane +/// let expected_centroids = array![[0., 1.], [-10., 20.], [-1., 10.]]; +/// // Let's generate a synthetic dataset: three blobs of observations +/// // (100 points each) centered around our `expected_centroids` +/// let observations = generate_blobs(100, &expected_centroids, &mut rng); +/// +/// // Let's configure and run our DBSCAN algorithm +/// // We use the builder pattern to specify the hyperparameters +/// // `min_points` is the only mandatory parameter. +/// // If you don't specify the others (e.g. `tolerance`) +/// // default values will be used. +/// let min_points = 3; +/// let hyperparams = DbscanHyperParams::new(min_points) +/// .tolerance(1e-2) +/// .build(); +/// // Let's run the algorithm! +/// let clusters = Dbscan::predict(&hyperparams, &observations); +/// // Points are `None` if noise `Some(id)` if belonging to a cluster. +/// ``` +/// +pub struct Dbscan; + +impl Dbscan { + pub fn predict( + hyperparameters: &DbscanHyperParams, + observations: &ArrayBase, Ix2>, + ) -> Array1> { + let mut cluster_memberships = Array1::from_elem(observations.dim().1, None); + let mut current_cluster_id = 0; + for (i, obs) in observations.axis_iter(Axis(1)).enumerate() { + if cluster_memberships[i].is_some() { + continue; + } + let (neighbor_count, mut search_queue) = find_neighbors( + &obs, observations, hyperparameters.tolerance(), &cluster_memberships, ); - - search_queue.append(&mut neighbors); + if neighbor_count < hyperparameters.minimum_points() { + continue; + } + // Now go over the neighbours adding them to the cluster + cluster_memberships[i] = Some(current_cluster_id); + + while !search_queue.is_empty() { + let candidate = search_queue.remove(0); + + let (neighbor_count, mut neighbors) = find_neighbors( + &candidate.1, + observations, + hyperparameters.tolerance(), + &cluster_memberships, + ); + if neighbor_count >= hyperparameters.minimum_points() { + cluster_memberships[candidate.0] = Some(current_cluster_id); + search_queue.append(&mut neighbors); + } + } + current_cluster_id += 1; } - current_cluster_id += 1; + cluster_memberships } - cluster_memberships } fn find_neighbors<'a>( @@ -75,7 +143,7 @@ mod tests { fn nested_clusters() { // Create a circuit of points and then a cluster in the centre // and ensure they are identified as two separate clusters - let params = DBScanHyperParams::new(4).tolerance(1.0).build(); + let params = DbscanHyperParams::new(2).tolerance(1.0).build(); let mut data: Array2 = Array2::zeros((2, 50)); let rising = Array1::linspace(0.0, 8.0, 10); @@ -91,7 +159,7 @@ mod tests { data.slice_mut(s![.., 40..]).fill(5.0); - let labels = dbscan(¶ms, &data); + let labels = Dbscan::predict(¶ms, &data); assert!(labels.slice(s![..40]).iter().all(|x| x == &Some(0))); assert!(labels.slice(s![40..]).iter().all(|x| x == &Some(1))); @@ -99,22 +167,22 @@ mod tests { #[test] fn non_cluster_points() { - let params = DBScanHyperParams::new(4).build(); + let params = DbscanHyperParams::new(4).build(); let mut data: Array2 = Array2::zeros((2, 5)); data.slice_mut(s![.., 0]).assign(&arr1(&[10.0, 10.0])); - let labels = dbscan(¶ms, &data); + let labels = Dbscan::predict(¶ms, &data); let expected = arr1(&[None, Some(0), Some(0), Some(0), Some(0)]); assert_eq!(labels, expected); } #[test] fn dataset_too_small() { - let params = DBScanHyperParams::new(4).build(); + let params = DbscanHyperParams::new(4).build(); let data: Array2 = Array2::zeros((2, 3)); - let labels = dbscan(¶ms, &data); + let labels = Dbscan::predict(¶ms, &data); assert!(labels.iter().all(|x| x.is_none())); } } diff --git a/linfa-clustering/src/dbscan/hyperparameters.rs b/linfa-clustering/src/dbscan/hyperparameters.rs index 6bb9ed6e0..65928bec2 100644 --- a/linfa-clustering/src/dbscan/hyperparameters.rs +++ b/linfa-clustering/src/dbscan/hyperparameters.rs @@ -2,8 +2,8 @@ use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] /// The set of hyperparameters that can be specified for the execution of -/// the [DBSCAN algorithm](struct.DBScan.html). -pub struct DBScanHyperParams { +/// the [DBSCAN algorithm](struct.Dbscan.html). +pub struct DbscanHyperParams { /// Distance between points for them to be considered neighbours. tolerance: f64, /// Minimum number of neighboring points a point needs to have to be a core @@ -12,25 +12,25 @@ pub struct DBScanHyperParams { } /// Helper struct used to construct a set of hyperparameters for -pub struct DBScanHyperParamsBuilder { +pub struct DbscanHyperParamsBuilder { tolerance: f64, min_points: usize, } -impl DBScanHyperParamsBuilder { +impl DbscanHyperParamsBuilder { pub fn tolerance(mut self, tolerance: f64) -> Self { self.tolerance = tolerance; self } - pub fn build(self) -> DBScanHyperParams { - DBScanHyperParams::build(self.tolerance, self.min_points) + pub fn build(self) -> DbscanHyperParams { + DbscanHyperParams::build(self.tolerance, self.min_points) } } -impl DBScanHyperParams { - pub fn new(min_points: usize) -> DBScanHyperParamsBuilder { - DBScanHyperParamsBuilder { +impl DbscanHyperParams { + pub fn new(min_points: usize) -> DbscanHyperParamsBuilder { + DbscanHyperParamsBuilder { min_points, tolerance: 1e-4, } From d7f8236c8a183c368cf4c38567827f45773b32a2 Mon Sep 17 00:00:00 2001 From: xd009642 Date: Thu, 26 Dec 2019 14:31:40 +0000 Subject: [PATCH 11/11] Update doc comments for HyperParams --- linfa-clustering/src/dbscan/hyperparameters.rs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/linfa-clustering/src/dbscan/hyperparameters.rs b/linfa-clustering/src/dbscan/hyperparameters.rs index 65928bec2..5ceaeaa3b 100644 --- a/linfa-clustering/src/dbscan/hyperparameters.rs +++ b/linfa-clustering/src/dbscan/hyperparameters.rs @@ -18,17 +18,27 @@ pub struct DbscanHyperParamsBuilder { } impl DbscanHyperParamsBuilder { + /// Distance between points for them to be considered neighbours. pub fn tolerance(mut self, tolerance: f64) -> Self { self.tolerance = tolerance; self } + /// Return an instance of `DbscanHyperParams` after having performed + /// validation checks on all hyperparameters. + /// + /// **Panics** if any of the validation checks fail. pub fn build(self) -> DbscanHyperParams { DbscanHyperParams::build(self.tolerance, self.min_points) } } impl DbscanHyperParams { + /// Minimum number of neighboring points a point needs to have to be a core + /// point and not a noise point. + /// + /// Defaults are provided if the optional parameters are not specified: + /// * `tolerance = 1e-4` pub fn new(min_points: usize) -> DbscanHyperParamsBuilder { DbscanHyperParamsBuilder { min_points, @@ -36,10 +46,14 @@ impl DbscanHyperParams { } } + /// Two points are considered neighbors if the euclidean distance between + /// them is below the tolerance pub fn tolerance(&self) -> f64 { self.tolerance } + /// Minimum number of a points in a neighborhood around a point for it to + /// not be considered noise pub fn minimum_points(&self) -> usize { self.min_points }