Skip to content

Commit 70d8a0f

Browse files
authored
fix precision and recall calculations (#338)
* fix precision and recall calculations
1 parent 0e42a97 commit 70d8a0f

File tree

2 files changed

+150
-45
lines changed

2 files changed

+150
-45
lines changed

src/metrics/precision.rs

Lines changed: 87 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
//!
55
//! \\[precision = \frac{tp}{tp + fp}\\]
66
//!
7-
//! where tp (true positive) - correct result, fp (false positive) - unexpected result
7+
//! where tp (true positive) - correct result, fp (false positive) - unexpected result.
8+
//! For binary classification, this is precision for the positive class (assumed to be 1.0).
9+
//! For multiclass, this is macro-averaged precision (average of per-class precisions).
810
//!
911
//! Example:
1012
//!
@@ -19,7 +21,8 @@
1921
//!
2022
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
2123
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
22-
use std::collections::HashSet;
24+
25+
use std::collections::{HashMap, HashSet};
2326
use std::marker::PhantomData;
2427

2528
#[cfg(feature = "serde")]
@@ -61,33 +64,63 @@ impl<T: RealNumber> Metrics<T> for Precision<T> {
6164
);
6265
}
6366

64-
let mut classes = HashSet::new();
65-
for i in 0..y_true.shape() {
66-
classes.insert(y_true.get(i).to_f64_bits());
67+
let n = y_true.shape();
68+
69+
let mut classes_set: HashSet<u64> = HashSet::new();
70+
for i in 0..n {
71+
classes_set.insert(y_true.get(i).to_f64_bits());
6772
}
68-
let classes = classes.len();
69-
70-
let mut tp = 0;
71-
let mut fp = 0;
72-
for i in 0..y_true.shape() {
73-
if y_pred.get(i) == y_true.get(i) {
74-
if classes == 2 {
75-
if *y_true.get(i) == T::one() {
73+
let classes: usize = classes_set.len();
74+
75+
if classes == 2 {
76+
// Binary case: precision for positive class (assumed T::one())
77+
let positive = T::one();
78+
let mut tp: usize = 0;
79+
let mut fp_count: usize = 0;
80+
for i in 0..n {
81+
let t = *y_true.get(i);
82+
let p = *y_pred.get(i);
83+
if p == t {
84+
if t == positive {
7685
tp += 1;
7786
}
78-
} else {
79-
tp += 1;
87+
} else if t != positive {
88+
fp_count += 1;
8089
}
81-
} else if classes == 2 {
82-
if *y_true.get(i) == T::one() {
83-
fp += 1;
90+
}
91+
if tp + fp_count == 0 {
92+
0.0
93+
} else {
94+
tp as f64 / (tp + fp_count) as f64
95+
}
96+
} else {
97+
// Multiclass case: macro-averaged precision
98+
let mut predicted: HashMap<u64, usize> = HashMap::new();
99+
let mut tp_map: HashMap<u64, usize> = HashMap::new();
100+
for i in 0..n {
101+
let p_bits = y_pred.get(i).to_f64_bits();
102+
*predicted.entry(p_bits).or_insert(0) += 1;
103+
if *y_true.get(i) == *y_pred.get(i) {
104+
*tp_map.entry(p_bits).or_insert(0) += 1;
84105
}
106+
}
107+
let mut precision_sum = 0.0;
108+
for &bits in &classes_set {
109+
let pred_count = *predicted.get(&bits).unwrap_or(&0);
110+
let tp = *tp_map.get(&bits).unwrap_or(&0);
111+
let prec = if pred_count > 0 {
112+
tp as f64 / pred_count as f64
113+
} else {
114+
0.0
115+
};
116+
precision_sum += prec;
117+
}
118+
if classes == 0 {
119+
0.0
85120
} else {
86-
fp += 1;
121+
precision_sum / classes as f64
87122
}
88123
}
89-
90-
tp as f64 / (tp as f64 + fp as f64)
91124
}
92125
}
93126

@@ -114,7 +147,7 @@ mod tests {
114147
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
115148

116149
let score3: f64 = Precision::new().get_score(&y_true, &y_pred);
117-
assert!((score3 - 0.6666666666).abs() < 1e-8);
150+
assert!((score3 - 0.5).abs() < 1e-8);
118151
}
119152

120153
#[cfg_attr(
@@ -132,4 +165,36 @@ mod tests {
132165
assert!((score1 - 0.333333333).abs() < 1e-8);
133166
assert!((score2 - 1.0).abs() < 1e-8);
134167
}
168+
169+
#[cfg_attr(
170+
all(target_arch = "wasm32", not(target_os = "wasi")),
171+
wasm_bindgen_test::wasm_bindgen_test
172+
)]
173+
#[test]
174+
fn precision_multiclass_imbalanced() {
175+
let y_true: Vec<f64> = vec![0., 0., 1., 2., 2., 2.];
176+
let y_pred: Vec<f64> = vec![0., 1., 1., 2., 0., 2.];
177+
178+
let score: f64 = Precision::new().get_score(&y_true, &y_pred);
179+
let expected = (0.5 + 0.5 + 1.0) / 3.0;
180+
assert!((score - expected).abs() < 1e-8);
181+
}
182+
183+
#[cfg_attr(
184+
all(target_arch = "wasm32", not(target_os = "wasi")),
185+
wasm_bindgen_test::wasm_bindgen_test
186+
)]
187+
#[test]
188+
fn precision_multiclass_unpredicted_class() {
189+
let y_true: Vec<f64> = vec![0., 0., 1., 2., 2., 2., 3.];
190+
let y_pred: Vec<f64> = vec![0., 1., 1., 2., 0., 2., 0.];
191+
192+
let score: f64 = Precision::new().get_score(&y_true, &y_pred);
193+
// Class 0: pred=3, tp=1 -> 1/3 ≈0.333
194+
// Class 1: pred=2, tp=1 -> 0.5
195+
// Class 2: pred=2, tp=2 -> 1.0
196+
// Class 3: pred=0, tp=0 -> 0.0
197+
let expected = (1.0 / 3.0 + 0.5 + 1.0 + 0.0) / 4.0;
198+
assert!((score - expected).abs() < 1e-8);
199+
}
135200
}

