Skip to content

Commit

Permalink
chore: make composition mapping internal
Browse files Browse the repository at this point in the history
  • Loading branch information
RomanBredehoft committed Jun 3, 2024
1 parent 287129b commit 2d2ed7a
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 23 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ dmypy.json

# Experiments directory
playground/
.playground/

# File generated by benchmarks
progress.json
Expand Down
1 change: 1 addition & 0 deletions src/concrete/ml/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def get_equivalent_numpy_forward_from_torch(
use_tempfile: bool = output_onnx_file is None

arguments = list(inspect.signature(torch_module.forward).parameters)

# Export to ONNX
torch.onnx.export(
torch_module,
Expand Down
6 changes: 1 addition & 5 deletions src/concrete/ml/quantization/post_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,17 +681,14 @@ def _quantize_layers(self, *input_calibration_data: numpy.ndarray):
node_results[output_name] = node_output[0]
constants.add(output_name)

def quantize_module(
self, *calibration_data: numpy.ndarray, composition_mapping: Optional[Dict] = None
) -> QuantizedModule:
def quantize_module(self, *calibration_data: numpy.ndarray) -> QuantizedModule:
"""Quantize numpy module.
Following https://arxiv.org/abs/1712.05877 guidelines.
Args:
calibration_data (numpy.ndarray): Data that will be used to compute the bounds,
scales and zero point values for every quantized object.
force_output_requant (bool):
Returns:
QuantizedModule: Quantized numpy module
Expand All @@ -711,7 +708,6 @@ def quantize_module(
),
quant_layers_dict=self.quant_ops_dict,
onnx_model=self.numpy_model.onnx_model,
composition_mapping=composition_mapping,
)

adapter = PowerOfTwoScalingRoundPBSAdapter(quantized_module)
Expand Down
62 changes: 59 additions & 3 deletions src/concrete/ml/quantization/quantized_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def __init__(
ordered_module_output_names: Optional[Iterable[str]] = None,
quant_layers_dict: Optional[Dict[str, Tuple[Tuple[str, ...], QuantizedOp]]] = None,
onnx_model: Optional[onnx.ModelProto] = None,
composition_mapping: Optional[Dict] = None,
):

all_or_none_params = [
Expand Down Expand Up @@ -140,8 +139,8 @@ def __init__(
else:
self.output_quantizers = []

# TODO: add check for inputs and outputs
self._composition_mapping = composition_mapping
# Input-output quantizer mapping for composition is not enabled at initialization
self._composition_mapping: Optional[Dict] = None

# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4127
def set_reduce_sum_copy(self):
Expand Down Expand Up @@ -292,6 +291,61 @@ def _set_output_quantizers(self) -> List[UniformQuantizer]:
)
return output_quantizers

# Remove this once we handle the re-quantization step in post-training only
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4472
def _add_requant_for_composition(self, composition_mapping: Optional[Dict]):
"""Trigger a re-quantization step for outputs using an input-output mapping for quantizers.
Args:
composition_mapping (Optional[Dict]): Dictionary that maps output positions with input
positions in the case of composable circuits. Setting this parameter triggers a
re-quantization step at the end of the FHE circuit. This makes sure outputs are
de-quantized using their output quantizer and then re-quantized using their
associated input quantizer. Default to None.
Raises:
ValueError: If the mapping is not properly constructed: it must be a dictionary of
positive integers, mapping output positions to input positions, where positions
must not be greater than the model's number of outputs/inputs.
"""
if not isinstance(composition_mapping, Dict):
raise ValueError(
"Parameter 'composition_mapping' mus be a dictionary. Got "
f"{type(composition_mapping)}"
)

max_output_pos = len(self.output_quantizers) - 1
max_input_pos = len(self.input_quantizers) - 1

for output_position, input_position in composition_mapping.items():
if not isinstance(output_position, int) or output_position < 0:
raise ValueError(
"Output positions (keys) must be positive integers. Got "
f"{type(output_position)}"
)

if output_position > max_output_pos:
raise ValueError(
"Output positions (keys) must not be greater than the model's number of "
f"outputs. Expected position '{max_output_pos}' at most, but got "
f"'{output_position}'"
)

if not isinstance(input_position, int) or input_position < 0:
raise ValueError(
"Input positions (values) must be positive integers. Got "
f"{type(input_position)}"
)

if input_position > max_input_pos:
raise ValueError(
"Input positions (values) must not be greater than the model's number of "
f"inputs. Expected position '{max_input_pos}' at most, but got "
f"'{input_position}'"
)

self._composition_mapping = composition_mapping

@property
def onnx_model(self):
"""Get the ONNX model.
Expand Down Expand Up @@ -491,6 +545,8 @@ def _clear_forward(
elt.qvalues for elt in output_quantized_arrays if isinstance(elt, QuantizedArray)
)

# Remove this once we handle the re-quantization step in post-training only
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4472
if self._composition_mapping is not None:
q_results = tuple(
self.input_quantizers[input_i].quant(
Expand Down
2 changes: 1 addition & 1 deletion src/concrete/ml/sklearn/_fhe_training_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Utility functions for FHE training."""

from typing import List, Optional, Tuple
from typing import Tuple

import numpy
import torch
Expand Down
4 changes: 2 additions & 2 deletions src/concrete/ml/sklearn/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ..common.utils import FheMode
from ..onnx.ops_impl import numpy_sigmoid
from ..quantization import QuantizedModule
from ..torch.compile import compile_torch_model
from ..torch.compile import _compile_torch_or_onnx_model
from ._fhe_training_utils import LogisticRegressionTraining
from .base import (
Data,
Expand Down Expand Up @@ -367,7 +367,7 @@ def _get_training_quantized_module(
print("Compiling training circuit ...")

start = time.time()
training_quantized_module = compile_torch_model(
training_quantized_module = _compile_torch_or_onnx_model(
trainer,
compile_set,
n_bits=self.n_bits_training,
Expand Down
33 changes: 21 additions & 12 deletions src/concrete/ml/torch/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def build_quantized_module(
n_bits: Union[int, Dict[str, int]] = MAX_BITWIDTH_BACKWARD_COMPATIBLE,
rounding_threshold_bits: Union[None, int, Dict[str, Union[str, int]]] = None,
reduce_sum_copy=False,
composition_mapping: Optional[Dict] = None,
) -> QuantizedModule:
"""Build a quantized module from a Torch or ONNX model.
Expand Down Expand Up @@ -125,9 +124,7 @@ def build_quantized_module(
# FIXME: mismatch here. We traced with dummy_input_for_tracing which made some operator
# only work over shape of (1, ., .). For example, some reshape have newshape hardcoded based
# on the inputset we sent in the NumpyModule.
quantized_module = post_training_quant.quantize_module(
*inputset_as_numpy_tuple, composition_mapping=composition_mapping
)
quantized_module = post_training_quant.quantize_module(*inputset_as_numpy_tuple)

# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4127
if reduce_sum_copy:
Expand Down Expand Up @@ -187,20 +184,30 @@ def _compile_torch_or_onnx_model(
for each input. By default all arguments will be encrypted.
reduce_sum_copy (bool): if the inputs of QuantizedReduceSum should be copied to avoid
bit-width propagation
composition_mapping (Optional[Dict]): Dictionary that maps output positions with input
positions in the case of composable circuits. Setting this parameter triggers a
re-quantization step at the end of the FHE circuit. This makes sure outputs are
de-quantized using their output quantizer and then re-quantized using their associated
input quantizer. Default to None.
Returns:
QuantizedModule: The resulting compiled QuantizedModule.
Raises:
ValueError: If a input-output mapping ('composition_mapping') is set but composition is not
enabled at the Concrete level (in 'configuration').
"""
rounding_threshold_bits = process_rounding_threshold_bits(rounding_threshold_bits)

inputset_as_numpy_tuple = tuple(
convert_torch_tensor_or_numpy_array_to_numpy_array(val) for val in to_tuple(torch_inputset)
)

if composition_mapping is not None and not configuration.composable:
# Check that composition is enabled if an input-output mapping has been set
if composition_mapping is not None and (configuration is None or not configuration.composable):
raise ValueError(
"Please enable the composition feature in order to be able to take the mapping between "
"inputs and output quantizers into account."
"Composition must be enabled in 'configuration' in order to trigger a re-quantization "
"step on the circuit's outputs."
)

# Build the quantized module
Expand All @@ -211,7 +218,6 @@ def _compile_torch_or_onnx_model(
n_bits=n_bits,
rounding_threshold_bits=rounding_threshold_bits,
reduce_sum_copy=reduce_sum_copy,
composition_mapping=composition_mapping,
)

# Check that p_error or global_p_error is not set in both the configuration and in the direct
Expand All @@ -232,6 +238,13 @@ def _compile_torch_or_onnx_model(
# Find the right way to set parameters for compiler, depending on the way we want to default
p_error, global_p_error = manage_parameters_for_pbs_errors(p_error, global_p_error)

# If a mapping between input and output quantizers is set, add a re-quantization step at the
# end of the forward call. This is only useful for composable circuits in order to make sure
# that input and output quantizers match
if composition_mapping is not None:
# pylint: disable-next=protected-access
quantized_module._add_requant_for_composition(composition_mapping)

quantized_module.compile(
inputset_as_numpy_tuple,
configuration,
Expand Down Expand Up @@ -261,7 +274,6 @@ def compile_torch_model(
verbose: bool = False,
inputs_encryption_status: Optional[Sequence[str]] = None,
reduce_sum_copy: bool = False,
composition_mapping: Optional[Dict] = None,
) -> QuantizedModule:
"""Compile a torch module into an FHE equivalent.
Expand Down Expand Up @@ -328,11 +340,9 @@ def compile_torch_model(
verbose=verbose,
inputs_encryption_status=inputs_encryption_status,
reduce_sum_copy=reduce_sum_copy,
composition_mapping=composition_mapping,
)


# TODO: add 'composition_mapping' here as well
# pylint: disable-next=too-many-arguments
def compile_onnx_model(
onnx_model: onnx.ModelProto,
Expand Down Expand Up @@ -413,7 +423,6 @@ def compile_onnx_model(
)


# TODO: add 'composition_mapping' here as well ?
# pylint: disable-next=too-many-arguments
def compile_brevitas_qat_model(
torch_model: torch.nn.Module,
Expand Down
58 changes: 58 additions & 0 deletions tests/torch/test_compile_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
# packages/projects, disable the warning
# pylint: disable=ungrouped-imports
from concrete.ml.torch.compile import (
_compile_torch_or_onnx_model,
build_quantized_module,
compile_brevitas_qat_model,
compile_onnx_model,
Expand Down Expand Up @@ -1476,3 +1477,60 @@ def test_rounding_mode(rounding_method, expected_reinterpret, default_configurat
), "Expected 'reinterpret_precision' found but 'round' should not be present."
else:
assert "reinterpret_precision" not in mlir, "Unexpected 'reinterpret_precision' found."


def test_composition_mapping_error_raise(default_configuration):
"""Test that using composition mappings in a wrong manner raises the proper errors."""
model = FCSmall(input_output=5, activation_function=nn.ReLU)
torch_inputset = torch.randn(10, 5)
composition_mapping = {0: 2}

with pytest.raises(ValueError, match="Composition must be enabled in 'configuration'.*"):
_compile_torch_or_onnx_model(
model,
torch_inputset,
configuration=default_configuration,
composition_mapping=composition_mapping,
)

default_configuration.composable = True

composition_mapping = {-1: 2}

with pytest.raises(ValueError, match=r"Output positions \(keys\) must be positive integers.*"):
_compile_torch_or_onnx_model(
model,
torch_inputset,
configuration=default_configuration,
composition_mapping=composition_mapping,
)

composition_mapping = {0: -2}

with pytest.raises(ValueError, match=r"Input positions \(values\) must be positive integers.*"):
_compile_torch_or_onnx_model(
model,
torch_inputset,
configuration=default_configuration,
composition_mapping=composition_mapping,
)

composition_mapping = {10: 2}

with pytest.raises(ValueError, match=r"Output positions \(keys\) must not be greater.*"):
_compile_torch_or_onnx_model(
model,
torch_inputset,
configuration=default_configuration,
composition_mapping=composition_mapping,
)

composition_mapping = {0: 20}

with pytest.raises(ValueError, match=r"Input positions \(values\) must not be greater.*"):
_compile_torch_or_onnx_model(
model,
torch_inputset,
configuration=default_configuration,
composition_mapping=composition_mapping,
)

0 comments on commit 2d2ed7a

Please sign in to comment.