diff --git a/src/scalib_ext/scalib-py/src/belief_propagation.rs b/src/scalib_ext/scalib-py/src/belief_propagation.rs index b20580c8..a8dd65df 100644 --- a/src/scalib_ext/scalib-py/src/belief_propagation.rs +++ b/src/scalib_ext/scalib-py/src/belief_propagation.rs @@ -106,16 +106,18 @@ pub fn run_bp( .map(|x| to_var(x.downcast::().unwrap())) .collect(); - scalib::belief_propagation::run_bp( - &functions_rust, - &mut variables_rust, - it, - edge, - nc, - n, - progress, - ) - .unwrap(); + py.allow_threads(|| { + scalib::belief_propagation::run_bp( + &functions_rust, + &mut variables_rust, + it, + edge, + nc, + n, + progress, + ) + .unwrap(); + }); variables_rust .iter() diff --git a/src/scalib_ext/scalib-py/src/lda.rs b/src/scalib_ext/scalib-py/src/lda.rs index 4128cec7..9501d200 100644 --- a/src/scalib_ext/scalib-py/src/lda.rs +++ b/src/scalib_ext/scalib-py/src/lda.rs @@ -96,7 +96,7 @@ impl LDA { /// x : traces with shape (n,ns) /// return prs with shape (n,nc). Every row corresponds to one probability distribution fn predict_proba<'py>( - &mut self, + &self, py: Python<'py>, x: PyReadonlyArray2, ) -> PyResult<&'py PyArray2> { diff --git a/src/scalib_ext/scalib-py/src/lib.rs b/src/scalib_ext/scalib-py/src/lib.rs index 8f8cfa85..9fc5c108 100644 --- a/src/scalib_ext/scalib-py/src/lib.rs +++ b/src/scalib_ext/scalib-py/src/lib.rs @@ -70,6 +70,7 @@ fn _scalib_ext(_py: Python, m: &PyModule) -> PyResult<()> { #[pyfn(m, "rank_accuracy")] fn rank_accuracy( + py: Python, costs: Vec>, key: Vec, acc: f64, @@ -77,32 +78,37 @@ fn _scalib_ext(_py: Python, m: &PyModule) -> PyResult<()> { method: String, max_nb_bin: usize, ) -> PyResult<(f64, f64, f64)> { - let res = str2method(&method).unwrap_or_else(|s| panic!("{}", s)); - let res = res.rank_accuracy(&costs, &key, acc, merge, max_nb_bin); - match res { - Ok(res) => Ok((res.min, res.est, res.max)), - Err(s) => { - panic!("{}", s); + py.allow_threads(|| { + let res = str2method(&method).unwrap_or_else(|s| panic!("{}", s)); + let res = res.rank_accuracy(&costs, &key, acc, merge, max_nb_bin); + match res { + Ok(res) => Ok((res.min, res.est, res.max)), + Err(s) => { + panic!("{}", s); + } } - } + }) } #[pyfn(m, "rank_nbin")] fn rank_nbin( + py: Python, costs: Vec>, key: Vec, nb_bin: usize, merge: Option, method: String, ) -> PyResult<(f64, f64, f64)> { - let res = str2method(&method).unwrap_or_else(|s| panic!("{}", s)); - let res = res.rank_nbin(&costs, &key, nb_bin, merge); - match res { - Ok(res) => Ok((res.min, res.est, res.max)), - Err(s) => { - panic!("{}", s); + py.allow_threads(|| { + let res = str2method(&method).unwrap_or_else(|s| panic!("{}", s)); + let res = res.rank_nbin(&costs, &key, nb_bin, merge); + match res { + Ok(res) => Ok((res.min, res.est, res.max)), + Err(s) => { + panic!("{}", s); + } } - } + }) } Ok(())