Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Initial commit of DBScan #12

Merged
merged 11 commits into from
Dec 27, 2019
65 changes: 65 additions & 0 deletions linfa-clustering/src/dbscan/algorithm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
use crate::dbscan::hyperparameters::DBScanHyperParams;
use ndarray::{Array1, ArrayBase, Axis, Data, Ix1, Ix2};

pub fn predict(
hyperparameters: DBScanHyperParams,
observations: &ArrayBase<impl Data<Elem = f64> + Sync, Ix2>,
LukeMathWalker marked this conversation as resolved.
Show resolved Hide resolved
) -> Array1<Option<usize>> {
let mut result = Array1::from_elem(observations.dim().1, None);
let mut latest_id = 0;
for (i, obs) in observations.axis_iter(Axis(1)).enumerate() {
if result[i].is_some() {
continue;
}
let n = find_neighbors(&obs, observations, hyperparameters.tolerance());
if n.len() < hyperparameters.minimum_points() {
continue;
}
// Now go over the neighbours adding them to the cluster
LukeMathWalker marked this conversation as resolved.
Show resolved Hide resolved
let mut search_queue = n
.iter()
.filter(|x| result[[**x]].is_none())
.copied()
.collect::<Vec<_>>();
while !search_queue.is_empty() {
let cand = search_queue.remove(0);

result[cand] = Some(latest_id);

let mut n = find_neighbors(&obs, observations, hyperparameters.tolerance())
.iter()
.filter(|x| result[[**x]].is_none() && !search_queue.contains(x))
.copied()
.collect::<Vec<_>>();

search_queue.append(&mut n);
}
latest_id += 1;
}
result
}
LukeMathWalker marked this conversation as resolved.
Show resolved Hide resolved

fn find_neighbors(
candidate: &ArrayBase<impl Data<Elem = f64> + Sync, Ix1>,
observations: &ArrayBase<impl Data<Elem = f64> + Sync, Ix2>,
eps: f64,
) -> Vec<usize> {
let mut res = vec![];
for (i, obs) in observations.axis_iter(Axis(1)).enumerate() {
if distance(candidate, &obs) < eps {
LukeMathWalker marked this conversation as resolved.
Show resolved Hide resolved
res.push(i);
}
}
res
}

fn distance(
lhs: &ArrayBase<impl Data<Elem = f64> + Sync, Ix1>,
rhs: &ArrayBase<impl Data<Elem = f64> + Sync, Ix1>,
) -> f64 {
let mut acc = 0.0;
for (l, r) in lhs.iter().zip(rhs.iter()) {
acc += (l - r).powi(2);
}
acc.sqrt()
}
59 changes: 59 additions & 0 deletions linfa-clustering/src/dbscan/hyperparameters.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
use serde::{Deserialize, Serialize};

#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
/// The set of hyperparameters that can be specified for the execution of
/// the [DBSCAN algorithm](struct.DBScan.html).
pub struct DBScanHyperParams {
/// Distance between points for them to be considered neighbours.
tolerance: f64,
LukeMathWalker marked this conversation as resolved.
Show resolved Hide resolved
/// Minimum number of neighboring points a point needs to have to be a core
/// point and not a noise point.
min_points: usize,
}

/// Helper struct used to construct a set of hyperparameters for
pub struct DBScanHyperParamsBuilder {
tolerance: f64,
min_points: usize,
}

impl DBScanHyperParamsBuilder {
pub fn tolerance(mut self, tolerance: f64) -> Self {
self.tolerance = tolerance;
self
}

pub fn build(self) -> DBScanHyperParams {
DBScanHyperParams::build(self.tolerance, self.min_points)
}
}

impl DBScanHyperParams {
pub fn new(min_points: usize) -> DBScanHyperParamsBuilder {
DBScanHyperParamsBuilder {
min_points,
tolerance: 1e-4,
}
}

pub fn tolerance(&self) -> f64 {
self.tolerance
}

pub fn minimum_points(&self) -> usize {
self.min_points
}

fn build(tolerance: f64, min_points: usize) -> Self {
if tolerance <= 0. {
panic!("`tolerance` must be greater than 0!");
}
if min_points == 0 {
panic!("`min_points` cannot be 0!");
}
Self {
tolerance,
min_points,
}
}
}
5 changes: 5 additions & 0 deletions linfa-clustering/src/dbscan/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
mod algorithm;
mod hyperparameters;

pub use algorithm::*;
pub use hyperparameters::*;
2 changes: 2 additions & 0 deletions linfa-clustering/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
//! Implementation choices, algorithmic details and a tutorial can be found [here](struct.KMeans.html).
//!
//! Check [here](https://github.com/LukeMathWalker/clustering-benchmarks) for extensive benchmarks against `scikit-learn`'s K-means implementation.
mod dbscan;
#[allow(clippy::new_ret_no_self)]
mod k_means;
mod utils;

pub use dbscan::*;
pub use k_means::*;
pub use utils::*;