src/metrics/recall.rs

Lines changed: 63 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
//!
55
//! \\[recall = \frac{tp}{tp + fn}\\]
66
//!
7-
//! where tp (true positive) - correct result, fn (false negative) - missing result
7+
//! where tp (true positive) - correct result, fn (false negative) - missing result.
8+
//! For binary classification, this is recall for the positive class (assumed to be 1.0).
9+
//! For multiclass, this is macro-averaged recall (average of per-class recalls).
810
//!
911
//! Example:
1012
//!
@@ -20,8 +22,7 @@
2022
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
2123
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
2224
23-
use std::collections::HashSet;
24-
use std::convert::TryInto;
25+
use std::collections::{HashMap, HashSet};
2526
use std::marker::PhantomData;
2627

2728
#[cfg(feature = "serde")]
@@ -52,7 +53,7 @@ impl<T: RealNumber> Metrics<T> for Recall<T> {
5253
}
5354
}
5455
/// Calculated recall score
55-
/// * `y_true` - cround truth (correct) labels.
56+
/// * `y_true` - ground truth (correct) labels.
5657
/// * `y_pred` - predicted labels, as returned by a classifier.
5758
fn get_score(&self, y_true: &dyn ArrayView1<T>, y_pred: &dyn ArrayView1<T>) -> f64 {
5859
if y_true.shape() != y_pred.shape() {
@@ -63,32 +64,57 @@ impl<T: RealNumber> Metrics<T> for Recall<T> {
6364
);
6465
}
6566

66-
let mut classes = HashSet::new();
67-
for i in 0..y_true.shape() {
68-
classes.insert(y_true.get(i).to_f64_bits());
67+
let n = y_true.shape();
68+
69+
let mut classes_set = HashSet::new();
70+
for i in 0..n {
71+
classes_set.insert(y_true.get(i).to_f64_bits());
6972
}
70-
let classes: i64 = classes.len().try_into().unwrap();
71-
72-
let mut tp = 0;
73-
let mut fne = 0;
74-
for i in 0..y_true.shape() {
75-
if y_pred.get(i) == y_true.get(i) {
76-
if classes == 2 {
77-
if *y_true.get(i) == T::one() {
73+
let classes: usize = classes_set.len();
74+
75+
if classes == 2 {
76+
// Binary case: recall for positive class (assumed T::one())
77+
let positive = T::one();
78+
let mut tp: usize = 0;
79+
let mut fn_count: usize = 0;
80+
for i in 0..n {
81+
let t = *y_true.get(i);
82+
let p = *y_pred.get(i);
83+
if p == t {
84+
if t == positive {
7885
tp += 1;
7986
}
80-
} else {
81-
tp += 1;
87+
} else if t == positive {
88+
fn_count += 1;
8289
}
83-
} else if classes == 2 {
84-
if *y_true.get(i) != T::one() {
85-
fne += 1;
90+
}
91+
if tp + fn_count == 0 {
92+
0.0
93+
} else {
94+
tp as f64 / (tp + fn_count) as f64
95+
}
96+
} else {
97+
// Multiclass case: macro-averaged recall
98+
let mut support: HashMap<u64, usize> = HashMap::new();
99+
let mut tp_map: HashMap<u64, usize> = HashMap::new();
100+
for i in 0..n {
101+
let t_bits = y_true.get(i).to_f64_bits();
102+
*support.entry(t_bits).or_insert(0) += 1;
103+
if *y_true.get(i) == *y_pred.get(i) {
104+
*tp_map.entry(t_bits).or_insert(0) += 1;
86105
}
106+
}
107+
let mut recall_sum = 0.0;
108+
for (&bits, &sup) in &support {
109+
let tp = *tp_map.get(&bits).unwrap_or(&0);
110+
recall_sum += tp as f64 / sup as f64;
111+
}
112+
if support.is_empty() {
113+
0.0
87114
} else {
88-
fne += 1;
115+
recall_sum / support.len() as f64
89116
}
90117
}
91-
tp as f64 / (tp as f64 + fne as f64)
92118
}
93119
}
94120

@@ -115,7 +141,7 @@ mod tests {
115141
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
116142

117143
let score3: f64 = Recall::new().get_score(&y_true, &y_pred);
118-
assert!((score3 - 0.5).abs() < 1e-8);
144+
assert!((score3 - (2.0 / 3.0)).abs() < 1e-8);
119145
}
120146

121147
#[cfg_attr(
@@ -133,4 +159,18 @@ mod tests {
133159
assert!((score1 - 0.333333333).abs() < 1e-8);
134160
assert!((score2 - 1.0).abs() < 1e-8);
135161
}
162+
163+
#[cfg_attr(
164+
all(target_arch = "wasm32", not(target_os = "wasi")),
165+
wasm_bindgen_test::wasm_bindgen_test
166+
)]
167+
#[test]
168+
fn recall_multiclass_imbalanced() {
169+
let y_true: Vec<f64> = vec![0., 0., 1., 2., 2., 2.];
170+
let y_pred: Vec<f64> = vec![0., 1., 1., 2., 0., 2.];
171+
172+
let score: f64 = Recall::new().get_score(&y_true, &y_pred);
173+
let expected = (0.5 + 1.0 + (2.0 / 3.0)) / 3.0;
174+
assert!((score - expected).abs() < 1e-8);
175+
}
136176
}

0 commit comments

Comments
 (0)