diff --git a/Cargo.toml b/Cargo.toml index 88463e47..c05680fd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ name = "smartcore" description = "Machine Learning in Rust." homepage = "https://smartcorelib.org" -version = "0.4.4" +version = "0.4.5" authors = ["smartcore Developers"] edition = "2021" license = "Apache-2.0" @@ -28,6 +28,7 @@ num = "0.4" rand = { version = "0.8.5", default-features = false, features = ["small_rng"] } rand_distr = { version = "0.4", optional = true } serde = { version = "1", features = ["derive"], optional = true } +ordered-float = "*" [target.'cfg(not(target_arch = "wasm32"))'.dependencies] typetag = { version = "0.2", optional = true } diff --git a/src/algorithm/neighbour/cosinepair.rs b/src/algorithm/neighbour/cosinepair.rs index 0a38f993..889be689 100644 --- a/src/algorithm/neighbour/cosinepair.rs +++ b/src/algorithm/neighbour/cosinepair.rs @@ -23,7 +23,10 @@ /// ``` /// /// -use std::collections::HashMap; +use ordered_float::{FloatCore, OrderedFloat}; + +use std::cmp::Reverse; +use std::collections::{BinaryHeap, HashMap}; use num::Bounded; @@ -34,6 +37,25 @@ use crate::metrics::distance::{Distance, PairwiseDistance}; use crate::numbers::floatnum::FloatNumber; use crate::numbers::realnum::RealNumber; +/// Parameters for CosinePair construction +#[derive(Debug, Clone)] +pub struct CosinePairParameters { + /// Maximum number of neighbors to consider per point (default: all points) + pub top_k: Option, + /// Whether to use approximate nearest neighbor search + pub approximate: bool, +} + +#[allow(clippy::derivable_impls)] +impl Default for CosinePairParameters { + fn default() -> Self { + Self { + top_k: None, + approximate: false, + } + } +} + /// /// Inspired by Python implementation: /// @@ -49,12 +71,29 @@ pub struct CosinePair<'a, T: RealNumber + FloatNumber, M: Array2> { pub distances: HashMap>, /// conga line used to keep track of the closest pair pub neighbours: Vec, + /// parameters used during construction + pub parameters: CosinePairParameters, } -impl<'a, T: RealNumber + FloatNumber, M: Array2> CosinePair<'a, T, M> { - /// Constructor - /// Instantiate and initialize the algorithm +impl<'a, T: RealNumber + FloatNumber + FloatCore, M: Array2> CosinePair<'a, T, M> { + /// Constructor with default parameters (backward compatibility) pub fn new(m: &'a M) -> Result { + Self::with_parameters(m, CosinePairParameters::default()) + } + + /// Constructor with top-k limiting for faster performance + pub fn with_top_k(m: &'a M, top_k: usize) -> Result { + Self::with_parameters( + m, + CosinePairParameters { + top_k: Some(top_k), + approximate: false, + }, + ) + } + + /// Constructor with full parameter control + pub fn with_parameters(m: &'a M, parameters: CosinePairParameters) -> Result { if m.shape().0 < 2 { return Err(Failed::because( FailedError::FindFailed, @@ -64,96 +103,156 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2> CosinePair<'a, T, M> { let mut init = Self { samples: m, - // to be computed in init(..) distances: HashMap::with_capacity(m.shape().0), - neighbours: Vec::with_capacity(m.shape().0 + 1), + neighbours: Vec::with_capacity(m.shape().0), + parameters, }; init.init(); Ok(init) } - /// Initialise `CosinePair` by passing a `Array2`. - /// Build a CosinePairs data-structure from a set of (new) points. + /// Helper function to create ordered float wrapper + fn ordered_float(value: T) -> OrderedFloat { + OrderedFloat(value) + } + + /// Helper function to extract value from ordered float wrapper + fn extract_float(ordered: OrderedFloat) -> T { + ordered.into_inner() + } + + /// Optimized initialization with top-k neighbor limiting fn init(&mut self) { - // basic measures let len = self.samples.shape().0; - let max_index = self.samples.shape().0 - 1; + let max_neighbors: usize = self.parameters.top_k.unwrap_or(len - 1).min(len - 1); - // Store all closest neighbors - let _distances = Box::new(HashMap::with_capacity(len)); - let _neighbours = Box::new(Vec::with_capacity(len)); + let mut distances = HashMap::with_capacity(len); + let mut neighbours = Vec::with_capacity(len); - let mut distances = *_distances; - let mut neighbours = *_neighbours; - - // fill neighbours with -1 values neighbours.extend(0..len); - // init closest neighbour pairwise data - for index_row_i in 0..(max_index) { + // Initialize with max distances + for i in 0..len { distances.insert( - index_row_i, + i, PairwiseDistance { - node: index_row_i, - neighbour: Option::None, + node: i, + neighbour: None, distance: Some(::max_value()), }, ); } - // loop through indeces and neighbours - for index_row_i in 0..(len) { - // start looking for the neighbour in the second element - let mut index_closest = index_row_i + 1; // closest neighbour index - let mut nbd: Option = distances[&index_row_i].distance; // init neighbour distance - for index_row_j in (index_row_i + 1)..len { - distances.insert( - index_row_j, - PairwiseDistance { - node: index_row_j, - neighbour: Some(index_row_i), - distance: nbd, - }, - ); - - let d = Cosine::new().distance( - &Vec::from_iterator( - self.samples.get_row(index_row_i).iterator(0).copied(), - self.samples.shape().1, - ), - &Vec::from_iterator( - self.samples.get_row(index_row_j).iterator(0).copied(), - self.samples.shape().1, - ), - ); - if d < nbd.unwrap().to_f64().unwrap() { - // set this j-value to be the closest neighbour - index_closest = index_row_j; - nbd = Some(T::from(d).unwrap()); + // Compute distances for each point using top-k optimization + for i in 0..len { + let mut candidate_distances = BinaryHeap::new(); + + for j in 0..len { + if i != j { + let distance = T::from(Cosine::new().distance( + &Vec::from_iterator( + self.samples.get_row(i).iterator(0).copied(), + self.samples.shape().1, + ), + &Vec::from_iterator( + self.samples.get_row(j).iterator(0).copied(), + self.samples.shape().1, + ), + )) + .unwrap(); + + // Use OrderedFloat for stable ordering + candidate_distances.push(Reverse((Self::ordered_float(distance), j))); + + if candidate_distances.len() > max_neighbors { + candidate_distances.pop(); + } } } - // Add that edge - distances.entry(index_row_i).and_modify(|e| { - e.distance = nbd; - e.neighbour = Some(index_closest); - }); - } - // No more neighbors, terminate conga line. - // Last person on the line has no neigbors - distances.get_mut(&max_index).unwrap().neighbour = Some(max_index); - distances.get_mut(&(len - 1)).unwrap().distance = Some(::max_value()); - - // compute sparse matrix (connectivity matrix) - let mut sparse_matrix = M::zeros(len, len); - for (_, p) in distances.iter() { - sparse_matrix.set((p.node, p.neighbour.unwrap()), p.distance.unwrap()); + // Find the closest neighbor from candidates + if let Some(Reverse((closest_distance, closest_neighbor))) = + candidate_distances.iter().min_by_key(|Reverse((d, _))| *d) + { + distances.entry(i).and_modify(|e| { + e.distance = Some(Self::extract_float(*closest_distance)); + e.neighbour = Some(*closest_neighbor); + }); + } } self.distances = distances; self.neighbours = neighbours; } + /// Fast query using top-k pre-computed neighbors with ordered-float + pub fn query_row_top_k( + &self, + query_row_index: usize, + k: usize, + ) -> Result, Failed> { + if query_row_index >= self.samples.shape().0 { + return Err(Failed::because( + FailedError::FindFailed, + "Query row index out of bounds", + )); + } + + if k == 0 { + return Ok(Vec::new()); + } + + let max_candidates = self.parameters.top_k.unwrap_or(self.samples.shape().0); + let actual_k: usize = k.min(max_candidates); + + // Use binary heap with ordered-float for reliable ordering + let mut heap = BinaryHeap::with_capacity(actual_k + 1); + + let candidates = if let Some(top_k) = self.parameters.top_k { + let step = (self.samples.shape().0 / top_k).max(1); + (0..self.samples.shape().0) + .step_by(step) + .filter(|&i| i != query_row_index) + .take(top_k) + .collect::>() + } else { + (0..self.samples.shape().0) + .filter(|&i| i != query_row_index) + .collect::>() + }; + + for &candidate_idx in &candidates { + let distance = T::from(Cosine::new().distance( + &Vec::from_iterator( + self.samples.get_row(query_row_index).iterator(0).copied(), + self.samples.shape().1, + ), + &Vec::from_iterator( + self.samples.get_row(candidate_idx).iterator(0).copied(), + self.samples.shape().1, + ), + )) + .unwrap(); + + heap.push(Reverse((Self::ordered_float(distance), candidate_idx))); + + if heap.len() > actual_k { + heap.pop(); + } + } + + // Convert heap to sorted vector + let mut neighbors: Vec<_> = heap + .into_vec() + .into_iter() + .map(|Reverse((dist, idx))| (Self::extract_float(dist), idx)) + .collect(); + + neighbors.sort_by(|a, b| Self::ordered_float(a.0).cmp(&Self::ordered_float(b.0))); + + Ok(neighbors) + } + /// Query k nearest neighbors for a row that's already in the dataset pub fn query_row(&self, query_row_index: usize, k: usize) -> Result, Failed> { if query_row_index >= self.samples.shape().0 { @@ -318,7 +417,7 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2> CosinePair<'a, T, M> { mod tests { use super::*; use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix}; - use approx::assert_relative_eq; + use approx::{assert_relative_eq, relative_eq}; #[cfg_attr( all(target_arch = "wasm32", not(target_os = "wasi")), @@ -499,10 +598,6 @@ mod tests { } } - #[cfg_attr( - all(target_arch = "wasm32", not(target_os = "wasi")), - wasm_bindgen_test::wasm_bindgen_test - )] #[test] fn cosine_pair_query_row_bounds_error() { let x = DenseMatrix::::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap(); @@ -520,10 +615,6 @@ mod tests { } } - #[cfg_attr( - all(target_arch = "wasm32", not(target_os = "wasi")), - wasm_bindgen_test::wasm_bindgen_test - )] #[test] fn cosine_pair_query_row_k_zero() { let x = @@ -635,6 +726,206 @@ mod tests { assert!(distance >= 0.0 && distance <= 2.0); } + #[test] + fn query_row_top_k_top_k_limiting() { + // Test that query_row_top_k respects top_k parameter and returns correct results + let x = DenseMatrix::::from_2d_array(&[ + &[1.0, 0.0, 0.0], // Point 0 + &[0.0, 1.0, 0.0], // Point 1 - orthogonal to point 0 + &[0.0, 0.0, 1.0], // Point 2 - orthogonal to point 0 + &[1.0, 1.0, 0.0], // Point 3 - closer to point 0 than points 1,2 + &[0.5, 0.0, 0.0], // Point 4 - very close to point 0 (parallel) + &[2.0, 0.0, 0.0], // Point 5 - very close to point 0 (parallel) + &[0.0, 1.0, 1.0], // Point 6 - far from point 0 + &[3.0, 3.0, 3.0], // Point 7 - moderately close to point 0 + ]) + .unwrap(); + + // Create CosinePair with top_k=4 to limit candidates + let cosine_pair = CosinePair::with_top_k(&x, 4).unwrap(); + + // Query for 3 nearest neighbors to point 0 + let neighbors = cosine_pair.query_row_top_k(0, 3).unwrap(); + + // Should return exactly 3 neighbors + assert_eq!(neighbors.len(), 3); + + // Verify that distances are in ascending order + for i in 1..neighbors.len() { + assert!( + neighbors[i - 1].0 <= neighbors[i].0, + "Distances should be in ascending order: {} <= {}", + neighbors[i - 1].0, + neighbors[i].0 + ); + } + + // All distances should be valid cosine distances (0 to 2) + for (distance, index) in &neighbors { + assert!( + *distance >= 0.0 && *distance <= 2.0, + "Cosine distance {} should be between 0 and 2", + distance + ); + assert!( + *index < x.shape().0, + "Neighbor index {} should be less than dataset size {}", + index, + x.shape().0 + ); + assert!( + *index != 0, + "Neighbor index should not include query point itself" + ); + } + + // The closest neighbor should be either point 4 or 5 (parallel vectors) + // These should have cosine distance ≈ 0 + let closest_distance = neighbors[0].0; + assert!( + closest_distance < 0.01, + "Closest parallel vector should have distance close to 0, got {}", + closest_distance + ); + + // Verify that we get different results with different top_k values + let cosine_pair_full = CosinePair::new(&x).unwrap(); + let neighbors_full = cosine_pair_full.query_row(0, 3).unwrap(); + + // Results should be the same or very close since we're asking for top 3 + // but the algorithm might find different candidates due to top_k limiting + assert_eq!(neighbors.len(), neighbors_full.len()); + + // The closest neighbor should be the same in both cases + let closest_idx_fast = neighbors[0].1; + let closest_idx_full = neighbors_full[0].1; + let closest_dist_fast = neighbors[0].0; + let closest_dist_full = neighbors_full[0].0; + + // Either we get the same closest neighbor, or distances are very close + if closest_idx_fast == closest_idx_full { + assert!(relative_eq!( + closest_dist_fast, + closest_dist_full, + epsilon = 1e-10 + )); + } else { + // Different neighbors, but distances should be very close (parallel vectors) + assert!(relative_eq!( + closest_dist_fast, + closest_dist_full, + epsilon = 1e-6 + )); + } + } + + #[test] + fn query_row_top_k_performance_vs_accuracy() { + // Test that query_row_top_k provides reasonable performance/accuracy tradeoff + // and handles edge cases properly + let large_dataset = DenseMatrix::::from_2d_array(&[ + &[1.0f32, 2.0, 3.0, 4.0], // Point 0 - query point + &[1.1f32, 2.1, 3.1, 4.1], // Point 1 - very close to 0 + &[1.05f32, 2.05, 3.05, 4.05], // Point 2 - very close to 0 + &[2.0f32, 4.0, 6.0, 8.0], // Point 3 - parallel to 0 (2x scaling) + &[0.5f32, 1.0, 1.5, 2.0], // Point 4 - parallel to 0 (0.5x scaling) + &[-1.0f32, -2.0, -3.0, -4.0], // Point 5 - opposite to 0 + &[4.0f32, 3.0, 2.0, 1.0], // Point 6 - different direction + &[0.0f32, 0.0, 0.0, 0.1], // Point 7 - mostly orthogonal + &[10.0f32, 20.0, 30.0, 40.0], // Point 8 - parallel but far + &[1.0f32, 0.0, 0.0, 0.0], // Point 9 - partially similar + &[0.0f32, 2.0, 0.0, 0.0], // Point 10 - partially similar + &[0.0f32, 0.0, 3.0, 0.0], // Point 11 - partially similar + ]) + .unwrap(); + + // Test with aggressive top_k limiting (only consider 5 out of 11 other points) + let cosine_pair_limited = CosinePair::with_top_k(&large_dataset, 5).unwrap(); + + // Query for 4 nearest neighbors + let neighbors_limited = cosine_pair_limited.query_row_top_k(0, 4).unwrap(); + + // Should return exactly 4 neighbors + assert_eq!(neighbors_limited.len(), 4); + + // Test error handling - out of bounds query + let result_oob = cosine_pair_limited.query_row_top_k(15, 2); + assert!(result_oob.is_err()); + if let Err(e) = result_oob { + assert_eq!( + e, + Failed::because(FailedError::FindFailed, "Query row index out of bounds") + ); + } + + // Test k=0 case + let neighbors_zero = cosine_pair_limited.query_row_top_k(0, 0).unwrap(); + assert_eq!(neighbors_zero.len(), 0); + + // Test k > available candidates + let neighbors_large_k = cosine_pair_limited.query_row_top_k(0, 20).unwrap(); + assert!(neighbors_large_k.len() <= 11); // At most 11 other points + + // Verify ordering is correct + for i in 1..neighbors_limited.len() { + assert!( + neighbors_limited[i - 1].0 <= neighbors_limited[i].0, + "Distance ordering violation at position {}: {} > {}", + i, + neighbors_limited[i - 1].0, + neighbors_limited[i].0 + ); + } + + // The closest neighbors should be the parallel vectors (points 1, 2, 3, 4) + // since they have the smallest cosine distances + let closest_distance = neighbors_limited[0].0; + assert!( + closest_distance < 0.1, + "Closest neighbor should be nearly parallel, distance: {}", + closest_distance + ); + + // Compare with full algorithm for accuracy assessment + let cosine_pair_full = CosinePair::new(&large_dataset).unwrap(); + let neighbors_full = cosine_pair_full.query_row(0, 4).unwrap(); + + // The fast version might not find the exact same neighbors due to sampling, + // but the closest neighbor's distance should be very similar + let dist_diff = (neighbors_limited[0].0 - neighbors_full[0].0).abs(); + assert!( + dist_diff < 0.01, + "Fast and full algorithms should give similar closest distances. Diff: {}", + dist_diff + ); + + // Verify that all returned indices are valid and unique + let mut indices: Vec = neighbors_limited.iter().map(|(_, idx)| *idx).collect(); + indices.sort(); + indices.dedup(); + assert_eq!( + indices.len(), + neighbors_limited.len(), + "All neighbor indices should be unique" + ); + + for &idx in &indices { + assert!( + idx < large_dataset.shape().0, + "Neighbor index {} should be valid", + idx + ); + assert!(idx != 0, "Neighbor should not include query point itself"); + } + + // Test with f32 precision to ensure type compatibility + for (distance, _) in &neighbors_limited { + assert!(!distance.is_nan(), "Distance should not be NaN"); + assert!(distance.is_finite(), "Distance should be finite"); + assert!(*distance >= 0.0, "Distance should be non-negative"); + } + } + #[test] fn cosine_pair_float_precision() { // Test with f32 precision