Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Initial commit of DBScan #12

Merged
merged 11 commits into from
Dec 27, 2019
4 changes: 4 additions & 0 deletions linfa-clustering/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,7 @@ approx = "0.3"
[[bench]]
name = "k_means"
harness = false

[[bench]]
name = "dbscan"
harness = false
38 changes: 38 additions & 0 deletions linfa-clustering/benches/dbscan.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use criterion::{
black_box, criterion_group, criterion_main, AxisScale, Criterion, ParameterizedBenchmark,
PlotConfiguration,
};
use linfa_clustering::{dbscan, generate_blobs, DBScanHyperParams};
use ndarray::Array2;
use ndarray_rand::rand::SeedableRng;
use ndarray_rand::rand_distr::Uniform;
use ndarray_rand::RandomExt;
use rand_isaac::Isaac64Rng;

fn dbscan_bench(c: &mut Criterion) {
let mut rng = Isaac64Rng::seed_from_u64(40);
let cluster_sizes = vec![10, 100, 1000, 10000];

let benchmark = ParameterizedBenchmark::new(
"dbscan",
move |bencher, &cluster_size| {
let min_points = 4;
let n_features = 3;
let centroids =
Array2::random_using((min_points, n_features), Uniform::new(-30., 30.), &mut rng);
let dataset = generate_blobs(cluster_size, &centroids, &mut rng);
let hyperparams = DBScanHyperParams::new(min_points).tolerance(1e-3).build();
bencher.iter(|| black_box(dbscan(&hyperparams, &dataset)));
},
cluster_sizes,
)
.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic));
c.bench("dbscan", benchmark);
}

criterion_group! {
name = benches;
config = Criterion::default();
targets = dbscan_bench
}
criterion_main!(benches);
114 changes: 114 additions & 0 deletions linfa-clustering/src/dbscan/algorithm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
use crate::dbscan::hyperparameters::DBScanHyperParams;
use ndarray::{s, Array1, ArrayBase, Axis, Data, Ix1, Ix2};
use ndarray_stats::DeviationExt;

pub fn dbscan(
hyperparameters: &DBScanHyperParams,
observations: &ArrayBase<impl Data<Elem = f64>, Ix2>,
) -> Array1<Option<usize>> {
let mut cluster_memberships = Array1::from_elem(observations.dim().1, None);
let mut current_cluster_id = 0;
for (i, obs) in observations.axis_iter(Axis(1)).enumerate() {
if cluster_memberships[i].is_some() {
continue;
}
let neighbors = find_neighbors(&obs, observations, hyperparameters.tolerance());
if neighbors.len() < hyperparameters.minimum_points() {
continue;
}
// Now go over the neighbours adding them to the cluster
LukeMathWalker marked this conversation as resolved.
Show resolved Hide resolved
cluster_memberships[i] = Some(current_cluster_id);
let mut search_queue = neighbors
.iter()
.filter(|x| cluster_memberships[[**x]].is_none())
LukeMathWalker marked this conversation as resolved.
Show resolved Hide resolved
.copied()
.collect::<Vec<_>>();

while !search_queue.is_empty() {
let candidate = search_queue.remove(0);

cluster_memberships[candidate] = Some(current_cluster_id);

let mut neighbors = find_neighbors(
&observations.slice(s![.., candidate]),
observations,
hyperparameters.tolerance(),
)
.iter()
.filter(|x| cluster_memberships[[**x]].is_none() && !search_queue.contains(x))
.copied()
.collect::<Vec<_>>();

search_queue.append(&mut neighbors);
}
current_cluster_id += 1;
}
cluster_memberships
}
LukeMathWalker marked this conversation as resolved.
Show resolved Hide resolved

fn find_neighbors(
candidate: &ArrayBase<impl Data<Elem = f64>, Ix1>,
observations: &ArrayBase<impl Data<Elem = f64>, Ix2>,
eps: f64,
) -> Vec<usize> {
let mut res = vec![];
for (i, obs) in observations.axis_iter(Axis(1)).enumerate() {
if candidate.l2_dist(&obs).unwrap() < eps {
res.push(i);
}
}
res
}

