Utils that can be re-used by other pieces of code in the module.
- SUPPORTED_FLOAT_TYPES
- SUPPORTED_INT_TYPES
- SUPPORTED_TYPES
- MAX_BITWIDTH_BACKWARD_COMPATIBLE
- USE_OLD_VL
- QUANT_ROUND_LIKE_ROUND_PBS
replace_invalid_arg_name_chars(arg_name: str) → str
Sanitize arg_name, replacing invalid chars by _.
This does not check that the starting character of arg_name is valid.
Args:
arg_name
(str): the arg name to sanitize.
Returns:
str
: the sanitized arg name, with only chars in _VALID_ARG_CHARS.
generate_proxy_function(
function_to_proxy: Callable,
desired_functions_arg_names: Iterable[str]
) → Tuple[Callable, Dict[str, str]]
Generate a proxy function for a function accepting only *args type arguments.
This returns a runtime compiled function with the sanitized argument names passed in desired_functions_arg_names as the arguments to the function.
Args:
function_to_proxy
(Callable): the function defined like def f(*args) for which to return a function like f_proxy(arg_1, arg_2) for any number of arguments.desired_functions_arg_names
(Iterable[str]): the argument names to use, these names are sanitized and the mapping between the original argument name to the sanitized one is returned in a dictionary. Only the sanitized names will work for a call to the proxy function.
Returns:
Tuple[Callable, Dict[str, str]]
: the proxy function and the mapping of the original arg name to the new and sanitized arg names.
get_onnx_opset_version(onnx_model: ModelProto) → int
Return the ONNX opset_version.
Args:
onnx_model
(onnx.ModelProto): the model.
Returns:
int
: the version of the model
manage_parameters_for_pbs_errors(
p_error: Optional[float] = None,
global_p_error: Optional[float] = None
)
Return (p_error, global_p_error) that we want to give to Concrete.
The returned (p_error, global_p_error) depends on user's parameters and the way we want to manage defaults in Concrete ML, which may be different from the way defaults are managed in Concrete.
Principle: - if none are set, we set global_p_error to a default value of our choice - if both are set, we raise an error - if one is set, we use it and forward it to Concrete
Note that global_p_error is currently set to 0 in the FHE simulation mode.
Args:
p_error
(Optional[float]): probability of error of a single PBS.global_p_error
(Optional[float]): probability of error of the full circuit.
Returns:
(p_error, global_p_error)
: parameters to give to the compiler
Raises:
ValueError
: if the two parameters are set (this is not as in Concrete-Python)
check_there_is_no_p_error_options_in_configuration(configuration)
Check the user did not set p_error or global_p_error in configuration.
It would be dangerous, since we set them in direct arguments in our calls to Concrete-Python.
Args:
configuration
: Configuration object to use during compilation
get_model_class(model_class)
Return the class of the model (instantiated or not), which can be a partial() instance.
Args:
model_class
: The model, which can be a partial() instance.
Returns: The model's class.
is_model_class_in_a_list(model_class, a_list)
Indicate if a model class, which can be a partial() instance, is an element of a_list.
Args:
model_class
: The model, which can be a partial() instance.a_list
: The list in which to look into.
Returns: If the model's class is in the list or not.
get_model_name(model_class)
Return the name of the model, which can be a partial() instance.
Args:
model_class
: The model, which can be a partial() instance.
Returns: the model's name.
is_classifier_or_partial_classifier(model_class)
Indicate if the model class represents a classifier.
Args:
model_class
: The model class, which can be a functool'spartial
class.
Returns:
bool
: If the model class represents a classifier.
is_regressor_or_partial_regressor(model_class)
Indicate if the model class represents a regressor.
Args:
model_class
: The model class, which can be a functool'spartial
class.
Returns:
bool
: If the model class represents a regressor.
is_pandas_dataframe(input_container: Any) → bool
Indicate if the input container is a Pandas DataFrame.
This function is inspired from Scikit-Learn's test validation tools and avoids the need to add and import Pandas as an additional dependency to the project. See https://github.com/scikit-learn/scikit-learn/blob/98cf537f5/sklearn/utils/validation.py#L629
Args:
input_container
(Any): The input container to consider
Returns:
bool
: If the input container is a DataFrame
is_pandas_series(input_container: Any) → bool
Indicate if the input container is a Pandas Series.
This function is inspired from Scikit-Learn's test validation tools and avoids the need to add and import Pandas as an additional dependency to the project. See https://github.com/scikit-learn/scikit-learn/blob/98cf537f5/sklearn/utils/validation.py#L629
Args:
input_container
(Any): The input container to consider
Returns:
bool
: If the input container is a Series
is_pandas_type(input_container: Any) → bool
Indicate if the input container is a Pandas DataFrame or Series.
Args:
input_container
(Any): The input container to consider
Returns:
bool
: If the input container is a DataFrame orSeries
check_dtype_and_cast(
values: Any,
expected_dtype: str,
error_information: Optional[str] = ''
)
Convert any allowed type into an array and cast it if required.
If values types don't match with any supported type or the expected dtype, raise a ValueError.
Args:
values
(Any): The values to considerexpected_dtype
(str): The expected dtype, either "float32" or "int64"error_information
(str): Additional information to put in front of the error message when raising a ValueError. Default to None.
Returns:
(Union[numpy.ndarray, torch.utils.data.dataset.Subset])
: The values with proper dtype.
Raises:
ValueError
: If the values' dtype don't match the expected one or casting is not possible.
compute_bits_precision(x: ndarray) → int
Compute the number of bits required to represent x.
Args:
x
(numpy.ndarray): Integer data
Returns:
int
: the number of bits required to represent x
is_brevitas_model(model: Module) → bool
Check if a model is a Brevitas type.
Args:
model
: PyTorch model.
Returns:
bool
: True ifmodel
is a Brevitas network.
to_tuple(x: Any) → tuple
Make the input a tuple if it is not already the case.
Args:
x
(Any): The input to consider. It can already be an input.
Returns:
tuple
: The input as a tuple.
all_values_are_integers(*values: Any) → bool
Indicate if all unpacked values are of a supported integer dtype.
Args:
*values (Any)
: The values to consider.
Returns:
bool
: Whether all values are supported integers or not.
all_values_are_floats(*values: Any) → bool
Indicate if all unpacked values are of a supported float dtype.
Args:
*values (Any)
: The values to consider.
Returns:
bool
: Whether all values are supported floating points or not.
all_values_are_of_dtype(*values: Any, dtypes: Union[str, List[str]]) → bool
Indicate if all unpacked values are of the specified dtype(s).
Args:
*values (Any)
: The values to consider.dtypes
(Union[str, List[str]]): The dtype(s) to consider.
Returns:
bool
: Whether all values are of the specified dtype(s) or not.
array_allclose_and_same_shape(
a,
b,
rtol: float = 1e-05,
atol: float = 1e-08,
equal_nan: bool = False
) → bool
Check if two numpy arrays are equal within a tolerances and have the same shape.
Args:
a
(numpy.ndarray): The first input arrayb
(numpy.ndarray): The second input arrayrtol
(float): The relative tolerance parameteratol
(float): The absolute tolerance parameterequal_nan
(bool): Whether to compare NaN’s as equal. If True, NaN’s in a will be considered equal to NaN’s in b in the output array
Returns:
bool
: True if the arrays have the same shape and all elements are equal within the specified tolerances, False otherwise.
process_rounding_threshold_bits(rounding_threshold_bits)
Check and process the rounding_threshold_bits parameter.
Args:
rounding_threshold_bits
(Union[None, int, Dict[str, Union[str, int]]]): Defines precision rounding for model accumulators. Accepts None, an int, or a dict. The dict can specify 'method' (fhe.Exactness.EXACT or fhe.Exactness.APPROXIMATE) and 'n_bits' ('auto' or int)
Returns:
Dict[str, Union[str, int]]
: Processed rounding_threshold_bits dictionary.
Raises:
NotImplementedError
: If 'auto' rounding is specified but not implemented.ValueError
: If an invalid type or value is provided for rounding_threshold_bits.
Enum representing the execution mode.
This enum inherits from str in order to be able to easily compare a string parameter to its equivalent Enum attribute.
Examples: fhe_disable = FheMode.DISABLE
fhe_disable == "disable"
True
>>> fhe_disable == "execute"
False
>>> FheMode.is_valid("simulate")
True
>>> FheMode.is_valid(FheMode.EXECUTE)
True
>>> FheMode.is_valid("predict_in_fhe")
False