Skip to content

Commit

Permalink
feat: add rounding feature on cml trees
Browse files Browse the repository at this point in the history
- adding new tree-only functions: get_equivalent_numpy_forward_from_onnx_tree and execute_onnx_with_numpy_trees
- adding new comparison operators in ops_impl.py that use rounding feature
- computing the LSB in the first and second stage manually using ONNX

closes zama-ai/concrete-ml-internal#4157
  • Loading branch information
kcelia committed Dec 13, 2023
1 parent 4f67883 commit 064eb82
Show file tree
Hide file tree
Showing 11 changed files with 654 additions and 43 deletions.
90 changes: 79 additions & 11 deletions src/concrete/ml/onnx/convert.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
"""ONNX conversion related code."""

import tempfile
import warnings
from pathlib import Path
from typing import Callable, Tuple, Union
from typing import Callable, Optional, Tuple, Union

import numpy
import onnx
import onnxoptimizer
import torch
from onnx import checker, helper

from .onnx_utils import IMPLEMENTED_ONNX_OPS, execute_onnx_with_numpy, get_op_type
from .onnx_utils import (
IMPLEMENTED_ONNX_OPS,
execute_onnx_with_numpy,
execute_onnx_with_numpy_trees,
get_op_type,
)

OPSET_VERSION_FOR_ONNX_EXPORT = 14

