Skip to content

Commit

Permalink
Clustering: DBSCAN (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
xd009642 authored and LukeMathWalker committed Dec 27, 2019
1 parent 370671c commit 6fc1f69
Show file tree
Hide file tree
Showing 8 changed files with 344 additions and 0 deletions.
4 changes: 4 additions & 0 deletions linfa-clustering/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,7 @@ approx = "0.3"
[[bench]]
name = "k_means"
harness = false

[[bench]]
name = "dbscan"
harness = false
38 changes: 38 additions & 0 deletions linfa-clustering/benches/dbscan.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use criterion::{
black_box, criterion_group, criterion_main, AxisScale, Criterion, ParameterizedBenchmark,
PlotConfiguration,
};
use linfa_clustering::{generate_blobs, Dbscan, DbscanHyperParams};
use ndarray::Array2;
use ndarray_rand::rand::SeedableRng;
use ndarray_rand::rand_distr::Uniform;
use ndarray_rand::RandomExt;
use rand_isaac::Isaac64Rng;

fn dbscan_bench(c: &mut Criterion) {
let mut rng = Isaac64Rng::seed_from_u64(40);
let cluster_sizes = vec![10, 100, 1000, 10000];

let benchmark = ParameterizedBenchmark::new(
"dbscan",
move |bencher, &cluster_size| {
let min_points = 4;
let n_features = 3;
let centroids =
Array2::random_using((min_points, n_features), Uniform::new(-30., 30.), &mut rng);
let dataset = generate_blobs(cluster_size, &centroids, &mut rng);
let hyperparams = DbscanHyperParams::new(min_points).tolerance(1e-3).build();
bencher.iter(|| black_box(Dbscan::predict(&hyperparams, &dataset)));
},
cluster_sizes,
)
.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic));
c.bench("dbscan", benchmark);
}

criterion_group! {
name = benches;
config = Criterion::default();
targets = dbscan_bench
}
criterion_main!(benches);
33 changes: 33 additions & 0 deletions linfa-clustering/examples/dbscan.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use linfa_clustering::{generate_blobs, Dbscan, DbscanHyperParams};
use ndarray::array;
use ndarray_npy::write_npy;
use ndarray_rand::rand::SeedableRng;
use rand_isaac::Isaac64Rng;

// A routine DBScan task: build a synthetic dataset, predict clusters for it
// and save both training data and predictions to disk.
fn main() {
// Our random number generator, seeded for reproducibility
let mut rng = Isaac64Rng::seed_from_u64(42);

// For each our expected centroids, generate `n` data points around it (a "blob")
let expected_centroids = array![[10., 10.], [1., 12.], [20., 30.], [-20., 30.],];
let n = 10000;
let dataset = generate_blobs(n, &expected_centroids, &mut rng);

// Configure our training algorithm
let min_points = 3;
let hyperparams = DbscanHyperParams::new(min_points).tolerance(1e-5).build();

// Infer an optimal set of centroids based on the training data distribution
let cluster_memberships = Dbscan::predict(&hyperparams, &dataset);

// Save to disk our dataset (and the cluster label assigned to each observation)
// We use the `npy` format for compatibility with NumPy
write_npy("clustered_dataset.npy", dataset).expect("Failed to write .npy file");
write_npy(
"clustered_memberships.npy",
cluster_memberships.map(|&x| x.map(|c| c as i64).unwrap_or(-1)),
)
.expect("Failed to write .npy file");
}
File renamed without changes.
188 changes: 188 additions & 0 deletions linfa-clustering/src/dbscan/algorithm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
use crate::dbscan::hyperparameters::DbscanHyperParams;
use ndarray::{Array1, ArrayBase, ArrayView, Axis, Data, Ix1, Ix2};
use ndarray_stats::DeviationExt;
use serde::{Deserialize, Serialize};

