Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Initial commit of DBScan #12

Merged
merged 11 commits into from
Dec 27, 2019
108 changes: 108 additions & 0 deletions linfa-clustering/src/dbscan/algorithm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
use crate::dbscan::hyperparameters::DBScanHyperParams;
use ndarray::{s, Array1, ArrayBase, Axis, Data, Ix1, Ix2};
use ndarray_stats::DeviationExt;

pub fn predict(
hyperparameters: &DBScanHyperParams,
observations: &ArrayBase<impl Data<Elem = f64>, Ix2>,
) -> Array1<Option<usize>> {
let mut result = Array1::from_elem(observations.dim().1, None);
let mut latest_id = 0;
for (i, obs) in observations.axis_iter(Axis(1)).enumerate() {
if result[i].is_some() {
continue;
}
let n = find_neighbors(&obs, observations, hyperparameters.tolerance());
if n.len() < hyperparameters.minimum_points() {
continue;
}
// Now go over the neighbours adding them to the cluster
LukeMathWalker marked this conversation as resolved.
Show resolved Hide resolved
let mut search_queue = n
.iter()
.filter(|x| result[[**x]].is_none())
.copied()
.collect::<Vec<_>>();
while !search_queue.is_empty() {
let cand = search_queue.remove(0);

result[cand] = Some(latest_id);

let mut n = find_neighbors(
&observations.slice(s![.., cand]),
observations,
hyperparameters.tolerance(),
)
.iter()
.filter(|x| result[[**x]].is_none() && !search_queue.contains(x))
.copied()
.collect::<Vec<_>>();

search_queue.append(&mut n);
}
latest_id += 1;
}
result
}
LukeMathWalker marked this conversation as resolved.
Show resolved Hide resolved

fn find_neighbors(
candidate: &ArrayBase<impl Data<Elem = f64>, Ix1>,
observations: &ArrayBase<impl Data<Elem = f64>, Ix2>,
eps: f64,
) -> Vec<usize> {
let mut res = vec![];
for (i, obs) in observations.axis_iter(Axis(1)).enumerate() {
if candidate.l2_dist(&obs).unwrap() < eps {
res.push(i);
}
}
res
}

#[cfg(test)]
mod tests {
use super::*;
use ndarray::{arr1, s, Array2};

#[test]
fn nested_clusters() {
// Create a circuit of points and then a cluster in the centre
// and ensure they are identified as two separate clusters
let params = DBScanHyperParams::new(4).tolerance(1.0).build();

let mut data: Array2<f64> = Array2::zeros((2, 50));
let rising = Array1::linspace(0.0, 8.0, 10);
data.slice_mut(s![0, 0..10]).assign(&rising);
data.slice_mut(s![0, 10..20]).assign(&rising);
data.slice_mut(s![1, 20..30]).assign(&rising);
data.slice_mut(s![1, 30..40]).assign(&rising);

data.slice_mut(s![1, 0..10]).fill(0.0);
data.slice_mut(s![1, 10..20]).fill(8.0);
data.slice_mut(s![0, 20..30]).fill(0.0);
data.slice_mut(s![0, 30..40]).fill(8.0);

data.slice_mut(s![.., 40..]).fill(5.0);

let labels = predict(&params, &data);

assert!(labels.slice(s![..40]).iter().all(|x| x == &Some(0)));
assert!(labels.slice(s![40..]).iter().all(|x| x == &Some(1)));
}

#[test]
fn non_cluster_points() {
let params = DBScanHyperParams::new(4).build();

let data: Array2<f64> = Array2::zeros((2, 3));

let labels = predict(&params, &data);
assert!(labels.iter().all(|x| x.is_none()));
LukeMathWalker marked this conversation as resolved.
Show resolved Hide resolved

let mut data: Array2<f64> = Array2::zeros((2, 5));
data.slice_mut(s![.., 0]).assign(&arr1(&[10.0, 10.0]));

let labels = predict(&params, &data);
let expected = arr1(&[None, Some(0), Some(0), Some(0), Some(0)]);
assert_eq!(labels, expected);
}
}
60 changes: 60 additions & 0 deletions linfa-clustering/src/dbscan/hyperparameters.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
use serde::{Deserialize, Serialize};

#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
/// The set of hyperparameters that can be specified for the execution of
/// the [DBSCAN algorithm](struct.DBScan.html).
pub struct DBScanHyperParams {
/// Distance between points for them to be considered neighbours.
tolerance: f64,
LukeMathWalker marked this conversation as resolved.
Show resolved Hide resolved
/// Minimum number of neighboring points a point needs to have to be a core
/// point and not a noise point.
min_points: usize,
}

/// Helper struct used to construct a set of hyperparameters for
pub struct DBScanHyperParamsBuilder {
tolerance: f64,
min_points: usize,
}

impl DBScanHyperParamsBuilder {
pub fn tolerance(mut self, tolerance: f64) -> Self {
self.tolerance = tolerance;
self
}

pub fn build(self) -> DBScanHyperParams {
DBScanHyperParams::build(self.tolerance, self.min_points)
}
}

impl DBScanHyperParams {
pub fn new(min_points: usize) -> DBScanHyperParamsBuilder {
DBScanHyperParamsBuilder {
min_points,
tolerance: 1e-4,
}
}

pub fn tolerance(&self) -> f64 {
self.tolerance
}

pub fn minimum_points(&self) -> usize {
self.min_points
}

fn build(tolerance: f64, min_points: usize) -> Self {
if tolerance <= 0. {
panic!("`tolerance` must be greater than 0!");
}
// There is always at least one neighbor to a point (itself)
if min_points <= 1 {
panic!("`min_points` must be greater than 1!");
}
Self {
tolerance,
min_points,
}
}
}
5 changes: 5 additions & 0 deletions linfa-clustering/src/dbscan/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
mod algorithm;
mod hyperparameters;

pub use algorithm::*;
pub use hyperparameters::*;
2 changes: 2 additions & 0 deletions linfa-clustering/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
//! Implementation choices, algorithmic details and a tutorial can be found [here](struct.KMeans.html).
//!
//! Check [here](https://github.com/LukeMathWalker/clustering-benchmarks) for extensive benchmarks against `scikit-learn`'s K-means implementation.
mod dbscan;
#[allow(clippy::new_ret_no_self)]
mod k_means;
mod utils;

pub use dbscan::*;
pub use k_means::*;
pub use utils::*;