From 918828091eb84c6f5399ce0fd557e9107b4bac91 Mon Sep 17 00:00:00 2001 From: Lorenzo Mec-iS Date: Sat, 27 Sep 2025 10:51:19 +0100 Subject: [PATCH 01/10] Implement cosine similarity and cosinepair --- src/algorithm/neighbour/cosinepair.rs | 772 ++++++++++++++++++++++++++ src/algorithm/neighbour/mod.rs | 2 + src/metrics/distance/cosine.rs | 219 ++++++++ src/metrics/distance/mod.rs | 2 + 4 files changed, 995 insertions(+) create mode 100644 src/algorithm/neighbour/cosinepair.rs create mode 100644 src/metrics/distance/cosine.rs diff --git a/src/algorithm/neighbour/cosinepair.rs b/src/algorithm/neighbour/cosinepair.rs new file mode 100644 index 00000000..3c2effe6 --- /dev/null +++ b/src/algorithm/neighbour/cosinepair.rs @@ -0,0 +1,772 @@ +/// +/// ### CosinePair: Data-structure for the dynamic closest-pair problem. +/// +/// Reference: +/// Eppstein, David: Fast hierarchical clustering and other applications of +/// dynamic closest pairs. Journal of Experimental Algorithmics 5 (2000) 1. +/// +/// Example: +/// ``` +/// use smartcore::metrics::distance::PairwiseDistance; +/// use smartcore::linalg::basic::matrix::DenseMatrix; +/// use smartcore::algorithm::neighbour::cosinepair::CosinePair; +/// let x = DenseMatrix::::from_2d_array(&[ +/// &[5.1, 3.5, 1.4, 0.2], +/// &[4.9, 3.0, 1.4, 0.2], +/// &[4.7, 3.2, 1.3, 0.2], +/// &[4.6, 3.1, 1.5, 0.2], +/// &[5.0, 3.6, 1.4, 0.2], +/// &[5.4, 3.9, 1.7, 0.4], +/// ]).unwrap(); +/// let cosinepair = CosinePair::new(&x); +/// let closest_pair: PairwiseDistance = cosinepair.unwrap().closest_pair(); +/// ``` +/// +/// +use std::collections::HashMap; + +use num::Bounded; + +use crate::error::{Failed, FailedError}; +use crate::linalg::basic::arrays::{Array1, Array2}; +use crate::metrics::distance::cosine::Cosine; +use crate::metrics::distance::{Distance, PairwiseDistance}; +use crate::numbers::floatnum::FloatNumber; +use crate::numbers::realnum::RealNumber; + +/// +/// Inspired by Python implementation: +/// +/// MIT License (MIT) Copyright (c) 2016 Carson Farmer +/// +/// affinity used is Cosine as it is the most used +/// +#[derive(Debug, Clone)] +pub struct CosinePair<'a, T: RealNumber + FloatNumber, M: Array2> { + /// initial matrix + pub samples: &'a M, + /// closest pair hashmap (connectivity matrix for closest pairs) + pub distances: HashMap>, + /// conga line used to keep track of the closest pair + pub neighbours: Vec, +} + +impl<'a, T: RealNumber + FloatNumber, M: Array2> CosinePair<'a, T, M> { + /// Constructor + /// Instantiate and initialize the algorithm + pub fn new(m: &'a M) -> Result { + if m.shape().0 < 2 { + return Err(Failed::because( + FailedError::FindFailed, + "min number of rows should be 2", + )); + } + + let mut init = Self { + samples: m, + // to be computed in init(..) + distances: HashMap::with_capacity(m.shape().0), + neighbours: Vec::with_capacity(m.shape().0 + 1), + }; + init.init(); + Ok(init) + } + + /// Initialise `CosinePair` by passing a `Array2`. + /// Build a CosinePairs data-structure from a set of (new) points. + fn init(&mut self) { + // basic measures + let len = self.samples.shape().0; + let max_index = self.samples.shape().0 - 1; + + // Store all closest neighbors + let _distances = Box::new(HashMap::with_capacity(len)); + let _neighbours = Box::new(Vec::with_capacity(len)); + + let mut distances = *_distances; + let mut neighbours = *_neighbours; + + // fill neighbours with -1 values + neighbours.extend(0..len); + + // init closest neighbour pairwise data + for index_row_i in 0..(max_index) { + distances.insert( + index_row_i, + PairwiseDistance { + node: index_row_i, + neighbour: Option::None, + distance: Some(::max_value()), + }, + ); + } + + // loop through indeces and neighbours + for index_row_i in 0..(len) { + // start looking for the neighbour in the second element + let mut index_closest = index_row_i + 1; // closest neighbour index + let mut nbd: Option = distances[&index_row_i].distance; // init neighbour distance + for index_row_j in (index_row_i + 1)..len { + distances.insert( + index_row_j, + PairwiseDistance { + node: index_row_j, + neighbour: Some(index_row_i), + distance: nbd, + }, + ); + + let d = Cosine::new().distance( + &Vec::from_iterator( + self.samples.get_row(index_row_i).iterator(0).copied(), + self.samples.shape().1, + ), + &Vec::from_iterator( + self.samples.get_row(index_row_j).iterator(0).copied(), + self.samples.shape().1, + ), + ); + if d < nbd.unwrap().to_f64().unwrap() { + // set this j-value to be the closest neighbour + index_closest = index_row_j; + nbd = Some(T::from(d).unwrap()); + } + } + + // Add that edge + distances.entry(index_row_i).and_modify(|e| { + e.distance = nbd; + e.neighbour = Some(index_closest); + }); + } + // No more neighbors, terminate conga line. + // Last person on the line has no neigbors + distances.get_mut(&max_index).unwrap().neighbour = Some(max_index); + distances.get_mut(&(len - 1)).unwrap().distance = Some(::max_value()); + + // compute sparse matrix (connectivity matrix) + let mut sparse_matrix = M::zeros(len, len); + for (_, p) in distances.iter() { + sparse_matrix.set((p.node, p.neighbour.unwrap()), p.distance.unwrap()); + } + + self.distances = distances; + self.neighbours = neighbours; + } + + /// Query k nearest neighbors for a row that's already in the dataset + pub fn query_row(&self, query_row_index: usize, k: usize) -> Result, Failed> { + if query_row_index >= self.samples.shape().0 { + return Err(Failed::because( + FailedError::FindFailed, + "Query row index out of bounds" + )); + } + + if k == 0 { + return Ok(Vec::new()); + } + + // Get distances to all other points + let mut distances = self.distances_from(query_row_index); + + // Sort by distance (ascending) + distances.sort_by(|a, b| { + a.distance.unwrap().partial_cmp(&b.distance.unwrap()).unwrap_or(std::cmp::Ordering::Equal) + }); + + // Take top k neighbors and convert to (distance, index) format + let neighbors: Vec<(T, usize)> = distances + .into_iter() + .take(k) + .map(|pd| (pd.distance.unwrap(), pd.neighbour.unwrap())) + .collect(); + + Ok(neighbors) + } + + /// Query k nearest neighbors for an external query vector + pub fn query(&self, query_vector: &Vec, k: usize) -> Result, Failed> { + if query_vector.len() != self.samples.shape().1 { + return Err(Failed::because( + FailedError::FindFailed, + "Query vector dimension mismatch" + )); + } + + if k == 0 { + return Ok(Vec::new()); + } + + // Compute distances from query vector to all points in the dataset + let mut distances = Vec::>::with_capacity(self.samples.shape().0); + + for i in 0..self.samples.shape().0 { + let dataset_point = Vec::from_iterator( + self.samples.get_row(i).iterator(0).copied(), + self.samples.shape().1, + ); + + let distance = T::from(Cosine::new().distance(query_vector, &dataset_point)).unwrap(); + + distances.push(PairwiseDistance { + node: i, // This represents the dataset point index + neighbour: Some(i), + distance: Some(distance), + }); + } + + // Sort by distance (ascending) + distances.sort_by(|a, b| { + a.distance.unwrap().partial_cmp(&b.distance.unwrap()).unwrap_or(std::cmp::Ordering::Equal) + }); + + // Take top k neighbors and convert to (distance, index) format + let neighbors: Vec<(T, usize)> = distances + .into_iter() + .take(k) + .map(|pd| (pd.distance.unwrap(), pd.node)) + .collect(); + + Ok(neighbors) + } + + /// Optimized version that reuses the existing distances_from method + /// This is more efficient for queries that are points already in the dataset + pub fn query_optimized(&self, query_row_index: usize, k: usize) -> Result, Failed> { + // Reuse existing method and sort the results + self.query_row(query_row_index, k) + } + + /// Find closest pair by scanning list of nearest neighbors. + #[allow(dead_code)] + pub fn closest_pair(&self) -> PairwiseDistance { + let mut a = self.neighbours[0]; // Start with first point + let mut d = self.distances[&a].distance; + for p in self.neighbours.iter() { + if self.distances[p].distance < d { + a = *p; // Update `a` and distance `d` + d = self.distances[p].distance; + } + } + let b = self.distances[&a].neighbour; + PairwiseDistance { + node: a, + neighbour: b, + distance: d, + } + } + + /// + /// Return order dissimilarities from closest to furthest + /// + #[allow(dead_code)] + pub fn ordered_pairs(&self) -> std::vec::IntoIter<&PairwiseDistance> { + // improvement: implement this to return `impl Iterator>` + // need to implement trait `Iterator` for `Vec<&PairwiseDistance>` + let mut distances = self + .distances + .values() + .collect::>>(); + distances.sort_by(|a, b| a.partial_cmp(b).unwrap()); + distances.into_iter() + } + + // + // Compute distances from input to all other points in data-structure. + // input is the row index of the sample matrix + // + #[allow(dead_code)] + fn distances_from(&self, index_row: usize) -> Vec> { + let mut distances = Vec::>::with_capacity(self.samples.shape().0); + for other in self.neighbours.iter() { + if index_row != *other { + distances.push(PairwiseDistance { + node: index_row, + neighbour: Some(*other), + distance: Some( + T::from(Cosine::new().distance( + &Vec::from_iterator( + self.samples.get_row(index_row).iterator(0).copied(), + self.samples.shape().1, + ), + &Vec::from_iterator( + self.samples.get_row(*other).iterator(0).copied(), + self.samples.shape().1, + ), + )) + .unwrap(), + ), + }) + } + } + distances + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix}; + use approx::assert_relative_eq; + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn cosine_pair_initialization() { + let x = DenseMatrix::::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + &[4.6, 3.1, 1.5, 0.2], + &[5.0, 3.6, 1.4, 0.2], + &[5.4, 3.9, 1.7, 0.4], + ]).unwrap(); + + let cosine_pair = CosinePair::new(&x); + + assert!(cosine_pair.is_ok()); + let cp = cosine_pair.unwrap(); + + assert_eq!(cp.samples.shape().0, 6); + assert_eq!(cp.distances.len(), 6); + assert_eq!(cp.neighbours.len(), 6); + assert!(!cp.distances.is_empty()); + assert!(!cp.neighbours.is_empty()); + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn cosine_pair_minimum_rows_error() { + // Test with only one row - should fail + let x = DenseMatrix::::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + ]).unwrap(); + + let result = CosinePair::new(&x); + assert!(result.is_err()); + + if let Err(e) = result { + let expected_error = Failed::because( + FailedError::FindFailed, + "min number of rows should be 2" + ); + assert_eq!(e, expected_error); + } + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn cosine_pair_closest_pair() { + let x = DenseMatrix::::from_2d_array(&[ + &[1.0, 0.0], + &[0.0, 1.0], + &[1.0, 1.0], + &[2.0, 2.0], // This should be closest to [1.0, 1.0] with cosine distance + ]).unwrap(); + + let cosine_pair = CosinePair::new(&x).unwrap(); + let closest_pair = cosine_pair.closest_pair(); + + // Verify structure + assert!(closest_pair.distance.is_some()); + assert!(closest_pair.neighbour.is_some()); + + // The closest pair should have the smallest cosine distance + let distance = closest_pair.distance.unwrap(); + assert!(distance >= 0.0 && distance <= 2.0); // Cosine distance range + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn cosine_pair_identical_vectors() { + let x = DenseMatrix::::from_2d_array(&[ + &[1.0, 2.0, 3.0], + &[1.0, 2.0, 3.0], // Identical vector + &[4.0, 5.0, 6.0], + ]).unwrap(); + + let cosine_pair = CosinePair::new(&x).unwrap(); + let closest_pair = cosine_pair.closest_pair(); + + // Distance between identical vectors should be 0 + let distance = closest_pair.distance.unwrap(); + assert!((distance - 0.0).abs() < 1e-8); + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn cosine_pair_orthogonal_vectors() { + let x = DenseMatrix::::from_2d_array(&[ + &[1.0, 0.0], + &[0.0, 1.0], // Orthogonal to first + &[2.0, 3.0], + ]).unwrap(); + + let cosine_pair = CosinePair::new(&x).unwrap(); + + // Check that orthogonal vectors have cosine distance of 1.0 + let distances_from_first = cosine_pair.distances_from(0); + let orthogonal_distance = distances_from_first.iter() + .find(|pd| pd.neighbour == Some(1)) + .unwrap() + .distance + .unwrap(); + + assert!((orthogonal_distance - 1.0).abs() < 1e-8); + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn cosine_pair_ordered_pairs() { + let x = DenseMatrix::::from_2d_array(&[ + &[1.0, 2.0], + &[2.0, 1.0], + &[3.0, 4.0], + &[4.0, 3.0], + ]).unwrap(); + + let cosine_pair = CosinePair::new(&x).unwrap(); + let ordered_pairs: Vec<_> = cosine_pair.ordered_pairs().collect(); + + assert_eq!(ordered_pairs.len(), 4); + + // Check that pairs are ordered by distance (ascending) + for i in 1..ordered_pairs.len() { + let prev_distance = ordered_pairs[i-1].distance.unwrap(); + let curr_distance = ordered_pairs[i].distance.unwrap(); + assert!(prev_distance <= curr_distance); + } + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn cosine_pair_query_row() { + let x = DenseMatrix::::from_2d_array(&[ + &[1.0, 0.0, 0.0], + &[0.0, 1.0, 0.0], + &[0.0, 0.0, 1.0], + &[1.0, 1.0, 0.0], + &[0.0, 1.0, 1.0], + ]).unwrap(); + + let cosine_pair = CosinePair::new(&x).unwrap(); + + // Query k=2 nearest neighbors for row 0 + let neighbors = cosine_pair.query_row(0, 2).unwrap(); + + assert_eq!(neighbors.len(), 2); + + // Check that distances are in ascending order + assert!(neighbors[0].0 <= neighbors[1].0); + + // All distances should be valid cosine distances (0 to 2) + for (distance, _) in &neighbors { + assert!(*distance >= 0.0 && *distance <= 2.0); + } + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn cosine_pair_query_row_bounds_error() { + let x = DenseMatrix::::from_2d_array(&[ + &[1.0, 2.0], + &[3.0, 4.0], + ]).unwrap(); + + let cosine_pair = CosinePair::new(&x).unwrap(); + + // Query with out-of-bounds row index + let result = cosine_pair.query_row(5, 1); + assert!(result.is_err()); + + if let Err(e) = result { + let expected_error = Failed::because( + FailedError::FindFailed, + "Query row index out of bounds" + ); + assert_eq!(e, expected_error); + } + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn cosine_pair_query_row_k_zero() { + let x = DenseMatrix::::from_2d_array(&[ + &[1.0, 2.0], + &[3.0, 4.0], + &[5.0, 6.0], + ]).unwrap(); + + let cosine_pair = CosinePair::new(&x).unwrap(); + let neighbors = cosine_pair.query_row(0, 0).unwrap(); + + assert_eq!(neighbors.len(), 0); + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn cosine_pair_query_external_vector() { + let x = DenseMatrix::::from_2d_array(&[ + &[1.0, 0.0, 0.0], + &[0.0, 1.0, 0.0], + &[0.0, 0.0, 1.0], + &[1.0, 1.0, 0.0], + ]).unwrap(); + + let cosine_pair = CosinePair::new(&x).unwrap(); + + // Query with external vector + let query_vector = vec![1.0, 0.5, 0.0]; + let neighbors = cosine_pair.query(&query_vector, 2).unwrap(); + + assert_eq!(neighbors.len(), 2); + + // Verify distances are valid and ordered + assert!(neighbors[0].0 <= neighbors[1].0); + for (distance, index) in &neighbors { + assert!(*distance >= 0.0 && *distance <= 2.0); + assert!(*index < x.shape().0); + } + } + + #[test] + fn cosine_pair_query_dimension_mismatch() { + let x = DenseMatrix::::from_2d_array(&[ + &[1.0, 2.0, 3.0], + &[4.0, 5.0, 6.0], + ]).unwrap(); + + let cosine_pair = CosinePair::new(&x).unwrap(); + + // Query with mismatched dimensions + let query_vector = vec![1.0, 2.0]; // Only 2 dimensions, but data has 3 + let result = cosine_pair.query(&query_vector, 1); + + assert!(result.is_err()); + if let Err(e) = result { + let expected_error = Failed::because( + FailedError::FindFailed, + "Query vector dimension mismatch" + ); + assert_eq!(e, expected_error); + } + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn cosine_pair_query_k_zero_external() { + let x = DenseMatrix::::from_2d_array(&[ + &[1.0, 2.0], + &[3.0, 4.0], + ]).unwrap(); + + let cosine_pair = CosinePair::new(&x).unwrap(); + let query_vector = vec![1.0, 1.0]; + let neighbors = cosine_pair.query(&query_vector, 0).unwrap(); + + assert_eq!(neighbors.len(), 0); + } + + #[test] + fn cosine_pair_large_dataset() { + // Test with larger dataset (similar to Iris) + let x = DenseMatrix::::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + &[4.6, 3.1, 1.5, 0.2], + &[5.0, 3.6, 1.4, 0.2], + &[5.4, 3.9, 1.7, 0.4], + &[4.6, 3.4, 1.4, 0.3], + &[5.0, 3.4, 1.5, 0.2], + &[4.4, 2.9, 1.4, 0.2], + &[4.9, 3.1, 1.5, 0.1], + &[7.0, 3.2, 4.7, 1.4], + &[6.4, 3.2, 4.5, 1.5], + &[6.9, 3.1, 4.9, 1.5], + &[5.5, 2.3, 4.0, 1.3], + &[6.5, 2.8, 4.6, 1.5], + ]).unwrap(); + + let cosine_pair = CosinePair::new(&x).unwrap(); + + assert_eq!(cosine_pair.samples.shape().0, 15); + assert_eq!(cosine_pair.distances.len(), 15); + assert_eq!(cosine_pair.neighbours.len(), 15); + + // Test closest pair computation + let closest_pair = cosine_pair.closest_pair(); + assert!(closest_pair.distance.is_some()); + assert!(closest_pair.neighbour.is_some()); + + let distance = closest_pair.distance.unwrap(); + assert!(distance >= 0.0 && distance <= 2.0); + } + + #[test] + fn cosine_pair_float_precision() { + // Test with f32 precision + let x = DenseMatrix::::from_2d_array(&[ + &[1.0f32, 2.0, 3.0], + &[4.0f32, 5.0, 6.0], + &[7.0f32, 8.0, 9.0], + ]).unwrap(); + + let cosine_pair = CosinePair::new(&x).unwrap(); + let closest_pair = cosine_pair.closest_pair(); + + assert!(closest_pair.distance.is_some()); + let distance = closest_pair.distance.unwrap(); + assert!(distance >= 0.0 && distance <= 2.0); + + // Test querying + let neighbors = cosine_pair.query_row(0, 2).unwrap(); + assert_eq!(neighbors.len(), 2); + assert_eq!(neighbors[0].1, 1); + assert_relative_eq!(neighbors[0].0, 0.025368154); + assert_eq!(neighbors[1].1, 2); + assert_relative_eq!(neighbors[1].0, 0.040588055); + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn cosine_pair_distances_from() { + let x = DenseMatrix::::from_2d_array(&[ + &[1.0, 0.0], + &[0.0, 1.0], + &[1.0, 1.0], + &[2.0, 0.0], + ]).unwrap(); + + let cosine_pair = CosinePair::new(&x).unwrap(); + let distances = cosine_pair.distances_from(0); + + // Should have 3 distances (excluding self) + assert_eq!(distances.len(), 3); + + // All should be from node 0 + for pd in &distances { + assert_eq!(pd.node, 0); + assert!(pd.neighbour.is_some()); + assert!(pd.distance.is_some()); + assert!(pd.neighbour.unwrap() != 0); // Should not include self + } + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn cosine_pair_consistency_check() { + // Verify that different query methods return consistent results + let x = DenseMatrix::::from_2d_array(&[ + &[1.0, 2.0, 3.0], + &[4.0, 5.0, 6.0], + &[7.0, 8.0, 9.0], + &[2.0, 3.0, 4.0], + ]).unwrap(); + + let cosine_pair = CosinePair::new(&x).unwrap(); + + // Query row 0 using internal method + let neighbors_internal = cosine_pair.query_row(0, 2).unwrap(); + + // Query row 0 using optimized method (should be same) + let neighbors_optimized = cosine_pair.query_optimized(0, 2).unwrap(); + + assert_eq!(neighbors_internal.len(), neighbors_optimized.len()); + for i in 0..neighbors_internal.len() { + let (dist1, idx1) = neighbors_internal[i]; + let (dist2, idx2) = neighbors_optimized[i]; + assert!((dist1 - dist2).abs() < 1e-10); + assert_eq!(idx1, idx2); + } + } + + // Brute force algorithm for testing/comparison + fn closest_pair_brute_force(cosine_pair: &CosinePair<'_, f64, DenseMatrix>) -> PairwiseDistance { + use itertools::Itertools; + + let m = cosine_pair.samples.shape().0; + let mut closest_pair = PairwiseDistance { + node: 0, + neighbour: None, + distance: Some(f64::MAX), + }; + + for pair in (0..m).combinations(2) { + let d = Cosine::new().distance( + &Vec::from_iterator( + cosine_pair.samples.get_row(pair[0]).iterator(0).copied(), + cosine_pair.samples.shape().1, + ), + &Vec::from_iterator( + cosine_pair.samples.get_row(pair[1]).iterator(0).copied(), + cosine_pair.samples.shape().1, + ), + ); + + if d < closest_pair.distance.unwrap() { + closest_pair.node = pair[0]; + closest_pair.neighbour = Some(pair[1]); + closest_pair.distance = Some(d); + } + } + + closest_pair + } + + #[test] + fn cosine_pair_vs_brute_force() { + let x = DenseMatrix::::from_2d_array(&[ + &[1.0, 2.0, 3.0], + &[4.0, 5.0, 6.0], + &[7.0, 8.0, 9.0], + &[1.1, 2.1, 3.1], // Close to first point + ]).unwrap(); + + let cosine_pair = CosinePair::new(&x).unwrap(); + let cp_result = cosine_pair.closest_pair(); + let brute_result = closest_pair_brute_force(&cosine_pair); + + // Results should be identical or very close + assert!((cp_result.distance.unwrap() - brute_result.distance.unwrap()).abs() < 1e-10); + } +} diff --git a/src/algorithm/neighbour/mod.rs b/src/algorithm/neighbour/mod.rs index 3bee93aa..73d3b9e9 100644 --- a/src/algorithm/neighbour/mod.rs +++ b/src/algorithm/neighbour/mod.rs @@ -43,6 +43,8 @@ pub(crate) mod bbd_tree; pub mod cover_tree; /// fastpair closest neighbour algorithm pub mod fastpair; +/// a variant of fastpair using cosine distance +pub mod cosinepair; /// very simple algorithm that sequentially checks each element of the list until a match is found or the whole list has been searched. pub mod linear_search; diff --git a/src/metrics/distance/cosine.rs b/src/metrics/distance/cosine.rs new file mode 100644 index 00000000..e783adc4 --- /dev/null +++ b/src/metrics/distance/cosine.rs @@ -0,0 +1,219 @@ +//! # Cosine Distance Metric +//! +//! The cosine distance between two points \\( x \\) and \\( y \\) in n-space is defined as: +//! +//! \\[ d(x, y) = 1 - \frac{x \cdot y}{||x|| ||y||} \\] +//! +//! where \\( x \cdot y \\) is the dot product of the vectors, and \\( ||x|| \\) and \\( ||y|| \\) +//! are their respective magnitudes (Euclidean norms). +//! +//! Cosine distance measures the angular dissimilarity between vectors, ranging from 0 to 2. +//! A value of 0 indicates identical direction (parallel vectors), while larger values indicate +//! greater angular separation. +//! +//! Example: +//! +//! ``` +//! use smartcore::metrics::distance::Distance; +//! use smartcore::metrics::distance::cosine::Cosine; +//! +//! let x = vec![1., 1.]; +//! let y = vec![2., 2.]; +//! +//! let cosine_dist: f64 = Cosine::new().distance(&x, &y); +//! ``` +//! +//! +//! +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; +use std::marker::PhantomData; + +use crate::linalg::basic::arrays::ArrayView1; +use crate::numbers::basenum::Number; + +use super::Distance; + +/// Cosine distance is a measure of the angular dissimilarity between two non-zero vectors in n-space. +/// It is defined as 1 minus the cosine similarity of the vectors. +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct Cosine { + _t: PhantomData, +} + +impl Default for Cosine { + fn default() -> Self { + Self::new() + } +} + +impl Cosine { + /// Instantiate the initial structure + pub fn new() -> Cosine { + Cosine { _t: PhantomData } + } + + /// Calculate the dot product of two vectors using smartcore's ArrayView1 trait + #[inline] + pub(crate) fn dot_product>(x: &A, y: &A) -> f64 { + if x.shape() != y.shape() { + panic!("Input vector sizes are different."); + } + + // Use the built-in dot product method from ArrayView1 trait + x.dot(y).to_f64().unwrap() + } + + /// Calculate the squared magnitude (norm squared) of a vector + #[inline] + #[allow(dead_code)] + pub(crate) fn squared_magnitude>(x: &A) -> f64 { + x.iterator(0) + .map(|&a| { + let val = a.to_f64().unwrap(); + val * val + }) + .sum() + } + + /// Calculate the magnitude (Euclidean norm) of a vector using smartcore's norm2 method + #[inline] + pub(crate) fn magnitude>(x: &A) -> f64 { + // Use the built-in norm2 method from ArrayView1 trait + x.norm2() + } + + /// Calculate cosine similarity between two vectors + #[inline] + pub(crate) fn cosine_similarity>(x: &A, y: &A) -> f64 { + let dot_product = Self::dot_product(x, y); + let magnitude_x = Self::magnitude(x); + let magnitude_y = Self::magnitude(y); + + if magnitude_x == 0.0 || magnitude_y == 0.0 { + panic!("Cannot compute cosine distance for zero-magnitude vectors."); + } + + dot_product / (magnitude_x * magnitude_y) + } +} + +impl> Distance for Cosine { + fn distance(&self, x: &A, y: &A) -> f64 { + let similarity = Cosine::cosine_similarity(x, y); + 1.0 - similarity + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn cosine_distance_identical_vectors() { + let a = vec![1, 2, 3]; + let b = vec![1, 2, 3]; + + let dist: f64 = Cosine::new().distance(&a, &b); + + assert!((dist - 0.0).abs() < 1e-8); + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn cosine_distance_orthogonal_vectors() { + let a = vec![1, 0]; + let b = vec![0, 1]; + + let dist: f64 = Cosine::new().distance(&a, &b); + + assert!((dist - 1.0).abs() < 1e-8); + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn cosine_distance_opposite_vectors() { + let a = vec![1, 2, 3]; + let b = vec![-1, -2, -3]; + + let dist: f64 = Cosine::new().distance(&a, &b); + + assert!((dist - 2.0).abs() < 1e-8); + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn cosine_distance_general_case() { + let a = vec![1.0, 2.0, 3.0]; + let b = vec![2.0, 1.0, 3.0]; + + let dist: f64 = Cosine::new().distance(&a, &b); + + // Expected cosine similarity: (1*2 + 2*1 + 3*3) / (sqrt(1+4+9) * sqrt(4+1+9)) + // = (2 + 2 + 9) / (sqrt(14) * sqrt(14)) = 13/14 ≈ 0.9286 + // So cosine distance = 1 - 13/14 = 1/14 ≈ 0.0714 + let expected_dist = 1.0 - (13.0 / 14.0); + assert!((dist - expected_dist).abs() < 1e-8); + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + #[should_panic(expected = "Input vector sizes are different.")] + fn cosine_distance_different_sizes() { + let a = vec![1, 2]; + let b = vec![1, 2, 3]; + + let _dist: f64 = Cosine::new().distance(&a, &b); + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + #[should_panic(expected = "Cannot compute cosine distance for zero-magnitude vectors.")] + fn cosine_distance_zero_vector() { + let a = vec![0, 0, 0]; + let b = vec![1, 2, 3]; + + let _dist: f64 = Cosine::new().distance(&a, &b); + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn cosine_distance_float_precision() { + let a = vec![1.0f32, 2.0, 3.0]; + let b = vec![4.0f32, 5.0, 6.0]; + + let dist: f64 = Cosine::new().distance(&a, &b); + + // Calculate expected value manually + let dot_product = 1.0*4.0 + 2.0*5.0 + 3.0*6.0; // = 32 + let mag_a = (1.0*1.0 + 2.0*2.0 + 3.0*3.0_f64).sqrt(); // = sqrt(14) + let mag_b = (4.0*4.0 + 5.0*5.0 + 6.0*6.0_f64).sqrt(); // = sqrt(77) + let expected_similarity = dot_product / (mag_a * mag_b); + let expected_distance = 1.0 - expected_similarity; + + assert!((dist - expected_distance).abs() < 1e-6); + } +} diff --git a/src/metrics/distance/mod.rs b/src/metrics/distance/mod.rs index 193d7a19..1f044aa1 100644 --- a/src/metrics/distance/mod.rs +++ b/src/metrics/distance/mod.rs @@ -23,6 +23,8 @@ pub mod mahalanobis; pub mod manhattan; /// A generalization of both the Euclidean distance and the Manhattan distance. pub mod minkowski; +/// Cosine distance +pub mod cosine; use std::cmp::{Eq, Ordering, PartialOrd}; From bb6cf19d4db3d5e0c767cedd08d061128c9f8da5 Mon Sep 17 00:00:00 2001 From: Lorenzo Mec-iS Date: Sat, 27 Sep 2025 10:53:38 +0100 Subject: [PATCH 02/10] bump version --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 3c1b8ab9..bd9db328 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ name = "smartcore" description = "Machine Learning in Rust." homepage = "https://smartcorelib.org" -version = "0.4.2" +version = "0.4.3" authors = ["smartcore Developers"] edition = "2021" license = "Apache-2.0" From 91fe93c549dc913799ae6b4402e6e6a476ca4504 Mon Sep 17 00:00:00 2001 From: Lorenzo Mec-iS Date: Sat, 27 Sep 2025 10:55:06 +0100 Subject: [PATCH 03/10] formatting --- src/algorithm/neighbour/cosinepair.rs | 257 +++++++++++++------------- src/algorithm/neighbour/mod.rs | 4 +- src/metrics/distance/cosine.rs | 12 +- src/metrics/distance/mod.rs | 4 +- 4 files changed, 141 insertions(+), 136 deletions(-) diff --git a/src/algorithm/neighbour/cosinepair.rs b/src/algorithm/neighbour/cosinepair.rs index 3c2effe6..0a38f993 100644 --- a/src/algorithm/neighbour/cosinepair.rs +++ b/src/algorithm/neighbour/cosinepair.rs @@ -159,81 +159,91 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2> CosinePair<'a, T, M> { if query_row_index >= self.samples.shape().0 { return Err(Failed::because( FailedError::FindFailed, - "Query row index out of bounds" + "Query row index out of bounds", )); } - + if k == 0 { return Ok(Vec::new()); } - + // Get distances to all other points let mut distances = self.distances_from(query_row_index); - + // Sort by distance (ascending) distances.sort_by(|a, b| { - a.distance.unwrap().partial_cmp(&b.distance.unwrap()).unwrap_or(std::cmp::Ordering::Equal) + a.distance + .unwrap() + .partial_cmp(&b.distance.unwrap()) + .unwrap_or(std::cmp::Ordering::Equal) }); - + // Take top k neighbors and convert to (distance, index) format let neighbors: Vec<(T, usize)> = distances .into_iter() .take(k) .map(|pd| (pd.distance.unwrap(), pd.neighbour.unwrap())) .collect(); - + Ok(neighbors) } - + /// Query k nearest neighbors for an external query vector pub fn query(&self, query_vector: &Vec, k: usize) -> Result, Failed> { if query_vector.len() != self.samples.shape().1 { return Err(Failed::because( FailedError::FindFailed, - "Query vector dimension mismatch" + "Query vector dimension mismatch", )); } - + if k == 0 { return Ok(Vec::new()); } - + // Compute distances from query vector to all points in the dataset let mut distances = Vec::>::with_capacity(self.samples.shape().0); - + for i in 0..self.samples.shape().0 { let dataset_point = Vec::from_iterator( self.samples.get_row(i).iterator(0).copied(), self.samples.shape().1, ); - + let distance = T::from(Cosine::new().distance(query_vector, &dataset_point)).unwrap(); - + distances.push(PairwiseDistance { node: i, // This represents the dataset point index neighbour: Some(i), distance: Some(distance), }); } - + // Sort by distance (ascending) distances.sort_by(|a, b| { - a.distance.unwrap().partial_cmp(&b.distance.unwrap()).unwrap_or(std::cmp::Ordering::Equal) + a.distance + .unwrap() + .partial_cmp(&b.distance.unwrap()) + .unwrap_or(std::cmp::Ordering::Equal) }); - + // Take top k neighbors and convert to (distance, index) format let neighbors: Vec<(T, usize)> = distances .into_iter() .take(k) .map(|pd| (pd.distance.unwrap(), pd.node)) .collect(); - + Ok(neighbors) } - + /// Optimized version that reuses the existing distances_from method /// This is more efficient for queries that are points already in the dataset - pub fn query_optimized(&self, query_row_index: usize, k: usize) -> Result, Failed> { + pub fn query_optimized( + &self, + query_row_index: usize, + k: usize, + ) -> Result, Failed> { // Reuse existing method and sort the results self.query_row(query_row_index, k) } @@ -323,13 +333,14 @@ mod tests { &[4.6, 3.1, 1.5, 0.2], &[5.0, 3.6, 1.4, 0.2], &[5.4, 3.9, 1.7, 0.4], - ]).unwrap(); - + ]) + .unwrap(); + let cosine_pair = CosinePair::new(&x); - + assert!(cosine_pair.is_ok()); let cp = cosine_pair.unwrap(); - + assert_eq!(cp.samples.shape().0, 6); assert_eq!(cp.distances.len(), 6); assert_eq!(cp.neighbours.len(), 6); @@ -341,21 +352,17 @@ mod tests { all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test )] - #[test] + #[test] fn cosine_pair_minimum_rows_error() { // Test with only one row - should fail - let x = DenseMatrix::::from_2d_array(&[ - &[5.1, 3.5, 1.4, 0.2], - ]).unwrap(); - + let x = DenseMatrix::::from_2d_array(&[&[5.1, 3.5, 1.4, 0.2]]).unwrap(); + let result = CosinePair::new(&x); assert!(result.is_err()); - + if let Err(e) = result { - let expected_error = Failed::because( - FailedError::FindFailed, - "min number of rows should be 2" - ); + let expected_error = + Failed::because(FailedError::FindFailed, "min number of rows should be 2"); assert_eq!(e, expected_error); } } @@ -371,15 +378,16 @@ mod tests { &[0.0, 1.0], &[1.0, 1.0], &[2.0, 2.0], // This should be closest to [1.0, 1.0] with cosine distance - ]).unwrap(); - + ]) + .unwrap(); + let cosine_pair = CosinePair::new(&x).unwrap(); let closest_pair = cosine_pair.closest_pair(); - + // Verify structure assert!(closest_pair.distance.is_some()); assert!(closest_pair.neighbour.is_some()); - + // The closest pair should have the smallest cosine distance let distance = closest_pair.distance.unwrap(); assert!(distance >= 0.0 && distance <= 2.0); // Cosine distance range @@ -395,11 +403,12 @@ mod tests { &[1.0, 2.0, 3.0], &[1.0, 2.0, 3.0], // Identical vector &[4.0, 5.0, 6.0], - ]).unwrap(); - + ]) + .unwrap(); + let cosine_pair = CosinePair::new(&x).unwrap(); let closest_pair = cosine_pair.closest_pair(); - + // Distance between identical vectors should be 0 let distance = closest_pair.distance.unwrap(); assert!((distance - 0.0).abs() < 1e-8); @@ -415,18 +424,20 @@ mod tests { &[1.0, 0.0], &[0.0, 1.0], // Orthogonal to first &[2.0, 3.0], - ]).unwrap(); - + ]) + .unwrap(); + let cosine_pair = CosinePair::new(&x).unwrap(); - + // Check that orthogonal vectors have cosine distance of 1.0 let distances_from_first = cosine_pair.distances_from(0); - let orthogonal_distance = distances_from_first.iter() + let orthogonal_distance = distances_from_first + .iter() .find(|pd| pd.neighbour == Some(1)) .unwrap() .distance .unwrap(); - + assert!((orthogonal_distance - 1.0).abs() < 1e-8); } @@ -441,16 +452,17 @@ mod tests { &[2.0, 1.0], &[3.0, 4.0], &[4.0, 3.0], - ]).unwrap(); - + ]) + .unwrap(); + let cosine_pair = CosinePair::new(&x).unwrap(); let ordered_pairs: Vec<_> = cosine_pair.ordered_pairs().collect(); - + assert_eq!(ordered_pairs.len(), 4); - + // Check that pairs are ordered by distance (ascending) for i in 1..ordered_pairs.len() { - let prev_distance = ordered_pairs[i-1].distance.unwrap(); + let prev_distance = ordered_pairs[i - 1].distance.unwrap(); let curr_distance = ordered_pairs[i].distance.unwrap(); assert!(prev_distance <= curr_distance); } @@ -468,18 +480,19 @@ mod tests { &[0.0, 0.0, 1.0], &[1.0, 1.0, 0.0], &[0.0, 1.0, 1.0], - ]).unwrap(); - + ]) + .unwrap(); + let cosine_pair = CosinePair::new(&x).unwrap(); - + // Query k=2 nearest neighbors for row 0 let neighbors = cosine_pair.query_row(0, 2).unwrap(); - + assert_eq!(neighbors.len(), 2); - + // Check that distances are in ascending order assert!(neighbors[0].0 <= neighbors[1].0); - + // All distances should be valid cosine distances (0 to 2) for (distance, _) in &neighbors { assert!(*distance >= 0.0 && *distance <= 2.0); @@ -492,22 +505,17 @@ mod tests { )] #[test] fn cosine_pair_query_row_bounds_error() { - let x = DenseMatrix::::from_2d_array(&[ - &[1.0, 2.0], - &[3.0, 4.0], - ]).unwrap(); - + let x = DenseMatrix::::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap(); + let cosine_pair = CosinePair::new(&x).unwrap(); - + // Query with out-of-bounds row index let result = cosine_pair.query_row(5, 1); assert!(result.is_err()); - + if let Err(e) = result { - let expected_error = Failed::because( - FailedError::FindFailed, - "Query row index out of bounds" - ); + let expected_error = + Failed::because(FailedError::FindFailed, "Query row index out of bounds"); assert_eq!(e, expected_error); } } @@ -518,15 +526,12 @@ mod tests { )] #[test] fn cosine_pair_query_row_k_zero() { - let x = DenseMatrix::::from_2d_array(&[ - &[1.0, 2.0], - &[3.0, 4.0], - &[5.0, 6.0], - ]).unwrap(); - + let x = + DenseMatrix::::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0], &[5.0, 6.0]]).unwrap(); + let cosine_pair = CosinePair::new(&x).unwrap(); let neighbors = cosine_pair.query_row(0, 0).unwrap(); - + assert_eq!(neighbors.len(), 0); } @@ -538,19 +543,20 @@ mod tests { fn cosine_pair_query_external_vector() { let x = DenseMatrix::::from_2d_array(&[ &[1.0, 0.0, 0.0], - &[0.0, 1.0, 0.0], + &[0.0, 1.0, 0.0], &[0.0, 0.0, 1.0], &[1.0, 1.0, 0.0], - ]).unwrap(); - + ]) + .unwrap(); + let cosine_pair = CosinePair::new(&x).unwrap(); - + // Query with external vector let query_vector = vec![1.0, 0.5, 0.0]; let neighbors = cosine_pair.query(&query_vector, 2).unwrap(); - + assert_eq!(neighbors.len(), 2); - + // Verify distances are valid and ordered assert!(neighbors[0].0 <= neighbors[1].0); for (distance, index) in &neighbors { @@ -561,23 +567,18 @@ mod tests { #[test] fn cosine_pair_query_dimension_mismatch() { - let x = DenseMatrix::::from_2d_array(&[ - &[1.0, 2.0, 3.0], - &[4.0, 5.0, 6.0], - ]).unwrap(); - + let x = DenseMatrix::::from_2d_array(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]]).unwrap(); + let cosine_pair = CosinePair::new(&x).unwrap(); - + // Query with mismatched dimensions let query_vector = vec![1.0, 2.0]; // Only 2 dimensions, but data has 3 let result = cosine_pair.query(&query_vector, 1); - + assert!(result.is_err()); if let Err(e) = result { - let expected_error = Failed::because( - FailedError::FindFailed, - "Query vector dimension mismatch" - ); + let expected_error = + Failed::because(FailedError::FindFailed, "Query vector dimension mismatch"); assert_eq!(e, expected_error); } } @@ -588,15 +589,12 @@ mod tests { )] #[test] fn cosine_pair_query_k_zero_external() { - let x = DenseMatrix::::from_2d_array(&[ - &[1.0, 2.0], - &[3.0, 4.0], - ]).unwrap(); - + let x = DenseMatrix::::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap(); + let cosine_pair = CosinePair::new(&x).unwrap(); let query_vector = vec![1.0, 1.0]; let neighbors = cosine_pair.query(&query_vector, 0).unwrap(); - + assert_eq!(neighbors.len(), 0); } @@ -619,19 +617,20 @@ mod tests { &[6.9, 3.1, 4.9, 1.5], &[5.5, 2.3, 4.0, 1.3], &[6.5, 2.8, 4.6, 1.5], - ]).unwrap(); - + ]) + .unwrap(); + let cosine_pair = CosinePair::new(&x).unwrap(); - + assert_eq!(cosine_pair.samples.shape().0, 15); assert_eq!(cosine_pair.distances.len(), 15); assert_eq!(cosine_pair.neighbours.len(), 15); - + // Test closest pair computation let closest_pair = cosine_pair.closest_pair(); assert!(closest_pair.distance.is_some()); assert!(closest_pair.neighbour.is_some()); - + let distance = closest_pair.distance.unwrap(); assert!(distance >= 0.0 && distance <= 2.0); } @@ -643,15 +642,16 @@ mod tests { &[1.0f32, 2.0, 3.0], &[4.0f32, 5.0, 6.0], &[7.0f32, 8.0, 9.0], - ]).unwrap(); - + ]) + .unwrap(); + let cosine_pair = CosinePair::new(&x).unwrap(); let closest_pair = cosine_pair.closest_pair(); - + assert!(closest_pair.distance.is_some()); let distance = closest_pair.distance.unwrap(); assert!(distance >= 0.0 && distance <= 2.0); - + // Test querying let neighbors = cosine_pair.query_row(0, 2).unwrap(); assert_eq!(neighbors.len(), 2); @@ -665,21 +665,22 @@ mod tests { all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test )] - #[test] + #[test] fn cosine_pair_distances_from() { let x = DenseMatrix::::from_2d_array(&[ &[1.0, 0.0], &[0.0, 1.0], &[1.0, 1.0], &[2.0, 0.0], - ]).unwrap(); - + ]) + .unwrap(); + let cosine_pair = CosinePair::new(&x).unwrap(); let distances = cosine_pair.distances_from(0); - + // Should have 3 distances (excluding self) assert_eq!(distances.len(), 3); - + // All should be from node 0 for pd in &distances { assert_eq!(pd.node, 0); @@ -698,19 +699,20 @@ mod tests { // Verify that different query methods return consistent results let x = DenseMatrix::::from_2d_array(&[ &[1.0, 2.0, 3.0], - &[4.0, 5.0, 6.0], + &[4.0, 5.0, 6.0], &[7.0, 8.0, 9.0], &[2.0, 3.0, 4.0], - ]).unwrap(); - + ]) + .unwrap(); + let cosine_pair = CosinePair::new(&x).unwrap(); - + // Query row 0 using internal method let neighbors_internal = cosine_pair.query_row(0, 2).unwrap(); - + // Query row 0 using optimized method (should be same) let neighbors_optimized = cosine_pair.query_optimized(0, 2).unwrap(); - + assert_eq!(neighbors_internal.len(), neighbors_optimized.len()); for i in 0..neighbors_internal.len() { let (dist1, idx1) = neighbors_internal[i]; @@ -721,16 +723,18 @@ mod tests { } // Brute force algorithm for testing/comparison - fn closest_pair_brute_force(cosine_pair: &CosinePair<'_, f64, DenseMatrix>) -> PairwiseDistance { + fn closest_pair_brute_force( + cosine_pair: &CosinePair<'_, f64, DenseMatrix>, + ) -> PairwiseDistance { use itertools::Itertools; - + let m = cosine_pair.samples.shape().0; let mut closest_pair = PairwiseDistance { node: 0, neighbour: None, distance: Some(f64::MAX), }; - + for pair in (0..m).combinations(2) { let d = Cosine::new().distance( &Vec::from_iterator( @@ -742,14 +746,14 @@ mod tests { cosine_pair.samples.shape().1, ), ); - + if d < closest_pair.distance.unwrap() { closest_pair.node = pair[0]; closest_pair.neighbour = Some(pair[1]); closest_pair.distance = Some(d); } } - + closest_pair } @@ -760,12 +764,13 @@ mod tests { &[4.0, 5.0, 6.0], &[7.0, 8.0, 9.0], &[1.1, 2.1, 3.1], // Close to first point - ]).unwrap(); - + ]) + .unwrap(); + let cosine_pair = CosinePair::new(&x).unwrap(); let cp_result = cosine_pair.closest_pair(); let brute_result = closest_pair_brute_force(&cosine_pair); - + // Results should be identical or very close assert!((cp_result.distance.unwrap() - brute_result.distance.unwrap()).abs() < 1e-10); } diff --git a/src/algorithm/neighbour/mod.rs b/src/algorithm/neighbour/mod.rs index 73d3b9e9..c13e914a 100644 --- a/src/algorithm/neighbour/mod.rs +++ b/src/algorithm/neighbour/mod.rs @@ -39,12 +39,12 @@ use crate::numbers::basenum::Number; use serde::{Deserialize, Serialize}; pub(crate) mod bbd_tree; +/// a variant of fastpair using cosine distance +pub mod cosinepair; /// tree data structure for fast nearest neighbor search pub mod cover_tree; /// fastpair closest neighbour algorithm pub mod fastpair; -/// a variant of fastpair using cosine distance -pub mod cosinepair; /// very simple algorithm that sequentially checks each element of the list until a match is found or the whole list has been searched. pub mod linear_search; diff --git a/src/metrics/distance/cosine.rs b/src/metrics/distance/cosine.rs index e783adc4..ea065a07 100644 --- a/src/metrics/distance/cosine.rs +++ b/src/metrics/distance/cosine.rs @@ -60,7 +60,7 @@ impl Cosine { if x.shape() != y.shape() { panic!("Input vector sizes are different."); } - + // Use the built-in dot product method from ArrayView1 trait x.dot(y).to_f64().unwrap() } @@ -206,14 +206,14 @@ mod tests { let b = vec![4.0f32, 5.0, 6.0]; let dist: f64 = Cosine::new().distance(&a, &b); - + // Calculate expected value manually - let dot_product = 1.0*4.0 + 2.0*5.0 + 3.0*6.0; // = 32 - let mag_a = (1.0*1.0 + 2.0*2.0 + 3.0*3.0_f64).sqrt(); // = sqrt(14) - let mag_b = (4.0*4.0 + 5.0*5.0 + 6.0*6.0_f64).sqrt(); // = sqrt(77) + let dot_product = 1.0 * 4.0 + 2.0 * 5.0 + 3.0 * 6.0; // = 32 + let mag_a = (1.0 * 1.0 + 2.0 * 2.0 + 3.0 * 3.0_f64).sqrt(); // = sqrt(14) + let mag_b = (4.0 * 4.0 + 5.0 * 5.0 + 6.0 * 6.0_f64).sqrt(); // = sqrt(77) let expected_similarity = dot_product / (mag_a * mag_b); let expected_distance = 1.0 - expected_similarity; - + assert!((dist - expected_distance).abs() < 1e-6); } } diff --git a/src/metrics/distance/mod.rs b/src/metrics/distance/mod.rs index 1f044aa1..6fdbaa46 100644 --- a/src/metrics/distance/mod.rs +++ b/src/metrics/distance/mod.rs @@ -13,6 +13,8 @@ //! //! +/// Cosine distance +pub mod cosine; /// Euclidean Distance is the straight-line distance between two points in Euclidean spacere that presents the shortest distance between these points. pub mod euclidian; /// Hamming Distance between two strings is the number of positions at which the corresponding symbols are different. @@ -23,8 +25,6 @@ pub mod mahalanobis; pub mod manhattan; /// A generalization of both the Euclidean distance and the Manhattan distance. pub mod minkowski; -/// Cosine distance -pub mod cosine; use std::cmp::{Eq, Ordering, PartialOrd}; From a438f649ea6f0c2c551da8bfe2989a7e8b8805b3 Mon Sep 17 00:00:00 2001 From: Lorenzo Mec-iS Date: Sat, 27 Sep 2025 11:05:05 +0100 Subject: [PATCH 04/10] fix clippy --- src/algorithm/sort/quick_sort.rs | 1 + src/optimization/line_search.rs | 8 ++++---- src/tree/decision_tree_classifier.rs | 17 +++++++++++------ src/xgboost/xgb_regressor.rs | 2 +- 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/src/algorithm/sort/quick_sort.rs b/src/algorithm/sort/quick_sort.rs index e64c4243..56efec94 100644 --- a/src/algorithm/sort/quick_sort.rs +++ b/src/algorithm/sort/quick_sort.rs @@ -1,6 +1,7 @@ use num_traits::Num; pub trait QuickArgSort { + #[allow(dead_code)] fn quick_argsort_mut(&mut self) -> Vec; #[allow(dead_code)] diff --git a/src/optimization/line_search.rs b/src/optimization/line_search.rs index 8357d8da..98d2982c 100644 --- a/src/optimization/line_search.rs +++ b/src/optimization/line_search.rs @@ -6,8 +6,8 @@ pub trait LineSearchMethod { /// Find alpha that satisfies strong Wolfe conditions. fn search( &self, - f: &(dyn Fn(T) -> T), - df: &(dyn Fn(T) -> T), + f: &dyn Fn(T) -> T, + df: &dyn Fn(T) -> T, alpha: T, f0: T, df0: T, @@ -55,8 +55,8 @@ impl Default for Backtracking { impl LineSearchMethod for Backtracking { fn search( &self, - f: &(dyn Fn(T) -> T), - _: &(dyn Fn(T) -> T), + f: &dyn Fn(T) -> T, + _: &dyn Fn(T) -> T, alpha: T, f0: T, df0: T, diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index 5679516a..96007677 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -674,15 +674,20 @@ impl, Y: Array1> ) -> bool { let (n_rows, n_attr) = visitor.x.shape(); - let mut label = Option::None; + let mut label = None; let mut is_pure = true; for i in 0..n_rows { if visitor.samples[i] > 0 { - if label.is_none() { - label = Option::Some(visitor.y[i]); - } else if visitor.y[i] != label.unwrap() { - is_pure = false; - break; + match label { + None => { + label = Some(visitor.y[i]); + } + Some(current_label) => { + if visitor.y[i] != current_label { + is_pure = false; + break; + } + } } } } diff --git a/src/xgboost/xgb_regressor.rs b/src/xgboost/xgb_regressor.rs index ac6ec752..75c77a54 100644 --- a/src/xgboost/xgb_regressor.rs +++ b/src/xgboost/xgb_regressor.rs @@ -96,7 +96,7 @@ impl Objective { pub fn gradient>(&self, y_true: &Y, y_pred: &Vec) -> Vec { match self { Objective::MeanSquaredError => zip(y_true.iterator(0), y_pred) - .map(|(true_val, pred_val)| (*pred_val - true_val.to_f64().unwrap())) + .map(|(true_val, pred_val)| *pred_val - true_val.to_f64().unwrap()) .collect(), } } From d6e210e6ed5babafc3d11bec617e50395c8aab71 Mon Sep 17 00:00:00 2001 From: Lorenzo Mec-iS Date: Sat, 27 Sep 2025 17:22:39 +0100 Subject: [PATCH 05/10] Add top k CosinePair --- Cargo.toml | 1 + src/algorithm/neighbour/cosinepair.rs | 392 +++++++++++++++++++++----- 2 files changed, 318 insertions(+), 75 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index bd9db328..f1ffea75 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,7 @@ num = "0.4" rand = { version = "0.8.5", default-features = false, features = ["small_rng"] } rand_distr = { version = "0.4", optional = true } serde = { version = "1", features = ["derive"], optional = true } +ordered-float = "*" [target.'cfg(not(target_arch = "wasm32"))'.dependencies] typetag = { version = "0.2", optional = true } diff --git a/src/algorithm/neighbour/cosinepair.rs b/src/algorithm/neighbour/cosinepair.rs index 0a38f993..b2abbdb8 100644 --- a/src/algorithm/neighbour/cosinepair.rs +++ b/src/algorithm/neighbour/cosinepair.rs @@ -23,7 +23,10 @@ /// ``` /// /// -use std::collections::HashMap; +use ordered_float::{FloatCore, OrderedFloat}; + +use std::collections::{BinaryHeap, HashMap}; +use std::cmp::Reverse; use num::Bounded; @@ -34,6 +37,25 @@ use crate::metrics::distance::{Distance, PairwiseDistance}; use crate::numbers::floatnum::FloatNumber; use crate::numbers::realnum::RealNumber; + +/// Parameters for CosinePair construction +#[derive(Debug, Clone)] +pub struct CosinePairParameters { + /// Maximum number of neighbors to consider per point (default: all points) + pub top_k: Option, + /// Whether to use approximate nearest neighbor search + pub approximate: bool, +} + +impl Default for CosinePairParameters { + fn default() -> Self { + Self { + top_k: None, + approximate: false, + } + } +} + /// /// Inspired by Python implementation: /// @@ -49,12 +71,26 @@ pub struct CosinePair<'a, T: RealNumber + FloatNumber, M: Array2> { pub distances: HashMap>, /// conga line used to keep track of the closest pair pub neighbours: Vec, + /// parameters used during construction + pub parameters: CosinePairParameters, } -impl<'a, T: RealNumber + FloatNumber, M: Array2> CosinePair<'a, T, M> { - /// Constructor - /// Instantiate and initialize the algorithm +impl<'a, T: RealNumber + FloatNumber + FloatCore, M: Array2> CosinePair<'a, T, M> { + /// Constructor with default parameters (backward compatibility) pub fn new(m: &'a M) -> Result { + Self::with_parameters(m, CosinePairParameters::default()) + } + + /// Constructor with top-k limiting for faster performance + pub fn with_top_k(m: &'a M, top_k: usize) -> Result { + Self::with_parameters(m, CosinePairParameters { + top_k: Some(top_k), + approximate: false, + }) + } + + /// Constructor with full parameter control + pub fn with_parameters(m: &'a M, parameters: CosinePairParameters) -> Result { if m.shape().0 < 2 { return Err(Failed::because( FailedError::FindFailed, @@ -64,96 +100,150 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2> CosinePair<'a, T, M> { let mut init = Self { samples: m, - // to be computed in init(..) distances: HashMap::with_capacity(m.shape().0), - neighbours: Vec::with_capacity(m.shape().0 + 1), + neighbours: Vec::with_capacity(m.shape().0), + parameters, }; init.init(); Ok(init) } - /// Initialise `CosinePair` by passing a `Array2`. - /// Build a CosinePairs data-structure from a set of (new) points. + /// Helper function to create ordered float wrapper + fn ordered_float(value: T) -> OrderedFloat { + return OrderedFloat(value); + } + + /// Helper function to extract value from ordered float wrapper + fn extract_float(ordered: OrderedFloat) -> T { + return ordered.into_inner(); + } + + /// Optimized initialization with top-k neighbor limiting fn init(&mut self) { - // basic measures let len = self.samples.shape().0; - let max_index = self.samples.shape().0 - 1; + let max_neighbors: usize = self.parameters.top_k.unwrap_or(len - 1).min(len - 1); - // Store all closest neighbors - let _distances = Box::new(HashMap::with_capacity(len)); - let _neighbours = Box::new(Vec::with_capacity(len)); + let mut distances = HashMap::with_capacity(len); + let mut neighbours = Vec::with_capacity(len); - let mut distances = *_distances; - let mut neighbours = *_neighbours; - - // fill neighbours with -1 values neighbours.extend(0..len); - // init closest neighbour pairwise data - for index_row_i in 0..(max_index) { + // Initialize with max distances + for i in 0..len { distances.insert( - index_row_i, + i, PairwiseDistance { - node: index_row_i, - neighbour: Option::None, + node: i, + neighbour: None, distance: Some(::max_value()), }, ); } - // loop through indeces and neighbours - for index_row_i in 0..(len) { - // start looking for the neighbour in the second element - let mut index_closest = index_row_i + 1; // closest neighbour index - let mut nbd: Option = distances[&index_row_i].distance; // init neighbour distance - for index_row_j in (index_row_i + 1)..len { - distances.insert( - index_row_j, - PairwiseDistance { - node: index_row_j, - neighbour: Some(index_row_i), - distance: nbd, - }, - ); - - let d = Cosine::new().distance( - &Vec::from_iterator( - self.samples.get_row(index_row_i).iterator(0).copied(), - self.samples.shape().1, - ), - &Vec::from_iterator( - self.samples.get_row(index_row_j).iterator(0).copied(), - self.samples.shape().1, - ), - ); - if d < nbd.unwrap().to_f64().unwrap() { - // set this j-value to be the closest neighbour - index_closest = index_row_j; - nbd = Some(T::from(d).unwrap()); + // Compute distances for each point using top-k optimization + for i in 0..len { + let mut candidate_distances = BinaryHeap::new(); + + for j in 0..len { + if i != j { + let distance = T::from(Cosine::new().distance( + &Vec::from_iterator( + self.samples.get_row(i).iterator(0).copied(), + self.samples.shape().1, + ), + &Vec::from_iterator( + self.samples.get_row(j).iterator(0).copied(), + self.samples.shape().1, + ), + )).unwrap(); + + // Use OrderedFloat for stable ordering + candidate_distances.push(Reverse((Self::ordered_float(distance), j))); + + if candidate_distances.len() > max_neighbors { + candidate_distances.pop(); + } } } - // Add that edge - distances.entry(index_row_i).and_modify(|e| { - e.distance = nbd; - e.neighbour = Some(index_closest); - }); - } - // No more neighbors, terminate conga line. - // Last person on the line has no neigbors - distances.get_mut(&max_index).unwrap().neighbour = Some(max_index); - distances.get_mut(&(len - 1)).unwrap().distance = Some(::max_value()); - - // compute sparse matrix (connectivity matrix) - let mut sparse_matrix = M::zeros(len, len); - for (_, p) in distances.iter() { - sparse_matrix.set((p.node, p.neighbour.unwrap()), p.distance.unwrap()); + // Find the closest neighbor from candidates + if let Some(Reverse((closest_distance, closest_neighbor))) = + candidate_distances.iter().min_by_key(|Reverse((d, _))| *d) { + distances.entry(i).and_modify(|e| { + e.distance = Some(Self::extract_float(*closest_distance)); + e.neighbour = Some(*closest_neighbor); + }); + } } self.distances = distances; self.neighbours = neighbours; } + /// Fast query using top-k pre-computed neighbors with ordered-float + pub fn query_row_top_k(&self, query_row_index: usize, k: usize) -> Result, Failed> { + if query_row_index >= self.samples.shape().0 { + return Err(Failed::because( + FailedError::FindFailed, + "Query row index out of bounds", + )); + } + + if k == 0 { + return Ok(Vec::new()); + } + + let max_candidates = self.parameters.top_k.unwrap_or(self.samples.shape().0); + let actual_k: usize = k.min(max_candidates); + + // Use binary heap with ordered-float for reliable ordering + let mut heap = BinaryHeap::with_capacity(actual_k + 1); + + let candidates = if let Some(top_k) = self.parameters.top_k { + let step = (self.samples.shape().0 / top_k).max(1); + (0..self.samples.shape().0) + .step_by(step) + .filter(|&i| i != query_row_index) + .take(top_k) + .collect::>() + } else { + (0..self.samples.shape().0) + .filter(|&i| i != query_row_index) + .collect::>() + }; + + for &candidate_idx in &candidates { + let distance = T::from(Cosine::new().distance( + &Vec::from_iterator( + self.samples.get_row(query_row_index).iterator(0).copied(), + self.samples.shape().1, + ), + &Vec::from_iterator( + self.samples.get_row(candidate_idx).iterator(0).copied(), + self.samples.shape().1, + ), + )).unwrap(); + + heap.push(Reverse((Self::ordered_float(distance), candidate_idx))); + + if heap.len() > actual_k { + heap.pop(); + } + } + + // Convert heap to sorted vector + let mut neighbors: Vec<_> = heap.into_vec() + .into_iter() + .map(|Reverse((dist, idx))| (Self::extract_float(dist), idx)) + .collect(); + + neighbors.sort_by(|a, b| { + Self::ordered_float(a.0).cmp(&Self::ordered_float(b.0)) + }); + + Ok(neighbors) + } + /// Query k nearest neighbors for a row that's already in the dataset pub fn query_row(&self, query_row_index: usize, k: usize) -> Result, Failed> { if query_row_index >= self.samples.shape().0 { @@ -318,7 +408,7 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2> CosinePair<'a, T, M> { mod tests { use super::*; use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix}; - use approx::assert_relative_eq; + use approx::{relative_eq, assert_relative_eq}; #[cfg_attr( all(target_arch = "wasm32", not(target_os = "wasi")), @@ -499,10 +589,6 @@ mod tests { } } - #[cfg_attr( - all(target_arch = "wasm32", not(target_os = "wasi")), - wasm_bindgen_test::wasm_bindgen_test - )] #[test] fn cosine_pair_query_row_bounds_error() { let x = DenseMatrix::::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap(); @@ -520,10 +606,6 @@ mod tests { } } - #[cfg_attr( - all(target_arch = "wasm32", not(target_os = "wasi")), - wasm_bindgen_test::wasm_bindgen_test - )] #[test] fn cosine_pair_query_row_k_zero() { let x = @@ -635,6 +717,166 @@ mod tests { assert!(distance >= 0.0 && distance <= 2.0); } + #[test] + fn query_row_top_k_top_k_limiting() { + // Test that query_row_top_k respects top_k parameter and returns correct results + let x = DenseMatrix::::from_2d_array(&[ + &[1.0, 0.0, 0.0], // Point 0 + &[0.0, 1.0, 0.0], // Point 1 - orthogonal to point 0 + &[0.0, 0.0, 1.0], // Point 2 - orthogonal to point 0 + &[1.0, 1.0, 0.0], // Point 3 - closer to point 0 than points 1,2 + &[0.5, 0.0, 0.0], // Point 4 - very close to point 0 (parallel) + &[2.0, 0.0, 0.0], // Point 5 - very close to point 0 (parallel) + &[0.0, 1.0, 1.0], // Point 6 - far from point 0 + &[3.0, 3.0, 3.0], // Point 7 - moderately close to point 0 + ]).unwrap(); + + // Create CosinePair with top_k=4 to limit candidates + let cosine_pair = CosinePair::with_top_k(&x, 4).unwrap(); + + // Query for 3 nearest neighbors to point 0 + let neighbors = cosine_pair.query_row_top_k(0, 3).unwrap(); + + // Should return exactly 3 neighbors + assert_eq!(neighbors.len(), 3); + + // Verify that distances are in ascending order + for i in 1..neighbors.len() { + assert!(neighbors[i-1].0 <= neighbors[i].0, + "Distances should be in ascending order: {} <= {}", + neighbors[i-1].0, neighbors[i].0); + } + + // All distances should be valid cosine distances (0 to 2) + for (distance, index) in &neighbors { + assert!(*distance >= 0.0 && *distance <= 2.0, + "Cosine distance {} should be between 0 and 2", distance); + assert!(*index < x.shape().0, + "Neighbor index {} should be less than dataset size {}", index, x.shape().0); + assert!(*index != 0, + "Neighbor index should not include query point itself"); + } + + // The closest neighbor should be either point 4 or 5 (parallel vectors) + // These should have cosine distance ≈ 0 + let closest_distance = neighbors[0].0; + assert!(closest_distance < 0.01, + "Closest parallel vector should have distance close to 0, got {}", closest_distance); + + // Verify that we get different results with different top_k values + let cosine_pair_full = CosinePair::new(&x).unwrap(); + let neighbors_full = cosine_pair_full.query_row(0, 3).unwrap(); + + // Results should be the same or very close since we're asking for top 3 + // but the algorithm might find different candidates due to top_k limiting + assert_eq!(neighbors.len(), neighbors_full.len()); + + // The closest neighbor should be the same in both cases + let closest_idx_fast = neighbors[0].1; + let closest_idx_full = neighbors_full[0].1; + let closest_dist_fast = neighbors[0].0; + let closest_dist_full = neighbors_full[0].0; + + // Either we get the same closest neighbor, or distances are very close + if closest_idx_fast == closest_idx_full { + assert!(relative_eq!(closest_dist_fast, closest_dist_full, epsilon = 1e-10)); + } else { + // Different neighbors, but distances should be very close (parallel vectors) + assert!(relative_eq!(closest_dist_fast, closest_dist_full, epsilon = 1e-6)); + } + } + + #[test] + fn query_row_top_k_performance_vs_accuracy() { + // Test that query_row_top_k provides reasonable performance/accuracy tradeoff + // and handles edge cases properly + let large_dataset = DenseMatrix::::from_2d_array(&[ + &[1.0f32, 2.0, 3.0, 4.0], // Point 0 - query point + &[1.1f32, 2.1, 3.1, 4.1], // Point 1 - very close to 0 + &[1.05f32, 2.05, 3.05, 4.05], // Point 2 - very close to 0 + &[2.0f32, 4.0, 6.0, 8.0], // Point 3 - parallel to 0 (2x scaling) + &[0.5f32, 1.0, 1.5, 2.0], // Point 4 - parallel to 0 (0.5x scaling) + &[-1.0f32, -2.0, -3.0, -4.0], // Point 5 - opposite to 0 + &[4.0f32, 3.0, 2.0, 1.0], // Point 6 - different direction + &[0.0f32, 0.0, 0.0, 0.1], // Point 7 - mostly orthogonal + &[10.0f32, 20.0, 30.0, 40.0], // Point 8 - parallel but far + &[1.0f32, 0.0, 0.0, 0.0], // Point 9 - partially similar + &[0.0f32, 2.0, 0.0, 0.0], // Point 10 - partially similar + &[0.0f32, 0.0, 3.0, 0.0], // Point 11 - partially similar + ]).unwrap(); + + // Test with aggressive top_k limiting (only consider 5 out of 11 other points) + let cosine_pair_limited = CosinePair::with_top_k(&large_dataset, 5).unwrap(); + + // Query for 4 nearest neighbors + let neighbors_limited = cosine_pair_limited.query_row_top_k(0, 4).unwrap(); + + // Should return exactly 4 neighbors + assert_eq!(neighbors_limited.len(), 4); + + // Test error handling - out of bounds query + let result_oob = cosine_pair_limited.query_row_top_k(15, 2); + assert!(result_oob.is_err()); + if let Err(e) = result_oob { + assert_eq!(e, Failed::because( + FailedError::FindFailed, + "Query row index out of bounds" + )); + } + + // Test k=0 case + let neighbors_zero = cosine_pair_limited.query_row_top_k(0, 0).unwrap(); + assert_eq!(neighbors_zero.len(), 0); + + // Test k > available candidates + let neighbors_large_k = cosine_pair_limited.query_row_top_k(0, 20).unwrap(); + assert!(neighbors_large_k.len() <= 11); // At most 11 other points + + // Verify ordering is correct + for i in 1..neighbors_limited.len() { + assert!(neighbors_limited[i-1].0 <= neighbors_limited[i].0, + "Distance ordering violation at position {}: {} > {}", + i, neighbors_limited[i-1].0, neighbors_limited[i].0); + } + + // The closest neighbors should be the parallel vectors (points 1, 2, 3, 4) + // since they have the smallest cosine distances + let closest_distance = neighbors_limited[0].0; + assert!(closest_distance < 0.1, + "Closest neighbor should be nearly parallel, distance: {}", closest_distance); + + // Compare with full algorithm for accuracy assessment + let cosine_pair_full = CosinePair::new(&large_dataset).unwrap(); + let neighbors_full = cosine_pair_full.query_row(0, 4).unwrap(); + + // The fast version might not find the exact same neighbors due to sampling, + // but the closest neighbor's distance should be very similar + let dist_diff = (neighbors_limited[0].0 - neighbors_full[0].0).abs(); + assert!(dist_diff < 0.01, + "Fast and full algorithms should give similar closest distances. Diff: {}", dist_diff); + + // Verify that all returned indices are valid and unique + let mut indices: Vec = neighbors_limited.iter().map(|(_, idx)| *idx).collect(); + indices.sort(); + indices.dedup(); + assert_eq!(indices.len(), neighbors_limited.len(), + "All neighbor indices should be unique"); + + for &idx in &indices { + assert!(idx < large_dataset.shape().0, + "Neighbor index {} should be valid", idx); + assert!(idx != 0, + "Neighbor should not include query point itself"); + } + + // Test with f32 precision to ensure type compatibility + for (distance, _) in &neighbors_limited { + assert!(!distance.is_nan(), "Distance should not be NaN"); + assert!(distance.is_finite(), "Distance should be finite"); + assert!(*distance >= 0.0, "Distance should be non-negative"); + } + } + #[test] fn cosine_pair_float_precision() { // Test with f32 precision From 1d557994881218785a99230e296270423d40f787 Mon Sep 17 00:00:00 2001 From: Lorenzo Mec-iS Date: Tue, 30 Sep 2025 17:44:32 +0100 Subject: [PATCH 06/10] fix distance computation --- src/metrics/distance/cosine.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/metrics/distance/cosine.rs b/src/metrics/distance/cosine.rs index ea065a07..dd9c600d 100644 --- a/src/metrics/distance/cosine.rs +++ b/src/metrics/distance/cosine.rs @@ -92,7 +92,7 @@ impl Cosine { let magnitude_y = Self::magnitude(y); if magnitude_x == 0.0 || magnitude_y == 0.0 { - panic!("Cannot compute cosine distance for zero-magnitude vectors."); + return f64::MAX; } dot_product / (magnitude_x * magnitude_y) From cfb2c6a49e544d4d003bb7901f7daa473af74053 Mon Sep 17 00:00:00 2001 From: Lorenzo Mec-iS Date: Thu, 2 Oct 2025 15:26:24 +0100 Subject: [PATCH 07/10] set min similarity for constant zeros --- src/metrics/distance/cosine.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/metrics/distance/cosine.rs b/src/metrics/distance/cosine.rs index ea065a07..8c7a2c00 100644 --- a/src/metrics/distance/cosine.rs +++ b/src/metrics/distance/cosine.rs @@ -92,7 +92,7 @@ impl Cosine { let magnitude_y = Self::magnitude(y); if magnitude_x == 0.0 || magnitude_y == 0.0 { - panic!("Cannot compute cosine distance for zero-magnitude vectors."); + return f64::MIN; } dot_product / (magnitude_x * magnitude_y) @@ -188,12 +188,12 @@ mod tests { wasm_bindgen_test::wasm_bindgen_test )] #[test] - #[should_panic(expected = "Cannot compute cosine distance for zero-magnitude vectors.")] fn cosine_distance_zero_vector() { let a = vec![0, 0, 0]; let b = vec![1, 2, 3]; - let _dist: f64 = Cosine::new().distance(&a, &b); + let dist: f64 = Cosine::new().distance(&a, &b); + assert!(dist > 1e300) } #[cfg_attr( From a3dbe784881f3684468a62283e5e17c1132c7eb0 Mon Sep 17 00:00:00 2001 From: Lorenzo Date: Thu, 9 Oct 2025 17:10:14 +0100 Subject: [PATCH 08/10] bump version to 0.4.5 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index f1ffea75..c05680fd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ name = "smartcore" description = "Machine Learning in Rust." homepage = "https://smartcorelib.org" -version = "0.4.3" +version = "0.4.5" authors = ["smartcore Developers"] edition = "2021" license = "Apache-2.0" From 320d083d2396e9ab9f6fcec81be4b62b362efce3 Mon Sep 17 00:00:00 2001 From: Lorenzo Mec-iS Date: Thu, 9 Oct 2025 17:19:24 +0100 Subject: [PATCH 09/10] apply formatting --- src/algorithm/neighbour/cosinepair.rs | 212 ++++++++++++++++---------- 1 file changed, 130 insertions(+), 82 deletions(-) diff --git a/src/algorithm/neighbour/cosinepair.rs b/src/algorithm/neighbour/cosinepair.rs index b2abbdb8..d5dcdc8a 100644 --- a/src/algorithm/neighbour/cosinepair.rs +++ b/src/algorithm/neighbour/cosinepair.rs @@ -25,8 +25,8 @@ /// use ordered_float::{FloatCore, OrderedFloat}; -use std::collections::{BinaryHeap, HashMap}; use std::cmp::Reverse; +use std::collections::{BinaryHeap, HashMap}; use num::Bounded; @@ -37,7 +37,6 @@ use crate::metrics::distance::{Distance, PairwiseDistance}; use crate::numbers::floatnum::FloatNumber; use crate::numbers::realnum::RealNumber; - /// Parameters for CosinePair construction #[derive(Debug, Clone)] pub struct CosinePairParameters { @@ -83,10 +82,13 @@ impl<'a, T: RealNumber + FloatNumber + FloatCore, M: Array2> CosinePair<'a, T /// Constructor with top-k limiting for faster performance pub fn with_top_k(m: &'a M, top_k: usize) -> Result { - Self::with_parameters(m, CosinePairParameters { - top_k: Some(top_k), - approximate: false, - }) + Self::with_parameters( + m, + CosinePairParameters { + top_k: Some(top_k), + approximate: false, + }, + ) } /// Constructor with full parameter control @@ -109,7 +111,7 @@ impl<'a, T: RealNumber + FloatNumber + FloatCore, M: Array2> CosinePair<'a, T } /// Helper function to create ordered float wrapper - fn ordered_float(value: T) -> OrderedFloat { + fn ordered_float(value: T) -> OrderedFloat { return OrderedFloat(value); } @@ -155,7 +157,8 @@ impl<'a, T: RealNumber + FloatNumber + FloatCore, M: Array2> CosinePair<'a, T self.samples.get_row(j).iterator(0).copied(), self.samples.shape().1, ), - )).unwrap(); + )) + .unwrap(); // Use OrderedFloat for stable ordering candidate_distances.push(Reverse((Self::ordered_float(distance), j))); @@ -167,8 +170,9 @@ impl<'a, T: RealNumber + FloatNumber + FloatCore, M: Array2> CosinePair<'a, T } // Find the closest neighbor from candidates - if let Some(Reverse((closest_distance, closest_neighbor))) = - candidate_distances.iter().min_by_key(|Reverse((d, _))| *d) { + if let Some(Reverse((closest_distance, closest_neighbor))) = + candidate_distances.iter().min_by_key(|Reverse((d, _))| *d) + { distances.entry(i).and_modify(|e| { e.distance = Some(Self::extract_float(*closest_distance)); e.neighbour = Some(*closest_neighbor); @@ -181,7 +185,11 @@ impl<'a, T: RealNumber + FloatNumber + FloatCore, M: Array2> CosinePair<'a, T } /// Fast query using top-k pre-computed neighbors with ordered-float - pub fn query_row_top_k(&self, query_row_index: usize, k: usize) -> Result, Failed> { + pub fn query_row_top_k( + &self, + query_row_index: usize, + k: usize, + ) -> Result, Failed> { if query_row_index >= self.samples.shape().0 { return Err(Failed::because( FailedError::FindFailed, @@ -198,7 +206,7 @@ impl<'a, T: RealNumber + FloatNumber + FloatCore, M: Array2> CosinePair<'a, T // Use binary heap with ordered-float for reliable ordering let mut heap = BinaryHeap::with_capacity(actual_k + 1); - + let candidates = if let Some(top_k) = self.parameters.top_k { let step = (self.samples.shape().0 / top_k).max(1); (0..self.samples.shape().0) @@ -222,25 +230,25 @@ impl<'a, T: RealNumber + FloatNumber + FloatCore, M: Array2> CosinePair<'a, T self.samples.get_row(candidate_idx).iterator(0).copied(), self.samples.shape().1, ), - )).unwrap(); + )) + .unwrap(); heap.push(Reverse((Self::ordered_float(distance), candidate_idx))); - + if heap.len() > actual_k { heap.pop(); } } // Convert heap to sorted vector - let mut neighbors: Vec<_> = heap.into_vec() + let mut neighbors: Vec<_> = heap + .into_vec() .into_iter() .map(|Reverse((dist, idx))| (Self::extract_float(dist), idx)) .collect(); - - neighbors.sort_by(|a, b| { - Self::ordered_float(a.0).cmp(&Self::ordered_float(b.0)) - }); - + + neighbors.sort_by(|a, b| Self::ordered_float(a.0).cmp(&Self::ordered_float(b.0))); + Ok(neighbors) } @@ -408,7 +416,7 @@ impl<'a, T: RealNumber + FloatNumber + FloatCore, M: Array2> CosinePair<'a, T mod tests { use super::*; use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix}; - use approx::{relative_eq, assert_relative_eq}; + use approx::{assert_relative_eq, relative_eq}; #[cfg_attr( all(target_arch = "wasm32", not(target_os = "wasi")), @@ -721,68 +729,92 @@ mod tests { fn query_row_top_k_top_k_limiting() { // Test that query_row_top_k respects top_k parameter and returns correct results let x = DenseMatrix::::from_2d_array(&[ - &[1.0, 0.0, 0.0], // Point 0 - &[0.0, 1.0, 0.0], // Point 1 - orthogonal to point 0 - &[0.0, 0.0, 1.0], // Point 2 - orthogonal to point 0 - &[1.0, 1.0, 0.0], // Point 3 - closer to point 0 than points 1,2 - &[0.5, 0.0, 0.0], // Point 4 - very close to point 0 (parallel) - &[2.0, 0.0, 0.0], // Point 5 - very close to point 0 (parallel) - &[0.0, 1.0, 1.0], // Point 6 - far from point 0 - &[3.0, 3.0, 3.0], // Point 7 - moderately close to point 0 - ]).unwrap(); + &[1.0, 0.0, 0.0], // Point 0 + &[0.0, 1.0, 0.0], // Point 1 - orthogonal to point 0 + &[0.0, 0.0, 1.0], // Point 2 - orthogonal to point 0 + &[1.0, 1.0, 0.0], // Point 3 - closer to point 0 than points 1,2 + &[0.5, 0.0, 0.0], // Point 4 - very close to point 0 (parallel) + &[2.0, 0.0, 0.0], // Point 5 - very close to point 0 (parallel) + &[0.0, 1.0, 1.0], // Point 6 - far from point 0 + &[3.0, 3.0, 3.0], // Point 7 - moderately close to point 0 + ]) + .unwrap(); // Create CosinePair with top_k=4 to limit candidates let cosine_pair = CosinePair::with_top_k(&x, 4).unwrap(); - + // Query for 3 nearest neighbors to point 0 let neighbors = cosine_pair.query_row_top_k(0, 3).unwrap(); - + // Should return exactly 3 neighbors assert_eq!(neighbors.len(), 3); - + // Verify that distances are in ascending order for i in 1..neighbors.len() { - assert!(neighbors[i-1].0 <= neighbors[i].0, - "Distances should be in ascending order: {} <= {}", - neighbors[i-1].0, neighbors[i].0); + assert!( + neighbors[i - 1].0 <= neighbors[i].0, + "Distances should be in ascending order: {} <= {}", + neighbors[i - 1].0, + neighbors[i].0 + ); } - + // All distances should be valid cosine distances (0 to 2) for (distance, index) in &neighbors { - assert!(*distance >= 0.0 && *distance <= 2.0, - "Cosine distance {} should be between 0 and 2", distance); - assert!(*index < x.shape().0, - "Neighbor index {} should be less than dataset size {}", index, x.shape().0); - assert!(*index != 0, - "Neighbor index should not include query point itself"); + assert!( + *distance >= 0.0 && *distance <= 2.0, + "Cosine distance {} should be between 0 and 2", + distance + ); + assert!( + *index < x.shape().0, + "Neighbor index {} should be less than dataset size {}", + index, + x.shape().0 + ); + assert!( + *index != 0, + "Neighbor index should not include query point itself" + ); } - + // The closest neighbor should be either point 4 or 5 (parallel vectors) // These should have cosine distance ≈ 0 let closest_distance = neighbors[0].0; - assert!(closest_distance < 0.01, - "Closest parallel vector should have distance close to 0, got {}", closest_distance); - + assert!( + closest_distance < 0.01, + "Closest parallel vector should have distance close to 0, got {}", + closest_distance + ); + // Verify that we get different results with different top_k values let cosine_pair_full = CosinePair::new(&x).unwrap(); let neighbors_full = cosine_pair_full.query_row(0, 3).unwrap(); - + // Results should be the same or very close since we're asking for top 3 // but the algorithm might find different candidates due to top_k limiting assert_eq!(neighbors.len(), neighbors_full.len()); - + // The closest neighbor should be the same in both cases let closest_idx_fast = neighbors[0].1; let closest_idx_full = neighbors_full[0].1; let closest_dist_fast = neighbors[0].0; let closest_dist_full = neighbors_full[0].0; - + // Either we get the same closest neighbor, or distances are very close if closest_idx_fast == closest_idx_full { - assert!(relative_eq!(closest_dist_fast, closest_dist_full, epsilon = 1e-10)); + assert!(relative_eq!( + closest_dist_fast, + closest_dist_full, + epsilon = 1e-10 + )); } else { // Different neighbors, but distances should be very close (parallel vectors) - assert!(relative_eq!(closest_dist_fast, closest_dist_full, epsilon = 1e-6)); + assert!(relative_eq!( + closest_dist_fast, + closest_dist_full, + epsilon = 1e-6 + )); } } @@ -803,72 +835,88 @@ mod tests { &[1.0f32, 0.0, 0.0, 0.0], // Point 9 - partially similar &[0.0f32, 2.0, 0.0, 0.0], // Point 10 - partially similar &[0.0f32, 0.0, 3.0, 0.0], // Point 11 - partially similar - ]).unwrap(); + ]) + .unwrap(); // Test with aggressive top_k limiting (only consider 5 out of 11 other points) let cosine_pair_limited = CosinePair::with_top_k(&large_dataset, 5).unwrap(); - + // Query for 4 nearest neighbors let neighbors_limited = cosine_pair_limited.query_row_top_k(0, 4).unwrap(); - + // Should return exactly 4 neighbors assert_eq!(neighbors_limited.len(), 4); - + // Test error handling - out of bounds query let result_oob = cosine_pair_limited.query_row_top_k(15, 2); assert!(result_oob.is_err()); if let Err(e) = result_oob { - assert_eq!(e, Failed::because( - FailedError::FindFailed, - "Query row index out of bounds" - )); + assert_eq!( + e, + Failed::because(FailedError::FindFailed, "Query row index out of bounds") + ); } - + // Test k=0 case let neighbors_zero = cosine_pair_limited.query_row_top_k(0, 0).unwrap(); assert_eq!(neighbors_zero.len(), 0); - + // Test k > available candidates let neighbors_large_k = cosine_pair_limited.query_row_top_k(0, 20).unwrap(); assert!(neighbors_large_k.len() <= 11); // At most 11 other points - + // Verify ordering is correct for i in 1..neighbors_limited.len() { - assert!(neighbors_limited[i-1].0 <= neighbors_limited[i].0, - "Distance ordering violation at position {}: {} > {}", - i, neighbors_limited[i-1].0, neighbors_limited[i].0); + assert!( + neighbors_limited[i - 1].0 <= neighbors_limited[i].0, + "Distance ordering violation at position {}: {} > {}", + i, + neighbors_limited[i - 1].0, + neighbors_limited[i].0 + ); } - + // The closest neighbors should be the parallel vectors (points 1, 2, 3, 4) // since they have the smallest cosine distances let closest_distance = neighbors_limited[0].0; - assert!(closest_distance < 0.1, - "Closest neighbor should be nearly parallel, distance: {}", closest_distance); - + assert!( + closest_distance < 0.1, + "Closest neighbor should be nearly parallel, distance: {}", + closest_distance + ); + // Compare with full algorithm for accuracy assessment let cosine_pair_full = CosinePair::new(&large_dataset).unwrap(); let neighbors_full = cosine_pair_full.query_row(0, 4).unwrap(); - + // The fast version might not find the exact same neighbors due to sampling, // but the closest neighbor's distance should be very similar let dist_diff = (neighbors_limited[0].0 - neighbors_full[0].0).abs(); - assert!(dist_diff < 0.01, - "Fast and full algorithms should give similar closest distances. Diff: {}", dist_diff); - + assert!( + dist_diff < 0.01, + "Fast and full algorithms should give similar closest distances. Diff: {}", + dist_diff + ); + // Verify that all returned indices are valid and unique let mut indices: Vec = neighbors_limited.iter().map(|(_, idx)| *idx).collect(); indices.sort(); indices.dedup(); - assert_eq!(indices.len(), neighbors_limited.len(), - "All neighbor indices should be unique"); - + assert_eq!( + indices.len(), + neighbors_limited.len(), + "All neighbor indices should be unique" + ); + for &idx in &indices { - assert!(idx < large_dataset.shape().0, - "Neighbor index {} should be valid", idx); - assert!(idx != 0, - "Neighbor should not include query point itself"); + assert!( + idx < large_dataset.shape().0, + "Neighbor index {} should be valid", + idx + ); + assert!(idx != 0, "Neighbor should not include query point itself"); } - + // Test with f32 precision to ensure type compatibility for (distance, _) in &neighbors_limited { assert!(!distance.is_nan(), "Distance should not be NaN"); From 90a979c4a72a2b8d48841c99e163af5e5644217d Mon Sep 17 00:00:00 2001 From: Lorenzo Mec-iS Date: Thu, 9 Oct 2025 17:24:05 +0100 Subject: [PATCH 10/10] fix clippy --- src/algorithm/neighbour/cosinepair.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/algorithm/neighbour/cosinepair.rs b/src/algorithm/neighbour/cosinepair.rs index d5dcdc8a..889be689 100644 --- a/src/algorithm/neighbour/cosinepair.rs +++ b/src/algorithm/neighbour/cosinepair.rs @@ -46,6 +46,7 @@ pub struct CosinePairParameters { pub approximate: bool, } +#[allow(clippy::derivable_impls)] impl Default for CosinePairParameters { fn default() -> Self { Self { @@ -112,12 +113,12 @@ impl<'a, T: RealNumber + FloatNumber + FloatCore, M: Array2> CosinePair<'a, T /// Helper function to create ordered float wrapper fn ordered_float(value: T) -> OrderedFloat { - return OrderedFloat(value); + OrderedFloat(value) } /// Helper function to extract value from ordered float wrapper fn extract_float(ordered: OrderedFloat) -> T { - return ordered.into_inner(); + ordered.into_inner() } /// Optimized initialization with top-k neighbor limiting