Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/cryptography/hazmat/asn1/asn1.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,10 @@ def decorator(cls: type[U]) -> type[U]:
inner = declarative_asn1.AnnotatedType(
rust_type, declarative_asn1.Annotation()
)
root = declarative_asn1.Type.ValueSet(cls, inner)
# Map from member value to member, used for O(1) lookups when
# decoding. This requires the member values to be hashable.
value_map = {member.value: member for member in members}
root = declarative_asn1.Type.ValueSet(cls, inner, value_map)

setattr(cls, "__asn1_root__", root)
return cls
Expand Down
19 changes: 6 additions & 13 deletions src/rust/src/declarative_asn1/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// for complete details.

use asn1::Parser;
use pyo3::types::{PyAnyMethods, PyListMethods, PyTypeMethods};
use pyo3::types::{PyAnyMethods, PyDictMethods, PyListMethods, PyTypeMethods};

use crate::asn1::big_byte_slice_to_py_int;
use crate::declarative_asn1::types::{
Expand Down Expand Up @@ -262,20 +262,13 @@ fn decode_value_set<'a>(
parser: &mut Parser<'a>,
cls: &pyo3::Py<pyo3::types::PyType>,
inner_type: &AnnotatedType,
value_map: &pyo3::Py<pyo3::types::PyDict>,
annotation: &Annotation,
) -> ParseResult<pyo3::Bound<'a, pyo3::PyAny>> {
let inner_ann_type = value_set_inner_type(py, inner_type, annotation)?;
let decoded = decode_annotated_type(py, parser, &inner_ann_type)?;
// NOTE: This is a linear scan over the members of the enum. If this
// ever becomes a performance problem, it could be replaced with a
// value -> member map stored in `Type::ValueSet` (keeping in mind
// that hash-based lookups won't work for the asn1 wrapper types,
// which implement `__eq__` but not `__hash__`).
for member in cls.bind(py).try_iter()? {
let member = member?;
if member.getattr(pyo3::intern!(py, "value"))?.eq(&decoded)? {
return Ok(member);
}
if let Some(member) = value_map.bind(py).get_item(&decoded)? {
return Ok(member);
}
Err(CryptographyError::Py(
pyo3::exceptions::PyValueError::new_err(format!(
Expand Down Expand Up @@ -452,8 +445,8 @@ pub(crate) fn decode_annotated_type<'a>(
))?
}
},
Type::ValueSet(cls, inner_type) => {
decode_value_set(py, parser, cls, inner_type.get(), annotation)?
Type::ValueSet(cls, inner_type, value_map) => {
decode_value_set(py, parser, cls, inner_type.get(), value_map, annotation)?
}
Type::PyBool() => decode_pybool(py, parser, encoding)?.into_any(),
Type::PyInt() => decode_pyint(py, parser, encoding)?.into_any(),
Expand Down
2 changes: 1 addition & 1 deletion src/rust/src/declarative_asn1/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ impl asn1::Asn1Writable for AnnotatedTypeObject<'_> {
),
))
}
Type::ValueSet(cls, inner_type) => {
Type::ValueSet(cls, inner_type, _) => {
if !value.is_instance(cls.bind(py))? {
return Err(CryptographyError::Py(
pyo3::exceptions::PyTypeError::new_err(format!(
Expand Down
11 changes: 8 additions & 3 deletions src/rust/src/declarative_asn1/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,13 @@ pub enum Type {
/// a single underlying ASN.1 type).
/// The first element is the Python enum class, the second
/// element is the (already converted) underlying type of the
/// member values.
ValueSet(pyo3::Py<pyo3::types::PyType>, pyo3::Py<AnnotatedType>),
/// member values, and the third element is a map from member
/// value to enum member, used when decoding.
ValueSet(
pyo3::Py<pyo3::types::PyType>,
pyo3::Py<AnnotatedType>,
pyo3::Py<pyo3::types::PyDict>,
),

// Python types that we map to canonical ASN.1 types
//
Expand Down Expand Up @@ -686,7 +691,7 @@ pub(crate) fn is_tag_valid_for_type(
Type::Choice(variants) => variants.bind(py).into_iter().any(|v| {
is_tag_valid_for_variant(py, tag, v.cast::<Variant>().unwrap().get(), encoding)
}),
Type::ValueSet(_, t) => is_tag_valid_for_type(py, tag, t.get().inner.get(), encoding),
Type::ValueSet(_, t, _) => is_tag_valid_for_type(py, tag, t.get().inner.get(), encoding),
Type::PyBool() => check_tag_with_encoding(bool::TAG, encoding, tag),
Type::PyInt() => check_tag_with_encoding(asn1::BigInt::TAG, encoding, tag),
Type::PyBytes() => {
Expand Down
6 changes: 5 additions & 1 deletion tests/hazmat/asn1/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,9 +428,13 @@ def test_fields_of_variant_type(self) -> None:
choice = declarative_asn1.Type.Choice(my_list)
assert choice._0 is my_list

value_set = declarative_asn1.Type.ValueSet(type(None), ann_type)
my_value_map: dict = {}
value_set = declarative_asn1.Type.ValueSet(
type(None), ann_type, my_value_map
)
assert value_set._0 is type(None)
assert value_set._1 is ann_type
assert value_set._2 is my_value_map

def test_fields_of_variant_encoding(self) -> None:
from cryptography.hazmat.bindings._rust import declarative_asn1
Expand Down
Loading