#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
/// DBSCAN (Density-based Spatial Clustering of Applications with Noise)
/// clusters together points which are close together with enough neighbors
/// labelled points which are sparsely neighbored as noise. As points may be
/// part of a cluster or noise the predict method returns
/// `Array1<Option<usize>>`
///
/// As it groups together points in dense regions the number of clusters is
/// determined by the dataset and distance tolerance not the user.
///
/// We provide an implemention of the standard O(N^2) query-based algorithm
/// of which more details can be found in the next section or
/// [here](https://en.wikipedia.org/wiki/DBSCAN).
///
/// The standard DBSCAN algorithm isn't iterative and therefore there's
/// no fit method provided only predict.
///
/// ## The algorithm
///
/// The algorithm iterates over each point in the dataset and for every point
/// not yet assigned to a cluster:
/// - Find all points within the neighborhood of size `tolerance`
/// - If the number of points in the neighborhood is below a minimum size label
/// as noise
/// - Otherwise label the point with the cluster ID and repeat with each of the
/// neighbours
///
/// ## Tutorial
///
/// Let's do a walkthrough of an example running DBSCAN on some data.
///
/// ```
/// use linfa_clustering::{DbscanHyperParams, Dbscan, generate_blobs};
/// use ndarray::{Axis, array, s};
/// use ndarray_rand::rand::SeedableRng;
/// use rand_isaac::Isaac64Rng;
/// use approx::assert_abs_diff_eq;
///
/// // Our random number generator, seeded for reproducibility
/// let seed = 42;
/// let mut rng = Isaac64Rng::seed_from_u64(seed);
///
/// // `expected_centroids` has shape `(n_centroids, n_features)`
/// // i.e. three points in the 2-dimensional plane
/// let expected_centroids = array![[0., 1.], [-10., 20.], [-1., 10.]];
/// // Let's generate a synthetic dataset: three blobs of observations
/// // (100 points each) centered around our `expected_centroids`
/// let observations = generate_blobs(100, &expected_centroids, &mut rng);
///
/// // Let's configure and run our DBSCAN algorithm
/// // We use the builder pattern to specify the hyperparameters
/// // `min_points` is the only mandatory parameter.
/// // If you don't specify the others (e.g. `tolerance`)
/// // default values will be used.
/// let min_points = 3;
/// let hyperparams = DbscanHyperParams::new(min_points)
/// .tolerance(1e-2)
/// .build();
/// // Let's run the algorithm!
/// let clusters = Dbscan::predict(&hyperparams, &observations);
/// // Points are `None` if noise `Some(id)` if belonging to a cluster.
/// ```
///
pub struct Dbscan;

impl Dbscan {
pub fn predict(
hyperparameters: &DbscanHyperParams,
observations: &ArrayBase<impl Data<Elem = f64>, Ix2>,
) -> Array1<Option<usize>> {
let mut cluster_memberships = Array1::from_elem(observations.dim().1, None);
let mut current_cluster_id = 0;
for (i, obs) in observations.axis_iter(Axis(1)).enumerate() {
if cluster_memberships[i].is_some() {
continue;
}
let (neighbor_count, mut search_queue) = find_neighbors(
&obs,
observations,
hyperparameters.tolerance(),
&cluster_memberships,
);
if neighbor_count < hyperparameters.minimum_points() {
continue;
}
// Now go over the neighbours adding them to the cluster
cluster_memberships[i] = Some(current_cluster_id);

while !search_queue.is_empty() {
let candidate = search_queue.remove(0);

let (neighbor_count, mut neighbors) = find_neighbors(
&candidate.1,
observations,
hyperparameters.tolerance(),
&cluster_memberships,
);
if neighbor_count >= hyperparameters.minimum_points() {
cluster_memberships[candidate.0] = Some(current_cluster_id);
search_queue.append(&mut neighbors);
}
}
current_cluster_id += 1;
}
cluster_memberships
}
}

