diff --git a/src/cryptography/hazmat/asn1/asn1.py b/src/cryptography/hazmat/asn1/asn1.py index 8c01f74985c6..d77581fcc600 100644 --- a/src/cryptography/hazmat/asn1/asn1.py +++ b/src/cryptography/hazmat/asn1/asn1.py @@ -463,7 +463,10 @@ def decorator(cls: type[U]) -> type[U]: inner = declarative_asn1.AnnotatedType( rust_type, declarative_asn1.Annotation() ) - root = declarative_asn1.Type.ValueSet(cls, inner) + # Map from member value to member, used for O(1) lookups when + # decoding. This requires the member values to be hashable. + value_map = {member.value: member for member in members} + root = declarative_asn1.Type.ValueSet(cls, inner, value_map) setattr(cls, "__asn1_root__", root) return cls diff --git a/src/rust/src/declarative_asn1/decode.rs b/src/rust/src/declarative_asn1/decode.rs index a328045faa2d..a20c83fc4ed7 100644 --- a/src/rust/src/declarative_asn1/decode.rs +++ b/src/rust/src/declarative_asn1/decode.rs @@ -3,7 +3,7 @@ // for complete details. use asn1::Parser; -use pyo3::types::{PyAnyMethods, PyListMethods, PyTypeMethods}; +use pyo3::types::{PyAnyMethods, PyDictMethods, PyListMethods, PyTypeMethods}; use crate::asn1::big_byte_slice_to_py_int; use crate::declarative_asn1::types::{ @@ -262,20 +262,13 @@ fn decode_value_set<'a>( parser: &mut Parser<'a>, cls: &pyo3::Py, inner_type: &AnnotatedType, + value_map: &pyo3::Py, annotation: &Annotation, ) -> ParseResult> { let inner_ann_type = value_set_inner_type(py, inner_type, annotation)?; let decoded = decode_annotated_type(py, parser, &inner_ann_type)?; - // NOTE: This is a linear scan over the members of the enum. If this - // ever becomes a performance problem, it could be replaced with a - // value -> member map stored in `Type::ValueSet` (keeping in mind - // that hash-based lookups won't work for the asn1 wrapper types, - // which implement `__eq__` but not `__hash__`). - for member in cls.bind(py).try_iter()? { - let member = member?; - if member.getattr(pyo3::intern!(py, "value"))?.eq(&decoded)? { - return Ok(member); - } + if let Some(member) = value_map.bind(py).get_item(&decoded)? { + return Ok(member); } Err(CryptographyError::Py( pyo3::exceptions::PyValueError::new_err(format!( @@ -452,8 +445,8 @@ pub(crate) fn decode_annotated_type<'a>( ))? } }, - Type::ValueSet(cls, inner_type) => { - decode_value_set(py, parser, cls, inner_type.get(), annotation)? + Type::ValueSet(cls, inner_type, value_map) => { + decode_value_set(py, parser, cls, inner_type.get(), value_map, annotation)? } Type::PyBool() => decode_pybool(py, parser, encoding)?.into_any(), Type::PyInt() => decode_pyint(py, parser, encoding)?.into_any(), diff --git a/src/rust/src/declarative_asn1/encode.rs b/src/rust/src/declarative_asn1/encode.rs index 139978fc173b..669189b00a6e 100644 --- a/src/rust/src/declarative_asn1/encode.rs +++ b/src/rust/src/declarative_asn1/encode.rs @@ -181,7 +181,7 @@ impl asn1::Asn1Writable for AnnotatedTypeObject<'_> { ), )) } - Type::ValueSet(cls, inner_type) => { + Type::ValueSet(cls, inner_type, _) => { if !value.is_instance(cls.bind(py))? { return Err(CryptographyError::Py( pyo3::exceptions::PyTypeError::new_err(format!( diff --git a/src/rust/src/declarative_asn1/types.rs b/src/rust/src/declarative_asn1/types.rs index 9eb6f470c36e..d296e493e941 100644 --- a/src/rust/src/declarative_asn1/types.rs +++ b/src/rust/src/declarative_asn1/types.rs @@ -38,8 +38,13 @@ pub enum Type { /// a single underlying ASN.1 type). /// The first element is the Python enum class, the second /// element is the (already converted) underlying type of the - /// member values. - ValueSet(pyo3::Py, pyo3::Py), + /// member values, and the third element is a map from member + /// value to enum member, used when decoding. + ValueSet( + pyo3::Py, + pyo3::Py, + pyo3::Py, + ), // Python types that we map to canonical ASN.1 types // @@ -686,7 +691,7 @@ pub(crate) fn is_tag_valid_for_type( Type::Choice(variants) => variants.bind(py).into_iter().any(|v| { is_tag_valid_for_variant(py, tag, v.cast::().unwrap().get(), encoding) }), - Type::ValueSet(_, t) => is_tag_valid_for_type(py, tag, t.get().inner.get(), encoding), + Type::ValueSet(_, t, _) => is_tag_valid_for_type(py, tag, t.get().inner.get(), encoding), Type::PyBool() => check_tag_with_encoding(bool::TAG, encoding, tag), Type::PyInt() => check_tag_with_encoding(asn1::BigInt::TAG, encoding, tag), Type::PyBytes() => { diff --git a/tests/hazmat/asn1/test_api.py b/tests/hazmat/asn1/test_api.py index b9fe418d6931..46023518c83e 100644 --- a/tests/hazmat/asn1/test_api.py +++ b/tests/hazmat/asn1/test_api.py @@ -428,9 +428,13 @@ def test_fields_of_variant_type(self) -> None: choice = declarative_asn1.Type.Choice(my_list) assert choice._0 is my_list - value_set = declarative_asn1.Type.ValueSet(type(None), ann_type) + my_value_map: dict = {} + value_set = declarative_asn1.Type.ValueSet( + type(None), ann_type, my_value_map + ) assert value_set._0 is type(None) assert value_set._1 is ann_type + assert value_set._2 is my_value_map def test_fields_of_variant_encoding(self) -> None: from cryptography.hazmat.bindings._rust import declarative_asn1