Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
# Ensure these are present (in case we are not using PEP-518 compatible build
# system).
import setuptools_scm
import toml

scalib_features = ["pyo3/abi3"]

Expand Down
5 changes: 4 additions & 1 deletion src/scalib/attacks/sascagraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,13 +223,15 @@ def set_table(self, table, values):
)
self.tables_[table] = values

def run_bp(self, it):
def run_bp(self, it, progress=False):
r"""Runs belief propagation algorithm on the current state of the graph.

Parameters
----------
it : int
Number of iterations of belief propagation.
progress: bool
Show a progress bar (default: False).
"""
if self.solved_:
raise Exception("Cannot run bp twice on a graph.")
Expand All @@ -241,6 +243,7 @@ def run_bp(self, it):
self.edge_,
self.nc_,
self.n_,
progress,
)
self.solved_ = True

Expand Down
14 changes: 13 additions & 1 deletion src/scalib_ext/scalib-py/src/belief_propagation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ pub fn run_bp(
nc: usize,
// number of copies in the graph (n_runs)
n: usize,
// show a progress bar
progress: bool,
) -> PyResult<()> {
// map all python functions to rust ones + generate the mapping in vec_functs_id
let functions_rust: Vec<Func> = functions
Expand All @@ -104,8 +106,18 @@ pub fn run_bp(
.map(|x| to_var(x.downcast::<PyDict>().unwrap()))
.collect();

scalib::belief_propagation::run_bp(&functions_rust, &mut variables_rust, it, edge, nc, n)
py.allow_threads(|| {
scalib::belief_propagation::run_bp(
&functions_rust,
&mut variables_rust,
it,
edge,
nc,
n,
progress,
)
.unwrap();
});

variables_rust
.iter()
Expand Down
2 changes: 1 addition & 1 deletion src/scalib_ext/scalib-py/src/lda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ impl LDA {
/// x : traces with shape (n,ns)
/// return prs with shape (n,nc). Every row corresponds to one probability distribution
fn predict_proba<'py>(
&mut self,
&self,
py: Python<'py>,
x: PyReadonlyArray2<i16>,
) -> PyResult<&'py PyArray2<f64>> {
Expand Down
37 changes: 22 additions & 15 deletions src/scalib_ext/scalib-py/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ fn _scalib_ext(_py: Python, m: &PyModule) -> PyResult<()> {
vertex: usize,
nc: usize,
n: usize,
progress: bool,
) -> PyResult<()> {
belief_propagation::run_bp(py, functions, variables, it, vertex, nc, n)
belief_propagation::run_bp(py, functions, variables, it, vertex, nc, n, progress)
}

#[pyfn(m, "partial_cp")]
Expand All @@ -69,39 +70,45 @@ fn _scalib_ext(_py: Python, m: &PyModule) -> PyResult<()> {

#[pyfn(m, "rank_accuracy")]
fn rank_accuracy(
py: Python,
costs: Vec<Vec<f64>>,
key: Vec<usize>,
acc: f64,
merge: Option<usize>,
method: String,
max_nb_bin: usize,
) -> PyResult<(f64, f64, f64)> {
let res = str2method(&method).unwrap_or_else(|s| panic!("{}", s));
let res = res.rank_accuracy(&costs, &key, acc, merge, max_nb_bin);
match res {
Ok(res) => Ok((res.min, res.est, res.max)),
Err(s) => {
panic!("{}", s);
py.allow_threads(|| {
let res = str2method(&method).unwrap_or_else(|s| panic!("{}", s));
let res = res.rank_accuracy(&costs, &key, acc, merge, max_nb_bin);
match res {
Ok(res) => Ok((res.min, res.est, res.max)),
Err(s) => {
panic!("{}", s);
}
}
}
})
}

#[pyfn(m, "rank_nbin")]
fn rank_nbin(
py: Python,
costs: Vec<Vec<f64>>,
key: Vec<usize>,
nb_bin: usize,
merge: Option<usize>,
method: String,
) -> PyResult<(f64, f64, f64)> {
let res = str2method(&method).unwrap_or_else(|s| panic!("{}", s));
let res = res.rank_nbin(&costs, &key, nb_bin, merge);
match res {
Ok(res) => Ok((res.min, res.est, res.max)),
Err(s) => {
panic!("{}", s);
py.allow_threads(|| {
let res = str2method(&method).unwrap_or_else(|s| panic!("{}", s));
let res = res.rank_nbin(&costs, &key, nb_bin, merge);
match res {
Ok(res) => Ok((res.min, res.est, res.max)),
Err(s) => {
panic!("{}", s);
}
}
}
})
}

Ok(())
Expand Down
29 changes: 20 additions & 9 deletions src/scalib_ext/scalib/src/belief_propagation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,8 @@ pub fn run_bp(
nc: usize,
// number of copies in the graph (n_runs)
n: usize,
// show a progress bar
progress: bool,
) -> Result<(), ()> {
// Scratch array containing all the edge's messages.
let mut edges: Vec<Array2<f64>> = vec![Array2::<f64>::ones((n, nc)); edge];
Expand Down Expand Up @@ -386,15 +388,7 @@ pub fn run_bp(
}
}

// loading bar
let pb = ProgressBar::new(it as u64);
pb.set_style(ProgressStyle::default_spinner().template(
"{msg} {spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] ({pos}/{len}, ETA {eta})",
)
.on_finish(ProgressFinish::AndClear));
pb.set_message("Calculating BP...");

for _ in (0..it).progress_with(pb) {
let mut bp_iter = || {
// This is a technique for runtime borrow-checking: we take reference on all the edges
// at once, put them into options, then extract the references out of the options, one
// at a time and out-of-order.
Expand Down Expand Up @@ -422,6 +416,23 @@ pub fn run_bp(
})
.collect();
update_variables(&mut edge_for_var, variables);
};

if progress {
// loading bar
let pb = ProgressBar::new(it as u64);
pb.set_style(ProgressStyle::default_spinner().template(
"{msg} {spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] ({pos}/{len}, ETA {eta})",
)
.on_finish(ProgressFinish::AndClear));
pb.set_message("Calculating BP...");
for _ in (0..it).progress_with(pb) {
bp_iter();
}
} else {
for _ in 0..it {
bp_iter();
}
}

Ok(())
Expand Down