Expand Down Expand Up @@ -145,7 +151,7 @@ def get_equivalent_numpy_forward_from_torch(
output_onnx_file_path.unlink()

equivalent_numpy_forward, equivalent_onnx_model = get_equivalent_numpy_forward_from_onnx(
equivalent_onnx_model, check_model=True
equivalent_onnx_model
)
with output_onnx_file_path.open("wb") as file:
file.write(equivalent_onnx_model.SerializeToString())
Expand All @@ -156,10 +162,7 @@ def get_equivalent_numpy_forward_from_torch(
)


def get_equivalent_numpy_forward_from_onnx(
onnx_model: onnx.ModelProto,
check_model: bool = True,
) -> Tuple[Callable[..., Tuple[numpy.ndarray, ...]], onnx.ModelProto]:
def preprocess_onnx_model(onnx_model: onnx.ModelProto, check_model: bool) -> onnx.ModelProto:
"""Get the numpy equivalent forward of the provided ONNX model.
Args:
Expand All @@ -173,11 +176,20 @@ def get_equivalent_numpy_forward_from_onnx(
model to numpy.
Returns:
Callable[..., Tuple[numpy.ndarray, ...]]: The function that will execute
the equivalent numpy function.
onnx.ModelProto: The preprocessed ONNX model.
"""
if check_model:
checker.check_model(onnx_model)

# All onnx models should be checked, "check_model" parameter must be removed
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4157
if not check_model: # pragma: no cover
warnings.simplefilter("always")
warnings.warn(
"`check_model` parameter should always be set to True, to ensure proper onnx model "
"verification and avoid bypassing essential onnx model validation checks.",
category=UserWarning,
stacklevel=2,
)

checker.check_model(onnx_model)

# Optimize ONNX graph
Expand All @@ -192,6 +204,7 @@ def get_equivalent_numpy_forward_from_onnx(
]
equivalent_onnx_model = onnxoptimizer.optimize(onnx_model, onnx_passes)
checker.check_model(equivalent_onnx_model)

# Custom optimization
# ONNX optimizer does not optimize Mat-Mult + Bias pattern into GEMM if the input isn't a matrix
# We manually do the optimization for this case
Expand All @@ -208,7 +221,62 @@ def get_equivalent_numpy_forward_from_onnx(
f"Available ONNX operators: {', '.join(sorted(IMPLEMENTED_ONNX_OPS))}"
)

return equivalent_onnx_model


def get_equivalent_numpy_forward_from_onnx(
onnx_model: onnx.ModelProto,
check_model: bool = True,
) -> Tuple[Callable[..., Tuple[numpy.ndarray, ...]], onnx.ModelProto]:
"""Get the numpy equivalent forward of the provided ONNX model.
Args:
onnx_model (onnx.ModelProto): the ONNX model for which to get the equivalent numpy
forward.
check_model (bool): set to True to run the onnx checker on the model.
Defaults to True.
Returns:
Callable[..., Tuple[numpy.ndarray, ...]]: The function that will execute
the equivalent numpy function.
"""

equivalent_onnx_model = preprocess_onnx_model(onnx_model, check_model)

# Return lambda of numpy equivalent of onnx execution
return (
lambda *args: execute_onnx_with_numpy(equivalent_onnx_model.graph, *args)
), equivalent_onnx_model


def get_equivalent_numpy_forward_from_onnx_tree(
onnx_model: onnx.ModelProto,
check_model: bool = True,
lsbs_to_remove_for_trees: Optional[Tuple[int, int]] = None,
) -> Tuple[Callable[..., Tuple[numpy.ndarray, ...]], onnx.ModelProto]:
"""Get the numpy equivalent forward of the provided ONNX model for tree-based models only.
Args:
onnx_model (onnx.ModelProto): the ONNX model for which to get the equivalent numpy
forward.
check_model (bool): set to True to run the onnx checker on the model.
Defaults to True.
lsbs_to_remove_for_trees (Optional[Tuple[int, int]]): This parameter is exclusively used for
optimizing tree-based models. It contains the values of the least significant bits to
remove during the tree traversal, where the first value refers to the first comparison
(either "less" or "less_or_equal"), while the second value refers to the "Equal"
comparison operation. Default to None, as it is not applicable to other types of models.
Returns:
Tuple[Callable[..., Tuple[numpy.ndarray, ...]], onnx.ModelProto]: The function that will
execute the equivalent numpy function.
"""

equivalent_onnx_model = preprocess_onnx_model(onnx_model, check_model)

# Return lambda of numpy equivalent of onnx execution
return (
lambda *args: execute_onnx_with_numpy_trees(
equivalent_onnx_model.graph, lsbs_to_remove_for_trees, *args
)
), equivalent_onnx_model
48 changes: 47 additions & 1 deletion src/concrete/ml/onnx/onnx_impl_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
"""Utility functions for onnx operator implementations."""

from typing import Tuple, Union
from typing import Callable, Tuple, Union

import numpy
from concrete.fhe import conv as cnp_conv
from concrete.fhe import ones as cnp_ones
from concrete.fhe import round_bit_pattern
from concrete.fhe.tracing import Tracer

from ..common.debugging import assert_true

ComparisonOperationType = Callable[[int], bool]


def numpy_onnx_pad(
x: numpy.ndarray,
Expand Down Expand Up @@ -225,3 +228,46 @@ def onnx_avgpool_compute_norm_const(
norm_const = float(numpy.prod(numpy.array(kernel_shape)))

return norm_const


# This function needs to be updated when the truncate feature is released.
# The following changes should be made:
# - Remove the `half` term
# - Replace `rounding_bit_pattern` with `truncate_bit_pattern`
# - Potentially replace `lsbs_to_remove` with `auto_truncate`
# - Adjust the typing
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4143
def rounded_comparison(
x: numpy.ndarray, y: numpy.ndarray, lsbs_to_remove: int, operation: ComparisonOperationType
) -> Tuple[bool]:
"""Comparison operation using `round_bit_pattern` function.
`round_bit_pattern` rounds the bit pattern of an integer to the closer
It also checks for any potential overflow. If so, it readjusts the LSBs accordingly.
The parameter `lsbs_to_remove` in `round_bit_pattern` can either be an integer specifying the
number of LSBS to remove, or an `AutoRounder` object that determines the required number of LSBs
based on the specified number of MSBs to retain. But in our case, we choose to compute the LSBs
manually.
Args:
x (numpy.ndarray): Input tensor
y (numpy.ndarray): Input tensor
lsbs_to_remove (int): Number of the least significant bits to remove
operation (ComparisonOperationType): Comparison operation, which can `<`, `<=` and `==`
Returns:
Tuple[bool]: If x and y satisfy the comparison operator.
"""

assert isinstance(lsbs_to_remove, int)

# Workaround: in this context, `round_bit_pattern` is used as a truncate operation.
# Consequently, we subtract a term, called `half` that will subsequently be re-added during the
# `round_bit_pattern` process.
half = 1 << (lsbs_to_remove - 1)

# To determine if 'x' 'operation' 'y' (operation being <, >, >=, <=), we evaluate 'x - y'
rounded_subtraction = round_bit_pattern((x - y) - half, lsbs_to_remove=lsbs_to_remove)

return (operation(rounded_subtraction),)
72 changes: 68 additions & 4 deletions src/concrete/ml/onnx/onnx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,7 @@

# Original file:
# https://github.com/google/jax/blob/f6d329b2d9b5f83c6a59e5739aa1ca8d4d1ffa1c/examples/onnx2xla.py


from typing import Any, Callable, Dict, Tuple
from typing import Any, Callable, Dict, Optional, Tuple

import numpy
import onnx
Expand Down Expand Up @@ -297,6 +295,9 @@
numpy_transpose,
numpy_unsqueeze,
numpy_where,
rounded_numpy_equal_for_trees,
rounded_numpy_less_for_trees,
rounded_numpy_less_or_equal_for_trees,
)

ATTR_TYPES = dict(onnx.AttributeProto.AttributeType.items())
Expand Down Expand Up @@ -406,14 +407,20 @@
"Less": numpy_less,
"LessOrEqual": numpy_less_or_equal,
}
# All numpy operators used for tree-based models that support auto rounding
ONNX_COMPARISON_OPS_TO_ROUNDED_TREES_NUMPY_IMPL_BOOL = {
"Less": rounded_numpy_less_for_trees,
"Equal": rounded_numpy_equal_for_trees,
"LessOrEqual": rounded_numpy_less_or_equal_for_trees,
}


# All numpy operators used in QuantizedOps
ONNX_OPS_TO_NUMPY_IMPL.update(ONNX_COMPARISON_OPS_TO_NUMPY_IMPL_FLOAT)

# All numpy operators used for tree-based models
ONNX_OPS_TO_NUMPY_IMPL_BOOL = {**ONNX_OPS_TO_NUMPY_IMPL, **ONNX_COMPARISON_OPS_TO_NUMPY_IMPL_BOOL}


IMPLEMENTED_ONNX_OPS = set(ONNX_OPS_TO_NUMPY_IMPL.keys())


Expand Down Expand Up @@ -465,6 +472,63 @@ def execute_onnx_with_numpy(
curr_inputs = (node_results[input_name] for input_name in node.input)
attributes = {attribute.name: get_attribute(attribute) for attribute in node.attribute}
outputs = ONNX_OPS_TO_NUMPY_IMPL_BOOL[node.op_type](*curr_inputs, **attributes)
node_results.update(zip(node.output, outputs))

return tuple(node_results[output.name] for output in graph.output)


def execute_onnx_with_numpy_trees(
graph: onnx.GraphProto,
lsbs_to_remove_for_trees: Optional[Tuple[int, int]],
*inputs: numpy.ndarray,
) -> Tuple[numpy.ndarray, ...]:
"""Execute the provided ONNX graph on the given inputs for tree-based models only.
Args:
graph (onnx.GraphProto): The ONNX graph to execute.
lsbs_to_remove_for_trees (Optional[Tuple[int, int]]): This parameter is exclusively used for
optimizing tree-based models. It contains the values of the least significant bits to
remove during the tree traversal, where the first value refers to the first comparison
(either "less" or "less_or_equal"), while the second value refers to the "Equal"
comparison operation.
Default to None.
*inputs: The inputs of the graph.
Returns:
Tuple[numpy.ndarray]: The result of the graph's execution.
"""

op_type: Callable[..., Tuple[numpy.ndarray[Any, Any], ...]]

# If no tree-based optimization is specified, return standard execution
if lsbs_to_remove_for_trees is None:
return execute_onnx_with_numpy(graph, *inputs)

node_results: Dict[str, numpy.ndarray] = dict(
{graph_input.name: input_value for graph_input, input_value in zip(graph.input, inputs)},
**{
initializer.name: numpy_helper.to_array(initializer)
for initializer in graph.initializer
},
)

for node in graph.node:
curr_inputs = (node_results[input_name] for input_name in node.input)
attributes = {attribute.name: get_attribute(attribute) for attribute in node.attribute}

if node.op_type in ONNX_COMPARISON_OPS_TO_ROUNDED_TREES_NUMPY_IMPL_BOOL:

# The first LSB refers to `Less` or `LessOrEqual` comparisons
# The second LSB refers to `Equal` comparison
stage = 0 if node.op_type != "Equal" else 1
attributes["lsbs_to_remove_for_trees"] = lsbs_to_remove_for_trees[stage]

# Use rounded numpy operation to relevant comparison nodes
op_type = ONNX_COMPARISON_OPS_TO_ROUNDED_TREES_NUMPY_IMPL_BOOL[node.op_type]
else:
op_type = ONNX_OPS_TO_NUMPY_IMPL_BOOL[node.op_type]

outputs = op_type(*curr_inputs, **attributes)

node_results.update(zip(node.output, outputs))
return tuple(node_results[output.name] for output in graph.output)
Expand Down
Loading

0 comments on commit 064eb82

Please sign in to comment.