-
Notifications
You must be signed in to change notification settings - Fork 143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: use composition for non-interactive encrypted training #660
feat: use composition for non-interactive encrypted training #660
Conversation
62e5ad5
to
ce9fb57
Compare
a420bef
to
2ffb370
Compare
332ee71
to
afca943
Compare
for output_i, input_i in self._composition_mapping.items() | ||
) | ||
|
||
if len(q_results) == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you assert on the shape here ? the input/output shapes should match
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added the checks in the new _add_requant_for_composition
method (name to be confirmed)
2d2ed7a
to
454f8fa
Compare
@@ -290,6 +293,61 @@ def _set_output_quantizers(self) -> List[UniformQuantizer]: | |||
) | |||
return output_quantizers | |||
|
|||
# Remove this once we handle the re-quantization step in post-training only | |||
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4472 | |||
def _add_requant_for_composition(self, composition_mapping: Optional[Dict]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
new (private) method for quantized module (avoids adding a param to the init and thus keep thing really internal)
max_output_pos = len(self.output_quantizers) - 1 | ||
max_input_pos = len(self.input_quantizers) - 1 | ||
|
||
for output_position, input_position in composition_mapping.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make sure the mapping is of the form {0:1, 3:2}
|
||
# Ignore [arg-type] check from mypy as it is not able to see that the input to `quant` | ||
# cannot be None | ||
q_x = tuple( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these are needed to match how CP works with encrypt, ie encrypt(None, x) = None, x_enc
, since we do not encrypt all inputs at the same time with composition
|
||
# Similarly, we only quantize the weight and bias values using the third and fourth | ||
# position parameter | ||
_, _, q_weights, q_bias = self.training_quantized_module.quantize_input( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -181,16 +184,32 @@ def _compile_torch_or_onnx_model( | |||
for each input. By default all arguments will be encrypted. | |||
reduce_sum_copy (bool): if the inputs of QuantizedReduceSum should be copied to avoid | |||
bit-width propagation | |||
composition_mapping (Optional[Dict]): Dictionary that maps output positions with input |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
adding this new parameter to the private funct _compile_torch_or_onnx_model
instead of the other public ones to keep things internal
# If a mapping between input and output quantizers is set, add a re-quantization step at the | ||
# end of the forward call. This is only useful for composable circuits in order to make sure | ||
# that input and output quantizers match | ||
if composition_mapping is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is where we decide to add the requant step or not
# Additionally, there is no point in computing the following in case of a partial fit, | ||
# as it only represents a single iteration | ||
if self.early_stopping and not is_partial_fit: | ||
weights_float, bias_float = self._decrypt_dequantize_training_output( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we keep early stopping possible with composition by adding this decrypt/dequant step here (since this is only for development, we believe it's not really an issue to do that)
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4477 | ||
# We should also rename the input arguments to remove the `serialized` part, as we now accept | ||
# both serialized and deserialized input values | ||
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4476 | ||
def run( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we now allow serialized and deserialized inputs (avoids having to deser + ser at each server call with composition)
@@ -357,97 +388,78 @@ def get_serialized_evaluation_keys(self) -> bytes: | |||
return self.client.evaluation_keys.serialize() | |||
|
|||
def quantize_encrypt_serialize( | |||
self, x: Union[numpy.ndarray, Tuple[numpy.ndarray, ...]] | |||
) -> Union[bytes, Tuple[bytes, ...]]: | |||
self, *x: Optional[numpy.ndarray] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we now allow unpacking. This is not a breaking change since allow tuples has been added by @jfrery only recently
this is mainly to make things more coherent with other methods + concrete
def deserialize_decrypt( | ||
self, serialized_encrypted_quantized_result: Union[bytes, Tuple[bytes, ...]] | ||
self, *serialized_encrypted_quantized_result: Optional[bytes] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
def deserialize_decrypt_dequantize( | ||
self, serialized_encrypted_quantized_result: Union[bytes, Tuple[bytes, ...]] | ||
) -> numpy.ndarray: | ||
self, *serialized_encrypted_quantized_result: Optional[bytes] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
4cc0545
to
c54b458
Compare
a8eaab9
to
9f352b7
Compare
I did the analysis -> #660 (comment). All in all, looks like the convergence isn't impacted. Good to go! |
7a52745
to
1037569
Compare
156fb4d
to
9328d57
Compare
Coverage passed ✅Coverage details
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very well done! thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All good thanks!
it's still a WIP
closes https://github.com/zama-ai/concrete-ml-internal/issues/4374
closes https://github.com/zama-ai/concrete-ml-internal/issues/4485 - analysis of the new training approach