fn find_neighbors<'a>(
candidate: &ArrayBase<impl Data<Elem = f64>, Ix1>,
observations: &'a ArrayBase<impl Data<Elem = f64>, Ix2>,
eps: f64,
clusters: &Array1<Option<usize>>,
) -> (usize, Vec<(usize, ArrayView<'a, f64, Ix1>)>) {
let mut res = vec![];
let mut count = 0;
for (i, (obs, cluster)) in observations
.axis_iter(Axis(1))
.zip(clusters.iter())
.enumerate()
{
if candidate.l2_dist(&obs).unwrap() < eps {
count += 1;
if cluster.is_none() {
res.push((i, obs));
}
}
}
(count, res)
}

#[cfg(test)]
mod tests {
use super::*;
use ndarray::{arr1, s, Array2};

#[test]
fn nested_clusters() {
// Create a circuit of points and then a cluster in the centre
// and ensure they are identified as two separate clusters
let params = DbscanHyperParams::new(2).tolerance(1.0).build();

let mut data: Array2<f64> = Array2::zeros((2, 50));
let rising = Array1::linspace(0.0, 8.0, 10);
data.slice_mut(s![0, 0..10]).assign(&rising);
data.slice_mut(s![0, 10..20]).assign(&rising);
data.slice_mut(s![1, 20..30]).assign(&rising);
data.slice_mut(s![1, 30..40]).assign(&rising);

data.slice_mut(s![1, 0..10]).fill(0.0);
data.slice_mut(s![1, 10..20]).fill(8.0);
data.slice_mut(s![0, 20..30]).fill(0.0);
data.slice_mut(s![0, 30..40]).fill(8.0);

data.slice_mut(s![.., 40..]).fill(5.0);

let labels = Dbscan::predict(&params, &data);

assert!(labels.slice(s![..40]).iter().all(|x| x == &Some(0)));
assert!(labels.slice(s![40..]).iter().all(|x| x == &Some(1)));
}

#[test]
fn non_cluster_points() {
let params = DbscanHyperParams::new(4).build();
let mut data: Array2<f64> = Array2::zeros((2, 5));
data.slice_mut(s![.., 0]).assign(&arr1(&[10.0, 10.0]));

let labels = Dbscan::predict(&params, &data);
let expected = arr1(&[None, Some(0), Some(0), Some(0), Some(0)]);
assert_eq!(labels, expected);
}

#[test]
fn dataset_too_small() {
let params = DbscanHyperParams::new(4).build();

let data: Array2<f64> = Array2::zeros((2, 3));

let labels = Dbscan::predict(&params, &data);
assert!(labels.iter().all(|x| x.is_none()));
}
}
74 changes: 74 additions & 0 deletions linfa-clustering/src/dbscan/hyperparameters.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
use serde::{Deserialize, Serialize};

#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
/// The set of hyperparameters that can be specified for the execution of
/// the [DBSCAN algorithm](struct.Dbscan.html).
pub struct DbscanHyperParams {
/// Distance between points for them to be considered neighbours.
tolerance: f64,
/// Minimum number of neighboring points a point needs to have to be a core
/// point and not a noise point.
min_points: usize,
}

/// Helper struct used to construct a set of hyperparameters for
pub struct DbscanHyperParamsBuilder {
tolerance: f64,
min_points: usize,
}

impl DbscanHyperParamsBuilder {
/// Distance between points for them to be considered neighbours.
pub fn tolerance(mut self, tolerance: f64) -> Self {
self.tolerance = tolerance;
self
}

/// Return an instance of `DbscanHyperParams` after having performed
/// validation checks on all hyperparameters.
///
/// **Panics** if any of the validation checks fail.
pub fn build(self) -> DbscanHyperParams {
DbscanHyperParams::build(self.tolerance, self.min_points)
}
}

impl DbscanHyperParams {
/// Minimum number of neighboring points a point needs to have to be a core
/// point and not a noise point.
///
/// Defaults are provided if the optional parameters are not specified:
/// * `tolerance = 1e-4`
pub fn new(min_points: usize) -> DbscanHyperParamsBuilder {
DbscanHyperParamsBuilder {
min_points,
tolerance: 1e-4,
}
}

/// Two points are considered neighbors if the euclidean distance between
/// them is below the tolerance
pub fn tolerance(&self) -> f64 {
self.tolerance
}

/// Minimum number of a points in a neighborhood around a point for it to
/// not be considered noise
pub fn minimum_points(&self) -> usize {
self.min_points
}

fn build(tolerance: f64, min_points: usize) -> Self {
if tolerance <= 0. {
panic!("`tolerance` must be greater than 0!");
}
// There is always at least one neighbor to a point (itself)
if min_points <= 1 {
panic!("`min_points` must be greater than 1!");
}
Self {
tolerance,
min_points,
}
}
}
5 changes: 5 additions & 0 deletions linfa-clustering/src/dbscan/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
mod algorithm;
mod hyperparameters;

pub use algorithm::*;
pub use hyperparameters::*;
2 changes: 2 additions & 0 deletions linfa-clustering/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
//! Implementation choices, algorithmic details and a tutorial can be found [here](struct.KMeans.html).
//!
//! Check [here](https://github.com/LukeMathWalker/clustering-benchmarks) for extensive benchmarks against `scikit-learn`'s K-means implementation.
mod dbscan;
#[allow(clippy::new_ret_no_self)]
mod k_means;
mod utils;

pub use dbscan::*;
pub use k_means::*;
pub use utils::*;

0 comments on commit 6fc1f69

Please sign in to comment.