Skip to content

Commit

Permalink
fix: use output scales for percent tolerance (#299)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethan-crypto committed Jun 15, 2023
1 parent 982d7e0 commit 3dee133
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 43 deletions.
12 changes: 6 additions & 6 deletions src/circuit/ops/chip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ impl From<String> for CheckMode {
#[derive(Clone, Debug, PartialEq, PartialOrd, Serialize, Deserialize, Copy)]
pub enum Tolerance {
Abs { val: usize },
Percentage { val: f32, scale: usize },
Percentage { val: f32, scales: (usize, usize) },
}

impl Default for Tolerance {
Expand All @@ -89,7 +89,7 @@ impl FromStr for Tolerance {
if let Ok(val) = s.parse::<usize>() {
Ok(Tolerance::Abs { val })
} else if let Ok(val) = s.parse::<f32>() {
Ok(Tolerance::Percentage { val, scale: 1 })
Ok(Tolerance::Percentage { val, scales: (1,1) })
} else {
Err("Invalid tolerance value provided. It should be either an absolute value (usize) or a percentage (f32).".to_string())
}
Expand Down Expand Up @@ -127,8 +127,8 @@ impl IntoPy<PyObject> for Tolerance {
fn into_py(self, py: Python) -> PyObject {
match self {
Tolerance::Abs { val } => (String::from("abs"), val).to_object(py),
Tolerance::Percentage { val, scale } => {
(String::from("percentage"), val, scale).to_object(py)
Tolerance::Percentage { val, scales } => {
(String::from("percentage"), val, scales).to_object(py)
}
}
}
Expand All @@ -143,9 +143,9 @@ impl<'source> FromPyObject<'source> for Tolerance {
"abs" => Ok(Tolerance::Abs { val }),
_ => Err(PyValueError::new_err("Invalid value for Tolerance")),
}
} else if let Ok((mode, val, scale)) = ob.extract::<(String, f32, usize)>() {
} else if let Ok((mode, val, scales)) = ob.extract::<(String, f32, (usize, usize))>() {
match mode.to_lowercase().as_str() {
"percentage" => Ok(Tolerance::Percentage { val, scale }),
"percentage" => Ok(Tolerance::Percentage { val, scales }),
_ => Err(PyValueError::new_err("Invalid value for Tolerance")),
}
} else {
Expand Down
9 changes: 5 additions & 4 deletions src/circuit/ops/hybrid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,12 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
offset,
*val as i32,
)?,
Tolerance::Percentage { val, scale } => layouts::range_check_percent(
Tolerance::Percentage { val, scales } => layouts::range_check_percent(
config,
region,
values[..].try_into()?,
*scale,
scales.0,
scales.1,
offset,
*val,
)?,
Expand Down Expand Up @@ -166,8 +167,8 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
]
}
HybridOp::RangeCheck(tol) => match tol {
Tolerance::Percentage { val, scale } => {
let scale = scale.pow(2);
Tolerance::Percentage { val, scales } => {
let scale = scales.0 * scales.1;
vec![
LookupOp::Recip { scale },
LookupOp::GreaterThan {
Expand Down
7 changes: 4 additions & 3 deletions src/circuit/ops/layouts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1973,15 +1973,16 @@ pub fn range_check_percent<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: Arc<Mutex<Option<&mut Region<F>>>>,
values: &[ValTensor<F>; 2],
scale: usize,
input_scale: usize,
output_scale: usize,
offset: &mut usize,
tol: f32,
) -> Result<ValTensor<F>, Box<dyn Error>> {
// Calculate the difference between the expected output and actual output
let diff = pairwise(config, region.clone(), values, offset, BaseOp::Sub)?;

// Calculate the reciprocal of the expected output tensor, scaling by double the scaling factor
let scale = scale.pow(2);
let scale = input_scale * output_scale;
let recip = nonlinearity(
config,
region.clone(),
Expand Down Expand Up @@ -2052,7 +2053,7 @@ pub fn range_check_percent<F: PrimeField + TensorType + PartialOrd>(
Tensor::new(Some(&values[1].get_int_evals()?), values[1].dims())?,
];
let ref_range_check_percent: Tensor<i128> =
tensor::ops::nonlinearities::range_check_percent(int_evals, scale, tol);
tensor::ops::nonlinearities::range_check_percent(int_evals, input_scale, output_scale, tol);
let output_int_evals = Tensor::new(Some(&sum.get_int_evals()?), values[0].dims())?;
assert_eq!(output_int_evals, ref_range_check_percent)
}
Expand Down
2 changes: 1 addition & 1 deletion src/circuit/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1755,7 +1755,7 @@ mod rangecheckpercent {
&mut 0,
Box::new(HybridOp::RangeCheck(Tolerance::Percentage {
val: RANGE,
scale: SCALE,
scales: (SCALE, SCALE),
})),
)
.map_err(|_| Error::Synthesis)
Expand Down
52 changes: 28 additions & 24 deletions src/graph/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,12 +249,17 @@ impl Model {

// if we're using percentage tolerance, we need to add the necessary range check ops for it.
if let Tolerance::Percentage { val, .. } = run_args.tolerance {
let tolerance = Tolerance::Percentage {
val,
scale: scale_to_multiplier(run_args.scale) as usize,
};
let opkind: Box<dyn Op<Fp>> = Box::new(HybridOp::RangeCheck(tolerance));
lookup_ops.extend(opkind.required_lookups());
for scale in self.graph.get_output_scales(){
let tolerance = Tolerance::Percentage {
val,
scales:(
scale_to_multiplier(scale) as usize,
scale_to_multiplier(run_args.scale) as usize,
)
};
let opkind: Box<dyn Op<Fp>> = Box::new(HybridOp::RangeCheck(tolerance));
lookup_ops.extend(opkind.required_lookups());
}
}

let set: HashSet<_> = lookup_ops.drain(..).collect(); // dedup
Expand Down Expand Up @@ -647,17 +652,23 @@ impl Model {

match run_args.output_visibility {
Visibility::Public => {
let tolerance = match run_args.tolerance {
Tolerance::Percentage { val, .. } => Tolerance::Percentage {
val,
scale: scale_to_multiplier(run_args.scale) as usize,
},
_ => run_args.tolerance,
};
let output_scales = self.graph.get_output_scales();
let global_scale = scale_to_multiplier(run_args.scale) as usize;
let _ = outputs
.iter()
.enumerate()
.map(|(i, output)| {
.iter()
.enumerate()
.map(|(i, output)| {
let tolerance = match run_args.tolerance {
Tolerance::Percentage { val, .. } => Tolerance::Percentage {
val,
scales:
(
scale_to_multiplier(output_scales[i]) as usize,
global_scale
),
},
_ => run_args.tolerance,
};
let mut instance_offset = 0;
if self.visibility.input.is_public() {
instance_offset += inputs.len();
Expand Down Expand Up @@ -792,13 +803,6 @@ impl Model {

match run_args.output_visibility {
Visibility::Public => {
let tolerance = match run_args.tolerance {
Tolerance::Percentage { val, .. } => Tolerance::Percentage {
val,
scale: scale_to_multiplier(run_args.scale) as usize,
},
_ => run_args.tolerance,
};
let _ = outputs
.clone()
.into_iter()
Expand All @@ -808,7 +812,7 @@ impl Model {
Arc::new(Mutex::new(None)),
&[output.clone(), output],
&mut offset,
Box::new(HybridOp::RangeCheck(tolerance)),
Box::new(HybridOp::RangeCheck(run_args.tolerance)),
)
.unwrap()
})
Expand Down
10 changes: 5 additions & 5 deletions src/tensor/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2040,17 +2040,17 @@ pub mod nonlinearities {
/// Some(&[103, 204, 303, 404, 505, 607]),
/// &[2, 3],
/// ).unwrap();
/// let result = range_check_percent(&[x, y], 1024, 1.0); // 1% tolerance
/// let result = range_check_percent(&[x, y], 1024, 1024, 1.0); // 1% tolerance
/// let expected = Tensor::<i128>::new(Some(&[1, 1, 0, 0, 0, 1]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn range_check_percent(t: &[Tensor<i128>], scale: usize, tol: f32) -> Tensor<i128> {
pub fn range_check_percent(t: &[Tensor<i128>], input_scale: usize, output_scale: usize, tol: f32) -> Tensor<i128> {
// the more accurate calculation is commented out and we implement as below so it matches the steps in layout
let _double_scale = scale.pow(2);
let scale = input_scale*output_scale;
let diff: Tensor<i128> = sub(t).unwrap();
let recip = recip(&t[0], _double_scale as u32);
let recip = recip(&t[0], scale as u32);
let product = mult(&[diff, recip]).unwrap();
let _tol = ((tol / 100.0) * _double_scale as f32).round() as f64;
let _tol = ((tol / 100.0) * scale as f32).round() as f64;
let upper_bound = greater_than(&product, _tol);
let neg_product =
mult(&[product, Tensor::<i128>::new(Some(&[-1]), &[1]).unwrap()]).unwrap();
Expand Down

0 comments on commit 3dee133

Please sign in to comment.