diff --git a/src/dataset/mod.rs b/src/dataset/mod.rs index b04e48109..077302a59 100644 --- a/src/dataset/mod.rs +++ b/src/dataset/mod.rs @@ -323,6 +323,19 @@ pub trait Labels { fn labels(&self) -> Vec { self.label_set().into_iter().flatten().collect() } + + fn combined_labels(&self, other: Vec) -> Vec { + let mut combined = self.labels(); + combined.extend(other.clone()); + + combined + .iter() + .map(|x| x) + .collect::>() + .into_iter() + .map(|x| x.clone()) + .collect() + } } #[cfg(test)] diff --git a/src/metrics_classification.rs b/src/metrics_classification.rs index 3cf2f30f5..74d053dd7 100644 --- a/src/metrics_classification.rs +++ b/src/metrics_classification.rs @@ -290,7 +290,7 @@ where return Err(Error::MismatchedShapes(targets.len(), ground_truth.len())); } - let classes = self.labels(); + let classes = self.combined_labels(ground_truth.labels()); let indices = map_prediction_to_idx( targets.as_slice().unwrap(), @@ -636,6 +636,19 @@ mod tests { ); } + #[test] + fn test_division_by_zero_cm() { + let ground_truth = Array1::from(vec![1, 1, 0, 1, 0, 1]); + let predicted = Array1::from(vec![0, 0, 0, 0, 0, 0]); + let labels = array![0, 1]; + + let x = predicted.confusion_matrix(ground_truth).unwrap(); + + let f1 = x.f1_score(); + + assert!(f1.is_nan()); + } + #[test] fn test_roc_curve() { let predicted = ArrayView1::from(&[0.1, 0.3, 0.5, 0.7, 0.8, 0.9]).mapv(Pr::new);