From 0228c2599dee91bfe2c5ffa09615b5372bfc4d27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20Cassiers?= Date: Fri, 26 Nov 2021 16:24:57 +0100 Subject: [PATCH] Make rank estimation and probability estimation for LDA run in parallel --- .../scalib-py/src/belief_propagation.rs | 22 ++++++------ src/scalib_ext/scalib-py/src/lda.rs | 2 +- src/scalib_ext/scalib-py/src/lib.rs | 34 +++++++++++-------- 3 files changed, 33 insertions(+), 25 deletions(-) diff --git a/src/scalib_ext/scalib-py/src/belief_propagation.rs b/src/scalib_ext/scalib-py/src/belief_propagation.rs index b20580c8..a8dd65df 100644 --- a/src/scalib_ext/scalib-py/src/belief_propagation.rs +++ b/src/scalib_ext/scalib-py/src/belief_propagation.rs @@ -106,16 +106,18 @@ pub fn run_bp( .map(|x| to_var(x.downcast::().unwrap())) .collect(); - scalib::belief_propagation::run_bp( - &functions_rust, - &mut variables_rust, - it, - edge, - nc, - n, - progress, - ) - .unwrap(); + py.allow_threads(|| { + scalib::belief_propagation::run_bp( + &functions_rust, + &mut variables_rust, + it, + edge, + nc, + n, + progress, + ) + .unwrap(); + }); variables_rust .iter() diff --git a/src/scalib_ext/scalib-py/src/lda.rs b/src/scalib_ext/scalib-py/src/lda.rs index 4128cec7..9501d200 100644 --- a/src/scalib_ext/scalib-py/src/lda.rs +++ b/src/scalib_ext/scalib-py/src/lda.rs @@ -96,7 +96,7 @@ impl LDA { /// x : traces with shape (n,ns) /// return prs with shape (n,nc). Every row corresponds to one probability distribution fn predict_proba<'py>( - &mut self, + &self, py: Python<'py>, x: PyReadonlyArray2, ) -> PyResult<&'py PyArray2> { diff --git a/src/scalib_ext/scalib-py/src/lib.rs b/src/scalib_ext/scalib-py/src/lib.rs index 8f8cfa85..9fc5c108 100644 --- a/src/scalib_ext/scalib-py/src/lib.rs +++ b/src/scalib_ext/scalib-py/src/lib.rs @@ -70,6 +70,7 @@ fn _scalib_ext(_py: Python, m: &PyModule) -> PyResult<()> { #[pyfn(m, "rank_accuracy")] fn rank_accuracy( + py: Python, costs: Vec>, key: Vec, acc: f64, @@ -77,32 +78,37 @@ fn _scalib_ext(_py: Python, m: &PyModule) -> PyResult<()> { method: String, max_nb_bin: usize, ) -> PyResult<(f64, f64, f64)> { - let res = str2method(&method).unwrap_or_else(|s| panic!("{}", s)); - let res = res.rank_accuracy(&costs, &key, acc, merge, max_nb_bin); - match res { - Ok(res) => Ok((res.min, res.est, res.max)), - Err(s) => { - panic!("{}", s); + py.allow_threads(|| { + let res = str2method(&method).unwrap_or_else(|s| panic!("{}", s)); + let res = res.rank_accuracy(&costs, &key, acc, merge, max_nb_bin); + match res { + Ok(res) => Ok((res.min, res.est, res.max)), + Err(s) => { + panic!("{}", s); + } } - } + }) } #[pyfn(m, "rank_nbin")] fn rank_nbin( + py: Python, costs: Vec>, key: Vec, nb_bin: usize, merge: Option, method: String, ) -> PyResult<(f64, f64, f64)> { - let res = str2method(&method).unwrap_or_else(|s| panic!("{}", s)); - let res = res.rank_nbin(&costs, &key, nb_bin, merge); - match res { - Ok(res) => Ok((res.min, res.est, res.max)), - Err(s) => { - panic!("{}", s); + py.allow_threads(|| { + let res = str2method(&method).unwrap_or_else(|s| panic!("{}", s)); + let res = res.rank_nbin(&costs, &key, nb_bin, merge); + match res { + Ok(res) => Ok((res.min, res.est, res.max)), + Err(s) => { + panic!("{}", s); + } } - } + }) } Ok(())