diff --git a/Cargo.toml b/Cargo.toml
index 88463e47..c05680fd 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -2,7 +2,7 @@
name = "smartcore"
description = "Machine Learning in Rust."
homepage = "https://smartcorelib.org"
-version = "0.4.4"
+version = "0.4.5"
authors = ["smartcore Developers"]
edition = "2021"
license = "Apache-2.0"
@@ -28,6 +28,7 @@ num = "0.4"
rand = { version = "0.8.5", default-features = false, features = ["small_rng"] }
rand_distr = { version = "0.4", optional = true }
serde = { version = "1", features = ["derive"], optional = true }
+ordered-float = "*"
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
typetag = { version = "0.2", optional = true }
diff --git a/src/algorithm/neighbour/cosinepair.rs b/src/algorithm/neighbour/cosinepair.rs
index 0a38f993..889be689 100644
--- a/src/algorithm/neighbour/cosinepair.rs
+++ b/src/algorithm/neighbour/cosinepair.rs
@@ -23,7 +23,10 @@
/// ```
///
///
-use std::collections::HashMap;
+use ordered_float::{FloatCore, OrderedFloat};
+
+use std::cmp::Reverse;
+use std::collections::{BinaryHeap, HashMap};
use num::Bounded;
@@ -34,6 +37,25 @@ use crate::metrics::distance::{Distance, PairwiseDistance};
use crate::numbers::floatnum::FloatNumber;
use crate::numbers::realnum::RealNumber;
+/// Parameters for CosinePair construction
+#[derive(Debug, Clone)]
+pub struct CosinePairParameters {
+ /// Maximum number of neighbors to consider per point (default: all points)
+ pub top_k: Option,
+ /// Whether to use approximate nearest neighbor search
+ pub approximate: bool,
+}
+
+#[allow(clippy::derivable_impls)]
+impl Default for CosinePairParameters {
+ fn default() -> Self {
+ Self {
+ top_k: None,
+ approximate: false,
+ }
+ }
+}
+
///
/// Inspired by Python implementation:
///
@@ -49,12 +71,29 @@ pub struct CosinePair<'a, T: RealNumber + FloatNumber, M: Array2> {
pub distances: HashMap>,
/// conga line used to keep track of the closest pair
pub neighbours: Vec,
+ /// parameters used during construction
+ pub parameters: CosinePairParameters,
}
-impl<'a, T: RealNumber + FloatNumber, M: Array2> CosinePair<'a, T, M> {
- /// Constructor
- /// Instantiate and initialize the algorithm
+impl<'a, T: RealNumber + FloatNumber + FloatCore, M: Array2> CosinePair<'a, T, M> {
+ /// Constructor with default parameters (backward compatibility)
pub fn new(m: &'a M) -> Result {
+ Self::with_parameters(m, CosinePairParameters::default())
+ }
+
+ /// Constructor with top-k limiting for faster performance
+ pub fn with_top_k(m: &'a M, top_k: usize) -> Result {
+ Self::with_parameters(
+ m,
+ CosinePairParameters {
+ top_k: Some(top_k),
+ approximate: false,
+ },
+ )
+ }
+
+ /// Constructor with full parameter control
+ pub fn with_parameters(m: &'a M, parameters: CosinePairParameters) -> Result {
if m.shape().0 < 2 {
return Err(Failed::because(
FailedError::FindFailed,
@@ -64,96 +103,156 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2> CosinePair<'a, T, M> {
let mut init = Self {
samples: m,
- // to be computed in init(..)
distances: HashMap::with_capacity(m.shape().0),
- neighbours: Vec::with_capacity(m.shape().0 + 1),
+ neighbours: Vec::with_capacity(m.shape().0),
+ parameters,
};
init.init();
Ok(init)
}
- /// Initialise `CosinePair` by passing a `Array2`.
- /// Build a CosinePairs data-structure from a set of (new) points.
+ /// Helper function to create ordered float wrapper
+ fn ordered_float(value: T) -> OrderedFloat {
+ OrderedFloat(value)
+ }
+
+ /// Helper function to extract value from ordered float wrapper
+ fn extract_float(ordered: OrderedFloat) -> T {
+ ordered.into_inner()
+ }
+
+ /// Optimized initialization with top-k neighbor limiting
fn init(&mut self) {
- // basic measures
let len = self.samples.shape().0;
- let max_index = self.samples.shape().0 - 1;
+ let max_neighbors: usize = self.parameters.top_k.unwrap_or(len - 1).min(len - 1);
- // Store all closest neighbors
- let _distances = Box::new(HashMap::with_capacity(len));
- let _neighbours = Box::new(Vec::with_capacity(len));
+ let mut distances = HashMap::with_capacity(len);
+ let mut neighbours = Vec::with_capacity(len);
- let mut distances = *_distances;
- let mut neighbours = *_neighbours;
-
- // fill neighbours with -1 values
neighbours.extend(0..len);
- // init closest neighbour pairwise data
- for index_row_i in 0..(max_index) {
+ // Initialize with max distances
+ for i in 0..len {
distances.insert(
- index_row_i,
+ i,
PairwiseDistance {
- node: index_row_i,
- neighbour: Option::None,
+ node: i,
+ neighbour: None,
distance: Some(::max_value()),
},
);
}
- // loop through indeces and neighbours
- for index_row_i in 0..(len) {
- // start looking for the neighbour in the second element
- let mut index_closest = index_row_i + 1; // closest neighbour index
- let mut nbd: Option = distances[&index_row_i].distance; // init neighbour distance
- for index_row_j in (index_row_i + 1)..len {
- distances.insert(
- index_row_j,
- PairwiseDistance {
- node: index_row_j,
- neighbour: Some(index_row_i),
- distance: nbd,
- },
- );
-
- let d = Cosine::new().distance(
- &Vec::from_iterator(
- self.samples.get_row(index_row_i).iterator(0).copied(),
- self.samples.shape().1,
- ),
- &Vec::from_iterator(
- self.samples.get_row(index_row_j).iterator(0).copied(),
- self.samples.shape().1,
- ),
- );
- if d < nbd.unwrap().to_f64().unwrap() {
- // set this j-value to be the closest neighbour
- index_closest = index_row_j;
- nbd = Some(T::from(d).unwrap());
+ // Compute distances for each point using top-k optimization
+ for i in 0..len {
+ let mut candidate_distances = BinaryHeap::new();
+
+ for j in 0..len {
+ if i != j {
+ let distance = T::from(Cosine::new().distance(
+ &Vec::from_iterator(
+ self.samples.get_row(i).iterator(0).copied(),
+ self.samples.shape().1,
+ ),
+ &Vec::from_iterator(
+ self.samples.get_row(j).iterator(0).copied(),
+ self.samples.shape().1,
+ ),
+ ))
+ .unwrap();
+
+ // Use OrderedFloat for stable ordering
+ candidate_distances.push(Reverse((Self::ordered_float(distance), j)));
+
+ if candidate_distances.len() > max_neighbors {
+ candidate_distances.pop();
+ }
}
}
- // Add that edge
- distances.entry(index_row_i).and_modify(|e| {
- e.distance = nbd;
- e.neighbour = Some(index_closest);
- });
- }
- // No more neighbors, terminate conga line.
- // Last person on the line has no neigbors
- distances.get_mut(&max_index).unwrap().neighbour = Some(max_index);
- distances.get_mut(&(len - 1)).unwrap().distance = Some(::max_value());
-
- // compute sparse matrix (connectivity matrix)
- let mut sparse_matrix = M::zeros(len, len);
- for (_, p) in distances.iter() {
- sparse_matrix.set((p.node, p.neighbour.unwrap()), p.distance.unwrap());
+ // Find the closest neighbor from candidates
+ if let Some(Reverse((closest_distance, closest_neighbor))) =
+ candidate_distances.iter().min_by_key(|Reverse((d, _))| *d)
+ {
+ distances.entry(i).and_modify(|e| {
+ e.distance = Some(Self::extract_float(*closest_distance));
+ e.neighbour = Some(*closest_neighbor);
+ });
+ }
}
self.distances = distances;
self.neighbours = neighbours;
}
+ /// Fast query using top-k pre-computed neighbors with ordered-float
+ pub fn query_row_top_k(
+ &self,
+ query_row_index: usize,
+ k: usize,
+ ) -> Result, Failed> {
+ if query_row_index >= self.samples.shape().0 {
+ return Err(Failed::because(
+ FailedError::FindFailed,
+ "Query row index out of bounds",
+ ));
+ }
+
+ if k == 0 {
+ return Ok(Vec::new());
+ }
+
+ let max_candidates = self.parameters.top_k.unwrap_or(self.samples.shape().0);
+ let actual_k: usize = k.min(max_candidates);
+
+ // Use binary heap with ordered-float for reliable ordering
+ let mut heap = BinaryHeap::with_capacity(actual_k + 1);
+
+ let candidates = if let Some(top_k) = self.parameters.top_k {
+ let step = (self.samples.shape().0 / top_k).max(1);
+ (0..self.samples.shape().0)
+ .step_by(step)
+ .filter(|&i| i != query_row_index)
+ .take(top_k)
+ .collect::>()
+ } else {
+ (0..self.samples.shape().0)
+ .filter(|&i| i != query_row_index)
+ .collect::>()
+ };
+
+ for &candidate_idx in &candidates {
+ let distance = T::from(Cosine::new().distance(
+ &Vec::from_iterator(
+ self.samples.get_row(query_row_index).iterator(0).copied(),
+ self.samples.shape().1,
+ ),
+ &Vec::from_iterator(
+ self.samples.get_row(candidate_idx).iterator(0).copied(),
+ self.samples.shape().1,
+ ),
+ ))
+ .unwrap();
+
+ heap.push(Reverse((Self::ordered_float(distance), candidate_idx)));
+
+ if heap.len() > actual_k {
+ heap.pop();
+ }
+ }
+
+ // Convert heap to sorted vector
+ let mut neighbors: Vec<_> = heap
+ .into_vec()
+ .into_iter()
+ .map(|Reverse((dist, idx))| (Self::extract_float(dist), idx))
+ .collect();
+
+ neighbors.sort_by(|a, b| Self::ordered_float(a.0).cmp(&Self::ordered_float(b.0)));
+
+ Ok(neighbors)
+ }
+
/// Query k nearest neighbors for a row that's already in the dataset
pub fn query_row(&self, query_row_index: usize, k: usize) -> Result, Failed> {
if query_row_index >= self.samples.shape().0 {
@@ -318,7 +417,7 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2> CosinePair<'a, T, M> {
mod tests {
use super::*;
use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix};
- use approx::assert_relative_eq;
+ use approx::{assert_relative_eq, relative_eq};
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
@@ -499,10 +598,6 @@ mod tests {
}
}
- #[cfg_attr(
- all(target_arch = "wasm32", not(target_os = "wasi")),
- wasm_bindgen_test::wasm_bindgen_test
- )]
#[test]
fn cosine_pair_query_row_bounds_error() {
let x = DenseMatrix::::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
@@ -520,10 +615,6 @@ mod tests {
}
}
- #[cfg_attr(
- all(target_arch = "wasm32", not(target_os = "wasi")),
- wasm_bindgen_test::wasm_bindgen_test
- )]
#[test]
fn cosine_pair_query_row_k_zero() {
let x =
@@ -635,6 +726,206 @@ mod tests {
assert!(distance >= 0.0 && distance <= 2.0);
}
+ #[test]
+ fn query_row_top_k_top_k_limiting() {
+ // Test that query_row_top_k respects top_k parameter and returns correct results
+ let x = DenseMatrix::::from_2d_array(&[
+ &[1.0, 0.0, 0.0], // Point 0
+ &[0.0, 1.0, 0.0], // Point 1 - orthogonal to point 0
+ &[0.0, 0.0, 1.0], // Point 2 - orthogonal to point 0
+ &[1.0, 1.0, 0.0], // Point 3 - closer to point 0 than points 1,2
+ &[0.5, 0.0, 0.0], // Point 4 - very close to point 0 (parallel)
+ &[2.0, 0.0, 0.0], // Point 5 - very close to point 0 (parallel)
+ &[0.0, 1.0, 1.0], // Point 6 - far from point 0
+ &[3.0, 3.0, 3.0], // Point 7 - moderately close to point 0
+ ])
+ .unwrap();
+
+ // Create CosinePair with top_k=4 to limit candidates
+ let cosine_pair = CosinePair::with_top_k(&x, 4).unwrap();
+
+ // Query for 3 nearest neighbors to point 0
+ let neighbors = cosine_pair.query_row_top_k(0, 3).unwrap();
+
+ // Should return exactly 3 neighbors
+ assert_eq!(neighbors.len(), 3);
+
+ // Verify that distances are in ascending order
+ for i in 1..neighbors.len() {
+ assert!(
+ neighbors[i - 1].0 <= neighbors[i].0,
+ "Distances should be in ascending order: {} <= {}",
+ neighbors[i - 1].0,
+ neighbors[i].0
+ );
+ }
+
+ // All distances should be valid cosine distances (0 to 2)
+ for (distance, index) in &neighbors {
+ assert!(
+ *distance >= 0.0 && *distance <= 2.0,
+ "Cosine distance {} should be between 0 and 2",
+ distance
+ );
+ assert!(
+ *index < x.shape().0,
+ "Neighbor index {} should be less than dataset size {}",
+ index,
+ x.shape().0
+ );
+ assert!(
+ *index != 0,
+ "Neighbor index should not include query point itself"
+ );
+ }
+
+ // The closest neighbor should be either point 4 or 5 (parallel vectors)
+ // These should have cosine distance ≈ 0
+ let closest_distance = neighbors[0].0;
+ assert!(
+ closest_distance < 0.01,
+ "Closest parallel vector should have distance close to 0, got {}",
+ closest_distance
+ );
+
+ // Verify that we get different results with different top_k values
+ let cosine_pair_full = CosinePair::new(&x).unwrap();
+ let neighbors_full = cosine_pair_full.query_row(0, 3).unwrap();
+
+ // Results should be the same or very close since we're asking for top 3
+ // but the algorithm might find different candidates due to top_k limiting
+ assert_eq!(neighbors.len(), neighbors_full.len());
+
+ // The closest neighbor should be the same in both cases
+ let closest_idx_fast = neighbors[0].1;
+ let closest_idx_full = neighbors_full[0].1;
+ let closest_dist_fast = neighbors[0].0;
+ let closest_dist_full = neighbors_full[0].0;
+
+ // Either we get the same closest neighbor, or distances are very close
+ if closest_idx_fast == closest_idx_full {
+ assert!(relative_eq!(
+ closest_dist_fast,
+ closest_dist_full,
+ epsilon = 1e-10
+ ));
+ } else {
+ // Different neighbors, but distances should be very close (parallel vectors)
+ assert!(relative_eq!(
+ closest_dist_fast,
+ closest_dist_full,
+ epsilon = 1e-6
+ ));
+ }
+ }
+
+ #[test]
+ fn query_row_top_k_performance_vs_accuracy() {
+ // Test that query_row_top_k provides reasonable performance/accuracy tradeoff
+ // and handles edge cases properly
+ let large_dataset = DenseMatrix::::from_2d_array(&[
+ &[1.0f32, 2.0, 3.0, 4.0], // Point 0 - query point
+ &[1.1f32, 2.1, 3.1, 4.1], // Point 1 - very close to 0
+ &[1.05f32, 2.05, 3.05, 4.05], // Point 2 - very close to 0
+ &[2.0f32, 4.0, 6.0, 8.0], // Point 3 - parallel to 0 (2x scaling)
+ &[0.5f32, 1.0, 1.5, 2.0], // Point 4 - parallel to 0 (0.5x scaling)
+ &[-1.0f32, -2.0, -3.0, -4.0], // Point 5 - opposite to 0
+ &[4.0f32, 3.0, 2.0, 1.0], // Point 6 - different direction
+ &[0.0f32, 0.0, 0.0, 0.1], // Point 7 - mostly orthogonal
+ &[10.0f32, 20.0, 30.0, 40.0], // Point 8 - parallel but far
+ &[1.0f32, 0.0, 0.0, 0.0], // Point 9 - partially similar
+ &[0.0f32, 2.0, 0.0, 0.0], // Point 10 - partially similar
+ &[0.0f32, 0.0, 3.0, 0.0], // Point 11 - partially similar
+ ])
+ .unwrap();
+
+ // Test with aggressive top_k limiting (only consider 5 out of 11 other points)
+ let cosine_pair_limited = CosinePair::with_top_k(&large_dataset, 5).unwrap();
+
+ // Query for 4 nearest neighbors
+ let neighbors_limited = cosine_pair_limited.query_row_top_k(0, 4).unwrap();
+
+ // Should return exactly 4 neighbors
+ assert_eq!(neighbors_limited.len(), 4);
+
+ // Test error handling - out of bounds query
+ let result_oob = cosine_pair_limited.query_row_top_k(15, 2);
+ assert!(result_oob.is_err());
+ if let Err(e) = result_oob {
+ assert_eq!(
+ e,
+ Failed::because(FailedError::FindFailed, "Query row index out of bounds")
+ );
+ }
+
+ // Test k=0 case
+ let neighbors_zero = cosine_pair_limited.query_row_top_k(0, 0).unwrap();
+ assert_eq!(neighbors_zero.len(), 0);
+
+ // Test k > available candidates
+ let neighbors_large_k = cosine_pair_limited.query_row_top_k(0, 20).unwrap();
+ assert!(neighbors_large_k.len() <= 11); // At most 11 other points
+
+ // Verify ordering is correct
+ for i in 1..neighbors_limited.len() {
+ assert!(
+ neighbors_limited[i - 1].0 <= neighbors_limited[i].0,
+ "Distance ordering violation at position {}: {} > {}",
+ i,
+ neighbors_limited[i - 1].0,
+ neighbors_limited[i].0
+ );
+ }
+
+ // The closest neighbors should be the parallel vectors (points 1, 2, 3, 4)
+ // since they have the smallest cosine distances
+ let closest_distance = neighbors_limited[0].0;
+ assert!(
+ closest_distance < 0.1,
+ "Closest neighbor should be nearly parallel, distance: {}",
+ closest_distance
+ );
+
+ // Compare with full algorithm for accuracy assessment
+ let cosine_pair_full = CosinePair::new(&large_dataset).unwrap();
+ let neighbors_full = cosine_pair_full.query_row(0, 4).unwrap();
+
+ // The fast version might not find the exact same neighbors due to sampling,
+ // but the closest neighbor's distance should be very similar
+ let dist_diff = (neighbors_limited[0].0 - neighbors_full[0].0).abs();
+ assert!(
+ dist_diff < 0.01,
+ "Fast and full algorithms should give similar closest distances. Diff: {}",
+ dist_diff
+ );
+
+ // Verify that all returned indices are valid and unique
+ let mut indices: Vec = neighbors_limited.iter().map(|(_, idx)| *idx).collect();
+ indices.sort();
+ indices.dedup();
+ assert_eq!(
+ indices.len(),
+ neighbors_limited.len(),
+ "All neighbor indices should be unique"
+ );
+
+ for &idx in &indices {
+ assert!(
+ idx < large_dataset.shape().0,
+ "Neighbor index {} should be valid",
+ idx
+ );
+ assert!(idx != 0, "Neighbor should not include query point itself");
+ }
+
+ // Test with f32 precision to ensure type compatibility
+ for (distance, _) in &neighbors_limited {
+ assert!(!distance.is_nan(), "Distance should not be NaN");
+ assert!(distance.is_finite(), "Distance should be finite");
+ assert!(*distance >= 0.0, "Distance should be non-negative");
+ }
+ }
+
#[test]
fn cosine_pair_float_precision() {
// Test with f32 precision