From 960824ededbec0fcace0843c89331b07049f8332 Mon Sep 17 00:00:00 2001 From: Charlie Martin Date: Fri, 21 Nov 2025 22:23:00 -0500 Subject: [PATCH 1/2] fix precision and recall calculations --- src/metrics/precision.rs | 109 +++++++++++++++++++++++++++++++-------- src/metrics/recall.rs | 85 ++++++++++++++++++++++-------- 2 files changed, 152 insertions(+), 42 deletions(-) diff --git a/src/metrics/precision.rs b/src/metrics/precision.rs index dd09740c..bf00bdb7 100644 --- a/src/metrics/precision.rs +++ b/src/metrics/precision.rs @@ -4,7 +4,9 @@ //! //! \\[precision = \frac{tp}{tp + fp}\\] //! -//! where tp (true positive) - correct result, fp (false positive) - unexpected result +//! where tp (true positive) - correct result, fp (false positive) - unexpected result. +//! For binary classification, this is precision for the positive class (assumed to be 1.0). +//! For multiclass, this is macro-averaged precision (average of per-class precisions). //! //! Example: //! @@ -19,7 +21,8 @@ //! //! //! -use std::collections::HashSet; + +use std::collections::{HashMap, HashSet}; use std::marker::PhantomData; #[cfg(feature = "serde")] @@ -61,33 +64,65 @@ impl Metrics for Precision { ); } - let mut classes = HashSet::new(); - for i in 0..y_true.shape() { - classes.insert(y_true.get(i).to_f64_bits()); + let n = y_true.shape(); + + let mut classes_set: HashSet = HashSet::new(); + for i in 0..n { + classes_set.insert(y_true.get(i).to_f64_bits()); } - let classes = classes.len(); - - let mut tp = 0; - let mut fp = 0; - for i in 0..y_true.shape() { - if y_pred.get(i) == y_true.get(i) { - if classes == 2 { - if *y_true.get(i) == T::one() { + let classes: usize = classes_set.len(); + + if classes == 2 { + // Binary case: precision for positive class (assumed T::one()) + let positive = T::one(); + let mut tp: usize = 0; + let mut fp_count: usize = 0; + for i in 0..n { + let t = *y_true.get(i); + let p = *y_pred.get(i); + if p == t { + if t == positive { tp += 1; } } else { - tp += 1; + if t != positive { + fp_count += 1; + } } - } else if classes == 2 { - if *y_true.get(i) == T::one() { - fp += 1; + } + if tp + fp_count == 0 { + 0.0 + } else { + tp as f64 / (tp + fp_count) as f64 + } + } else { + // Multiclass case: macro-averaged precision + let mut predicted: HashMap = HashMap::new(); + let mut tp_map: HashMap = HashMap::new(); + for i in 0..n { + let p_bits = y_pred.get(i).to_f64_bits(); + *predicted.entry(p_bits).or_insert(0) += 1; + if *y_true.get(i) == *y_pred.get(i) { + *tp_map.entry(p_bits).or_insert(0) += 1; } + } + let mut precision_sum = 0.0; + for &bits in &classes_set { + let pred_count = *predicted.get(&bits).unwrap_or(&0); + let tp = *tp_map.get(&bits).unwrap_or(&0); + let prec = if pred_count > 0 { + tp as f64 / pred_count as f64 + } else { + 0.0 + }; + precision_sum += prec; + } + if classes == 0 { + 0.0 } else { - fp += 1; + precision_sum / classes as f64 } } - - tp as f64 / (tp as f64 + fp as f64) } } @@ -114,7 +149,7 @@ mod tests { let y_pred: Vec = vec![0., 0., 1., 1., 1., 1.]; let score3: f64 = Precision::new().get_score(&y_true, &y_pred); - assert!((score3 - 0.6666666666).abs() < 1e-8); + assert!((score3 - 0.5).abs() < 1e-8); } #[cfg_attr( @@ -132,4 +167,36 @@ mod tests { assert!((score1 - 0.333333333).abs() < 1e-8); assert!((score2 - 1.0).abs() < 1e-8); } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn precision_multiclass_imbalanced() { + let y_true: Vec = vec![0., 0., 1., 2., 2., 2.]; + let y_pred: Vec = vec![0., 1., 1., 2., 0., 2.]; + + let score: f64 = Precision::new().get_score(&y_true, &y_pred); + let expected = (0.5 + 0.5 + 1.0) / 3.0; + assert!((score - expected).abs() < 1e-8); + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn precision_multiclass_unpredicted_class() { + let y_true: Vec = vec![0., 0., 1., 2., 2., 2., 3.]; + let y_pred: Vec = vec![0., 1., 1., 2., 0., 2., 0.]; + + let score: f64 = Precision::new().get_score(&y_true, &y_pred); + // Class 0: pred=3, tp=1 -> 1/3 ≈0.333 + // Class 1: pred=2, tp=1 -> 0.5 + // Class 2: pred=2, tp=2 -> 1.0 + // Class 3: pred=0, tp=0 -> 0.0 + let expected = (1.0 / 3.0 + 0.5 + 1.0 + 0.0) / 4.0; + assert!((score - expected).abs() < 1e-8); + } } diff --git a/src/metrics/recall.rs b/src/metrics/recall.rs index ab76d972..94850f13 100644 --- a/src/metrics/recall.rs +++ b/src/metrics/recall.rs @@ -4,7 +4,9 @@ //! //! \\[recall = \frac{tp}{tp + fn}\\] //! -//! where tp (true positive) - correct result, fn (false negative) - missing result +//! where tp (true positive) - correct result, fn (false negative) - missing result. +//! For binary classification, this is recall for the positive class (assumed to be 1.0). +//! For multiclass, this is macro-averaged recall (average of per-class recalls). //! //! Example: //! @@ -20,7 +22,7 @@ //! //! -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::convert::TryInto; use std::marker::PhantomData; @@ -52,7 +54,7 @@ impl Metrics for Recall { } } /// Calculated recall score - /// * `y_true` - cround truth (correct) labels. + /// * `y_true` - ground truth (correct) labels. /// * `y_pred` - predicted labels, as returned by a classifier. fn get_score(&self, y_true: &dyn ArrayView1, y_pred: &dyn ArrayView1) -> f64 { if y_true.shape() != y_pred.shape() { @@ -63,32 +65,59 @@ impl Metrics for Recall { ); } - let mut classes = HashSet::new(); - for i in 0..y_true.shape() { - classes.insert(y_true.get(i).to_f64_bits()); + let n = y_true.shape(); + + let mut classes_set = HashSet::new(); + for i in 0..n { + classes_set.insert(y_true.get(i).to_f64_bits()); } - let classes: i64 = classes.len().try_into().unwrap(); - - let mut tp = 0; - let mut fne = 0; - for i in 0..y_true.shape() { - if y_pred.get(i) == y_true.get(i) { - if classes == 2 { - if *y_true.get(i) == T::one() { + let classes: usize = classes_set.len(); + + if classes == 2 { + // Binary case: recall for positive class (assumed T::one()) + let positive = T::one(); + let mut tp: usize = 0; + let mut fn_count: usize = 0; + for i in 0..n { + let t = *y_true.get(i); + let p = *y_pred.get(i); + if p == t { + if t == positive { tp += 1; } } else { - tp += 1; + if t == positive { + fn_count += 1; + } } - } else if classes == 2 { - if *y_true.get(i) != T::one() { - fne += 1; + } + if tp + fn_count == 0 { + 0.0 + } else { + tp as f64 / (tp + fn_count) as f64 + } + } else { + // Multiclass case: macro-averaged recall + let mut support: HashMap = HashMap::new(); + let mut tp_map: HashMap = HashMap::new(); + for i in 0..n { + let t_bits = y_true.get(i).to_f64_bits(); + *support.entry(t_bits).or_insert(0) += 1; + if *y_true.get(i) == *y_pred.get(i) { + *tp_map.entry(t_bits).or_insert(0) += 1; } + } + let mut recall_sum = 0.0; + for (&bits, &sup) in &support { + let tp = *tp_map.get(&bits).unwrap_or(&0); + recall_sum += tp as f64 / sup as f64; + } + if support.is_empty() { + 0.0 } else { - fne += 1; + recall_sum / support.len() as f64 } } - tp as f64 / (tp as f64 + fne as f64) } } @@ -115,7 +144,7 @@ mod tests { let y_pred: Vec = vec![0., 0., 1., 1., 1., 1.]; let score3: f64 = Recall::new().get_score(&y_true, &y_pred); - assert!((score3 - 0.5).abs() < 1e-8); + assert!((score3 - (2.0 / 3.0)).abs() < 1e-8); } #[cfg_attr( @@ -133,4 +162,18 @@ mod tests { assert!((score1 - 0.333333333).abs() < 1e-8); assert!((score2 - 1.0).abs() < 1e-8); } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn recall_multiclass_imbalanced() { + let y_true: Vec = vec![0., 0., 1., 2., 2., 2.]; + let y_pred: Vec = vec![0., 1., 1., 2., 0., 2.]; + + let score: f64 = Recall::new().get_score(&y_true, &y_pred); + let expected = (0.5 + 1.0 + (2.0 / 3.0)) / 3.0; + assert!((score - expected).abs() < 1e-8); + } } From 6db62f20649b0cc23dcbe6ee16e0fb505bb3b09f Mon Sep 17 00:00:00 2001 From: Charlie Martin Date: Fri, 21 Nov 2025 22:34:51 -0500 Subject: [PATCH 2/2] clippy --- src/metrics/precision.rs | 6 ++---- src/metrics/recall.rs | 7 ++----- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/src/metrics/precision.rs b/src/metrics/precision.rs index bf00bdb7..84444b6b 100644 --- a/src/metrics/precision.rs +++ b/src/metrics/precision.rs @@ -84,10 +84,8 @@ impl Metrics for Precision { if t == positive { tp += 1; } - } else { - if t != positive { - fp_count += 1; - } + } else if t != positive { + fp_count += 1; } } if tp + fp_count == 0 { diff --git a/src/metrics/recall.rs b/src/metrics/recall.rs index 94850f13..e7418511 100644 --- a/src/metrics/recall.rs +++ b/src/metrics/recall.rs @@ -23,7 +23,6 @@ //! use std::collections::{HashMap, HashSet}; -use std::convert::TryInto; use std::marker::PhantomData; #[cfg(feature = "serde")] @@ -85,10 +84,8 @@ impl Metrics for Recall { if t == positive { tp += 1; } - } else { - if t == positive { - fn_count += 1; - } + } else if t == positive { + fn_count += 1; } } if tp + fn_count == 0 {