44//!
55//! \\[precision = \frac{tp}{tp + fp}\\]
66//!
7- //! where tp (true positive) - correct result, fp (false positive) - unexpected result
7+ //! where tp (true positive) - correct result, fp (false positive) - unexpected result.
8+ //! For binary classification, this is precision for the positive class (assumed to be 1.0).
9+ //! For multiclass, this is macro-averaged precision (average of per-class precisions).
810//!
911//! Example:
1012//!
1921//!
2022//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
2123//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
22- use std:: collections:: HashSet ;
24+
25+ use std:: collections:: { HashMap , HashSet } ;
2326use std:: marker:: PhantomData ;
2427
2528#[ cfg( feature = "serde" ) ]
@@ -61,33 +64,63 @@ impl<T: RealNumber> Metrics<T> for Precision<T> {
6164 ) ;
6265 }
6366
64- let mut classes = HashSet :: new ( ) ;
65- for i in 0 ..y_true. shape ( ) {
66- classes. insert ( y_true. get ( i) . to_f64_bits ( ) ) ;
67+ let n = y_true. shape ( ) ;
68+
69+ let mut classes_set: HashSet < u64 > = HashSet :: new ( ) ;
70+ for i in 0 ..n {
71+ classes_set. insert ( y_true. get ( i) . to_f64_bits ( ) ) ;
6772 }
68- let classes = classes. len ( ) ;
69-
70- let mut tp = 0 ;
71- let mut fp = 0 ;
72- for i in 0 ..y_true. shape ( ) {
73- if y_pred. get ( i) == y_true. get ( i) {
74- if classes == 2 {
75- if * y_true. get ( i) == T :: one ( ) {
73+ let classes: usize = classes_set. len ( ) ;
74+
75+ if classes == 2 {
76+ // Binary case: precision for positive class (assumed T::one())
77+ let positive = T :: one ( ) ;
78+ let mut tp: usize = 0 ;
79+ let mut fp_count: usize = 0 ;
80+ for i in 0 ..n {
81+ let t = * y_true. get ( i) ;
82+ let p = * y_pred. get ( i) ;
83+ if p == t {
84+ if t == positive {
7685 tp += 1 ;
7786 }
78- } else {
79- tp += 1 ;
87+ } else if t != positive {
88+ fp_count += 1 ;
8089 }
81- } else if classes == 2 {
82- if * y_true. get ( i) == T :: one ( ) {
83- fp += 1 ;
90+ }
91+ if tp + fp_count == 0 {
92+ 0.0
93+ } else {
94+ tp as f64 / ( tp + fp_count) as f64
95+ }
96+ } else {
97+ // Multiclass case: macro-averaged precision
98+ let mut predicted: HashMap < u64 , usize > = HashMap :: new ( ) ;
99+ let mut tp_map: HashMap < u64 , usize > = HashMap :: new ( ) ;
100+ for i in 0 ..n {
101+ let p_bits = y_pred. get ( i) . to_f64_bits ( ) ;
102+ * predicted. entry ( p_bits) . or_insert ( 0 ) += 1 ;
103+ if * y_true. get ( i) == * y_pred. get ( i) {
104+ * tp_map. entry ( p_bits) . or_insert ( 0 ) += 1 ;
84105 }
106+ }
107+ let mut precision_sum = 0.0 ;
108+ for & bits in & classes_set {
109+ let pred_count = * predicted. get ( & bits) . unwrap_or ( & 0 ) ;
110+ let tp = * tp_map. get ( & bits) . unwrap_or ( & 0 ) ;
111+ let prec = if pred_count > 0 {
112+ tp as f64 / pred_count as f64
113+ } else {
114+ 0.0
115+ } ;
116+ precision_sum += prec;
117+ }
118+ if classes == 0 {
119+ 0.0
85120 } else {
86- fp += 1 ;
121+ precision_sum / classes as f64
87122 }
88123 }
89-
90- tp as f64 / ( tp as f64 + fp as f64 )
91124 }
92125}
93126
@@ -114,7 +147,7 @@ mod tests {
114147 let y_pred: Vec < f64 > = vec ! [ 0. , 0. , 1. , 1. , 1. , 1. ] ;
115148
116149 let score3: f64 = Precision :: new ( ) . get_score ( & y_true, & y_pred) ;
117- assert ! ( ( score3 - 0.6666666666 ) . abs( ) < 1e-8 ) ;
150+ assert ! ( ( score3 - 0.5 ) . abs( ) < 1e-8 ) ;
118151 }
119152
120153 #[ cfg_attr(
@@ -132,4 +165,36 @@ mod tests {
132165 assert ! ( ( score1 - 0.333333333 ) . abs( ) < 1e-8 ) ;
133166 assert ! ( ( score2 - 1.0 ) . abs( ) < 1e-8 ) ;
134167 }
168+
169+ #[ cfg_attr(
170+ all( target_arch = "wasm32" , not( target_os = "wasi" ) ) ,
171+ wasm_bindgen_test:: wasm_bindgen_test
172+ ) ]
173+ #[ test]
174+ fn precision_multiclass_imbalanced ( ) {
175+ let y_true: Vec < f64 > = vec ! [ 0. , 0. , 1. , 2. , 2. , 2. ] ;
176+ let y_pred: Vec < f64 > = vec ! [ 0. , 1. , 1. , 2. , 0. , 2. ] ;
177+
178+ let score: f64 = Precision :: new ( ) . get_score ( & y_true, & y_pred) ;
179+ let expected = ( 0.5 + 0.5 + 1.0 ) / 3.0 ;
180+ assert ! ( ( score - expected) . abs( ) < 1e-8 ) ;
181+ }
182+
183+ #[ cfg_attr(
184+ all( target_arch = "wasm32" , not( target_os = "wasi" ) ) ,
185+ wasm_bindgen_test:: wasm_bindgen_test
186+ ) ]
187+ #[ test]
188+ fn precision_multiclass_unpredicted_class ( ) {
189+ let y_true: Vec < f64 > = vec ! [ 0. , 0. , 1. , 2. , 2. , 2. , 3. ] ;
190+ let y_pred: Vec < f64 > = vec ! [ 0. , 1. , 1. , 2. , 0. , 2. , 0. ] ;
191+
192+ let score: f64 = Precision :: new ( ) . get_score ( & y_true, & y_pred) ;
193+ // Class 0: pred=3, tp=1 -> 1/3 ≈0.333
194+ // Class 1: pred=2, tp=1 -> 0.5
195+ // Class 2: pred=2, tp=2 -> 1.0
196+ // Class 3: pred=0, tp=0 -> 0.0
197+ let expected = ( 1.0 / 3.0 + 0.5 + 1.0 + 0.0 ) / 4.0 ;
198+ assert ! ( ( score - expected) . abs( ) < 1e-8 ) ;
199+ }
135200}
0 commit comments