#[cfg(test)]
mod tests {
use super::*;
use ndarray::{arr1, s, Array2};

#[test]
fn nested_clusters() {
// Create a circuit of points and then a cluster in the centre
// and ensure they are identified as two separate clusters
let params = DBScanHyperParams::new(4).tolerance(1.0).build();

let mut data: Array2<f64> = Array2::zeros((2, 50));
let rising = Array1::linspace(0.0, 8.0, 10);
data.slice_mut(s![0, 0..10]).assign(&rising);
data.slice_mut(s![0, 10..20]).assign(&rising);
data.slice_mut(s![1, 20..30]).assign(&rising);
data.slice_mut(s![1, 30..40]).assign(&rising);

data.slice_mut(s![1, 0..10]).fill(0.0);
data.slice_mut(s![1, 10..20]).fill(8.0);
data.slice_mut(s![0, 20..30]).fill(0.0);
data.slice_mut(s![0, 30..40]).fill(8.0);

data.slice_mut(s![.., 40..]).fill(5.0);

let labels = dbscan(&params, &data);

assert!(labels.slice(s![..40]).iter().all(|x| x == &Some(0)));
assert!(labels.slice(s![40..]).iter().all(|x| x == &Some(1)));
}

#[test]
fn non_cluster_points() {
let params = DBScanHyperParams::new(4).build();
let mut data: Array2<f64> = Array2::zeros((2, 5));
data.slice_mut(s![.., 0]).assign(&arr1(&[10.0, 10.0]));

let labels = dbscan(&params, &data);
let expected = arr1(&[None, Some(0), Some(0), Some(0), Some(0)]);
assert_eq!(labels, expected);
}

#[test]
fn dataset_too_small() {
let params = DBScanHyperParams::new(4).build();

let data: Array2<f64> = Array2::zeros((2, 3));

let labels = dbscan(&params, &data);
assert!(labels.iter().all(|x| x.is_none()));
}
}
60 changes: 60 additions & 0 deletions linfa-clustering/src/dbscan/hyperparameters.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
use serde::{Deserialize, Serialize};

#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
/// The set of hyperparameters that can be specified for the execution of
/// the [DBSCAN algorithm](struct.DBScan.html).
pub struct DBScanHyperParams {
/// Distance between points for them to be considered neighbours.
tolerance: f64,
LukeMathWalker marked this conversation as resolved.
Show resolved Hide resolved
/// Minimum number of neighboring points a point needs to have to be a core
/// point and not a noise point.
min_points: usize,
}

/// Helper struct used to construct a set of hyperparameters for
pub struct DBScanHyperParamsBuilder {
tolerance: f64,
min_points: usize,
}

impl DBScanHyperParamsBuilder {
pub fn tolerance(mut self, tolerance: f64) -> Self {
self.tolerance = tolerance;
self
}

pub fn build(self) -> DBScanHyperParams {
DBScanHyperParams::build(self.tolerance, self.min_points)
}
}

impl DBScanHyperParams {
pub fn new(min_points: usize) -> DBScanHyperParamsBuilder {
DBScanHyperParamsBuilder {
min_points,
tolerance: 1e-4,
}
}

pub fn tolerance(&self) -> f64 {
self.tolerance
}

pub fn minimum_points(&self) -> usize {
self.min_points
}

fn build(tolerance: f64, min_points: usize) -> Self {
if tolerance <= 0. {
panic!("`tolerance` must be greater than 0!");
}
// There is always at least one neighbor to a point (itself)
if min_points <= 1 {
panic!("`min_points` must be greater than 1!");
}
Self {
tolerance,
min_points,
}
}
}
5 changes: 5 additions & 0 deletions linfa-clustering/src/dbscan/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
mod algorithm;
mod hyperparameters;

pub use algorithm::*;
pub use hyperparameters::*;
2 changes: 2 additions & 0 deletions linfa-clustering/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
//! Implementation choices, algorithmic details and a tutorial can be found [here](struct.KMeans.html).
//!
//! Check [here](https://github.com/LukeMathWalker/clustering-benchmarks) for extensive benchmarks against `scikit-learn`'s K-means implementation.
mod dbscan;
#[allow(clippy::new_ret_no_self)]
mod k_means;
mod utils;

pub use dbscan::*;
pub use k_means::*;
pub use utils::*;