Skip to content

Commit

Permalink
Hardcode error type in IntoVisitor (#41)
Browse files Browse the repository at this point in the history
* hard-code Error type that IntoVisitor expects

* simplify some bounds and add a dev note

* cargo fmt

* Add adapter to make it easy to write visitors whose errors convert nicely

* remove now-unneeded Error::from's

* add a quick sanity check test for VisitorWithCrateError

* fix no-std tests

* IntoVisitor => DecodeAsType on a couple of tests to undo prev change
  • Loading branch information
jsdw authored Nov 10, 2023
1 parent 6fc4aea commit 2c59157
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 54 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
/target
/Cargo.lock
.DS_Store
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ members = [
"scale-decode-derive",
"testing/no_std",
]
resolver = "2"

[workspace.package]
version = "0.9.0"
Expand Down
7 changes: 7 additions & 0 deletions scale-decode/src/error/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ impl From<DecodeError> for Error {
}
}

impl From<codec::Error> for Error {
fn from(err: codec::Error) -> Error {
let err: DecodeError = err.into();
Error::new(err.into())
}
}

/// The underlying nature of the error.
#[derive(Debug, derive_more::From, derive_more::Display)]
pub enum ErrorKind {
Expand Down
53 changes: 13 additions & 40 deletions scale-decode/src/impls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,6 @@ macro_rules! impl_decode_seq_via_collect {
impl <$generic> Visitor for BasicVisitor<$ty<$generic>>
where
$generic: IntoVisitor,
Error: From<<$generic::Visitor as Visitor>::Error>,
$( $($where)* )?
{
type Value<'scale, 'info> = $ty<$generic>;
Expand Down Expand Up @@ -306,11 +305,7 @@ macro_rules! array_method_impl {
Ok(arr)
}};
}
impl<const N: usize, T> Visitor for BasicVisitor<[T; N]>
where
T: IntoVisitor,
Error: From<<T::Visitor as Visitor>::Error>,
{
impl<const N: usize, T: IntoVisitor> Visitor for BasicVisitor<[T; N]> {
type Value<'scale, 'info> = [T; N];
type Error = Error;

Expand All @@ -331,22 +326,14 @@ where

visit_single_field_composite_tuple_impls!();
}
impl<const N: usize, T> IntoVisitor for [T; N]
where
T: IntoVisitor,
Error: From<<T::Visitor as Visitor>::Error>,
{
impl<const N: usize, T: IntoVisitor> IntoVisitor for [T; N] {
type Visitor = BasicVisitor<[T; N]>;
fn into_visitor() -> Self::Visitor {
BasicVisitor { _marker: core::marker::PhantomData }
}
}

impl<T> Visitor for BasicVisitor<BTreeMap<String, T>>
where
T: IntoVisitor,
Error: From<<T::Visitor as Visitor>::Error>,
{
impl<T: IntoVisitor> Visitor for BasicVisitor<BTreeMap<String, T>> {
type Error = Error;
type Value<'scale, 'info> = BTreeMap<String, T>;

Expand All @@ -365,19 +352,15 @@ where
// Decode the value now that we have a valid name.
let Some(val) = value.decode_item(T::into_visitor()) else { break };
// Save to the map.
let val = val.map_err(|e| Error::from(e).at_field(key.to_owned()))?;
let val = val.map_err(|e| e.at_field(key.to_owned()))?;
map.insert(key.to_owned(), val);
}
Ok(map)
}
}
impl_into_visitor!(BTreeMap<String, T>);

impl<T> Visitor for BasicVisitor<Option<T>>
where
T: IntoVisitor,
Error: From<<T::Visitor as Visitor>::Error>,
{
impl<T: IntoVisitor> Visitor for BasicVisitor<Option<T>> {
type Error = Error;
type Value<'scale, 'info> = Option<T>;

Expand All @@ -391,7 +374,7 @@ where
.fields()
.decode_item(T::into_visitor())
.transpose()
.map_err(|e| Error::from(e).at_variant("Some"))?
.map_err(|e| e.at_variant("Some"))?
.expect("checked for 1 field already so should be ok");
Ok(Some(val))
} else if value.name() == "None" && value.fields().remaining() == 0 {
Expand All @@ -407,13 +390,7 @@ where
}
impl_into_visitor!(Option<T>);

impl<T, E> Visitor for BasicVisitor<Result<T, E>>
where
T: IntoVisitor,
Error: From<<T::Visitor as Visitor>::Error>,
E: IntoVisitor,
Error: From<<E::Visitor as Visitor>::Error>,
{
impl<T: IntoVisitor, E: IntoVisitor> Visitor for BasicVisitor<Result<T, E>> {
type Error = Error;
type Value<'scale, 'info> = Result<T, E>;

Expand All @@ -427,15 +404,15 @@ where
.fields()
.decode_item(T::into_visitor())
.transpose()
.map_err(|e| Error::from(e).at_variant("Ok"))?
.map_err(|e| e.at_variant("Ok"))?
.expect("checked for 1 field already so should be ok");
Ok(Ok(val))
} else if value.name() == "Err" && value.fields().remaining() == 1 {
let val = value
.fields()
.decode_item(E::into_visitor())
.transpose()
.map_err(|e| Error::from(e).at_variant("Err"))?
.map_err(|e| e.at_variant("Err"))?
.expect("checked for 1 field already so should be ok");
Ok(Err(val))
} else {
Expand Down Expand Up @@ -541,7 +518,7 @@ macro_rules! tuple_method_impl {
let v = $value
.decode_item($t::into_visitor())
.transpose()
.map_err(|e| Error::from(e).at_idx(idx))?
.map_err(|e| e.at_idx(idx))?
.expect("length already checked via .remaining()");
idx += 1;
v
Expand Down Expand Up @@ -593,7 +570,6 @@ macro_rules! impl_decode_tuple {
impl < $($t),* > Visitor for BasicVisitor<($($t,)*)>
where $(
$t: IntoVisitor,
Error: From<<$t::Visitor as Visitor>::Error>,
)*
{
type Value<'scale, 'info> = ($($t,)*);
Expand Down Expand Up @@ -621,7 +597,7 @@ macro_rules! impl_decode_tuple {

// We can turn this tuple into a visitor which knows how to decode it:
impl < $($t),* > IntoVisitor for ($($t,)*)
where $( $t: IntoVisitor, Error: From<<$t::Visitor as Visitor>::Error>, )*
where $( $t: IntoVisitor, )*
{
type Visitor = BasicVisitor<($($t,)*)>;
fn into_visitor() -> Self::Visitor {
Expand All @@ -631,7 +607,7 @@ macro_rules! impl_decode_tuple {

// We can decode given a list of fields (just delegate to the visitor impl:
impl < $($t),* > DecodeAsFields for ($($t,)*)
where $( $t: IntoVisitor, Error: From<<$t::Visitor as Visitor>::Error>, )*
where $( $t: IntoVisitor, )*
{
fn decode_as_fields<'info>(input: &mut &[u8], fields: &mut dyn FieldIter<'info>, types: &'info scale_info::PortableRegistry) -> Result<Self, Error> {
let mut composite = crate::visitor::types::Composite::new(input, crate::EMPTY_SCALE_INFO_PATH, fields, types, false);
Expand Down Expand Up @@ -676,14 +652,11 @@ fn decode_items_using<'a, 'scale, 'info, D: DecodeItemIterator<'scale, 'info>, T
) -> impl Iterator<Item = Result<T, Error>> + 'a
where
T: IntoVisitor,
Error: From<<T::Visitor as Visitor>::Error>,
D: DecodeItemIterator<'scale, 'info>,
{
let mut idx = 0;
core::iter::from_fn(move || {
let item = decoder
.decode_item(T::into_visitor())
.map(|res| res.map_err(|e| Error::from(e).at_idx(idx)));
let item = decoder.decode_item(T::into_visitor()).map(|res| res.map_err(|e| e.at_idx(idx)));
idx += 1;
item
})
Expand Down
19 changes: 10 additions & 9 deletions scale-decode/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ use alloc::vec::Vec;
/// This trait is implemented for any type `T` where `T` implements [`IntoVisitor`] and the errors returned
/// from this [`Visitor`] can be converted into [`Error`]. It's essentially a convenience wrapper around
/// [`visitor::decode_with_visitor`] that mirrors `scale-encode`'s `EncodeAsType`.
pub trait DecodeAsType: Sized {
pub trait DecodeAsType: Sized + IntoVisitor {
/// Given some input bytes, a `type_id`, and type registry, attempt to decode said bytes into
/// `Self`. Implementations should modify the `&mut` reference to the bytes such that any bytes
/// not used in the course of decoding are still pointed to after decoding is complete.
Expand All @@ -192,11 +192,7 @@ pub trait DecodeAsType: Sized {
) -> Result<Self, Error>;
}

impl<T> DecodeAsType for T
where
T: IntoVisitor,
Error: From<<T::Visitor as Visitor>::Error>,
{
impl<T: Sized + IntoVisitor> DecodeAsType for T {
fn decode_as_type_maybe_compact(
input: &mut &[u8],
type_id: u32,
Expand Down Expand Up @@ -267,11 +263,16 @@ pub trait FieldIter<'a>: Iterator<Item = Field<'a>> {}
impl<'a, T> FieldIter<'a> for T where T: Iterator<Item = Field<'a>> {}

/// This trait can be implemented on any type that has an associated [`Visitor`] responsible for decoding
/// SCALE encoded bytes to it. If you implement this on some type and the [`Visitor`] that you return has
/// an error type that converts into [`Error`], then you'll also get a [`DecodeAsType`] implementation for free.
/// SCALE encoded bytes to it whose error type is [`Error`]. Anything that implements this trait gets a
/// [`DecodeAsType`] implementation for free.
// Dev note: This used to allow for any Error type that could be converted into `scale_decode::Error`.
// The problem with this is that the `DecodeAsType` trait became tricky to use in some contexts, because it
// didn't automatically imply so much. Realistically, being stricter here shouldn't matter too much; derive
// impls all use `scale_decode::Error` anyway, and manual impls can just manually convert into the error
// rather than rely on auto conversion, if they care about also being able to impl `DecodeAsType`.
pub trait IntoVisitor {
/// The visitor type used to decode SCALE encoded bytes to `Self`.
type Visitor: for<'scale, 'info> visitor::Visitor<Value<'scale, 'info> = Self>;
type Visitor: for<'scale, 'info> visitor::Visitor<Value<'scale, 'info> = Self, Error = Error>;
/// A means of obtaining this visitor.
fn into_visitor() -> Self::Visitor;
}
Expand Down
110 changes: 105 additions & 5 deletions scale-decode/src/visitor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,26 @@ pub enum DecodeAsTypeResult<V, R> {
Decoded(R),
}

impl<V, R> DecodeAsTypeResult<V, R> {
/// If we have a [`DecodeAsTypeResult::Decoded`], the function provided will
/// map this decoded result to whatever it returns.
pub fn map_decoded<T, F: FnOnce(R) -> T>(self, f: F) -> DecodeAsTypeResult<V, T> {
match self {
DecodeAsTypeResult::Skipped(s) => DecodeAsTypeResult::Skipped(s),
DecodeAsTypeResult::Decoded(r) => DecodeAsTypeResult::Decoded(f(r)),
}
}

/// If we have a [`DecodeAsTypeResult::Skipped`], the function provided will
/// map this skipped value to whatever it returns.
pub fn map_skipped<T, F: FnOnce(V) -> T>(self, f: F) -> DecodeAsTypeResult<T, R> {
match self {
DecodeAsTypeResult::Skipped(s) => DecodeAsTypeResult::Skipped(f(s)),
DecodeAsTypeResult::Decoded(r) => DecodeAsTypeResult::Decoded(r),
}
}
}

/// This is implemented for visitor related types which have a `decode_item` method,
/// and allows you to generically talk about decoding unnamed items.
pub trait DecodeItemIterator<'scale, 'info> {
Expand Down Expand Up @@ -358,6 +378,34 @@ impl Visitor for IgnoreVisitor {
}
}

/// Some [`Visitor`] implementations may want to return an error type other than [`crate::Error`], which means
/// that they would not be automatically compatible with [`crate::IntoVisitor`], which requires visitors that do return
/// [`crate::Error`] errors.
///
/// As long as the error type of the visitor implementation can be converted into [`crate::Error`] via [`Into`],
/// the visitor implementation can be wrapped in this [`VisitorWithCrateError`] struct to make it work with
/// [`crate::IntoVisitor`].
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct VisitorWithCrateError<V>(pub V);

impl<V: Visitor> Visitor for VisitorWithCrateError<V>
where
V::Error: Into<crate::Error>,
{
type Value<'scale, 'info> = V::Value<'scale, 'info>;
type Error = crate::Error;

fn unchecked_decode_as_type<'scale, 'info>(
self,
input: &mut &'scale [u8],
type_id: TypeId,
types: &'info scale_info::PortableRegistry,
) -> DecodeAsTypeResult<Self, Result<Self::Value<'scale, 'info>, Self::Error>> {
let res = decode_with_visitor(input, type_id.0, types, self.0).map_err(Into::into);
DecodeAsTypeResult::Decoded(res)
}
}

#[cfg(test)]
mod test {
use crate::visitor::TypeId;
Expand Down Expand Up @@ -397,6 +445,7 @@ mod test {
BitSequence(scale_bits::Bits),
}

#[derive(Clone, Copy)]
struct ValueVisitor;
impl Visitor for ValueVisitor {
type Value<'scale, 'info> = Value;
Expand Down Expand Up @@ -595,22 +644,69 @@ mod test {

/// This just tests that if we try to decode some values we've encoded using a visitor
/// which just ignores everything by default, that we'll consume all of the bytes.
fn encode_decode_check_explicit_info<Ty: scale_info::TypeInfo + 'static, T: Encode>(
fn encode_decode_check_explicit_info<
Ty: scale_info::TypeInfo + 'static,
T: Encode,
V: for<'s, 'i> Visitor<Value<'s, 'i> = Value, Error = E>,
E: core::fmt::Debug,
>(
val: T,
expected: Value,
visitor: V,
) {
let encoded = val.encode();
let (id, types) = make_type::<Ty>();
let bytes = &mut &*encoded;
let val = decode_with_visitor(bytes, id, &types, ValueVisitor)
.expect("decoding should not error");
let val =
decode_with_visitor(bytes, id, &types, visitor).expect("decoding should not error");

assert_eq!(bytes.len(), 0, "Decoding should consume all bytes");
assert_eq!(val, expected);
}

fn encode_decode_check_with_visitor<
T: Encode + scale_info::TypeInfo + 'static,
V: for<'s, 'i> Visitor<Value<'s, 'i> = Value, Error = E>,
E: core::fmt::Debug,
>(
val: T,
expected: Value,
visitor: V,
) {
encode_decode_check_explicit_info::<T, T, _, _>(val, expected, visitor);
}

fn encode_decode_check<T: Encode + scale_info::TypeInfo + 'static>(val: T, expected: Value) {
encode_decode_check_explicit_info::<T, T>(val, expected);
encode_decode_check_explicit_info::<T, T, _, _>(val, expected, ValueVisitor);
}

#[test]
fn decode_with_root_error_wrapper_works() {
use crate::visitor::VisitorWithCrateError;
let visitor = VisitorWithCrateError(ValueVisitor);

encode_decode_check_with_visitor(123u8, Value::U8(123), visitor);
encode_decode_check_with_visitor(123u16, Value::U16(123), visitor);
encode_decode_check_with_visitor(123u32, Value::U32(123), visitor);
encode_decode_check_with_visitor(123u64, Value::U64(123), visitor);
encode_decode_check_with_visitor(123u128, Value::U128(123), visitor);
encode_decode_check_with_visitor(
"Hello there",
Value::Str("Hello there".to_owned()),
visitor,
);

#[derive(Encode, scale_info::TypeInfo)]
struct Unnamed(bool, String, Vec<u8>);
encode_decode_check_with_visitor(
Unnamed(true, "James".into(), vec![1, 2, 3]),
Value::Composite(vec![
(String::new(), Value::Bool(true)),
(String::new(), Value::Str("James".to_string())),
(String::new(), Value::Sequence(vec![Value::U8(1), Value::U8(2), Value::U8(3)])),
]),
visitor,
);
}

#[test]
Expand All @@ -627,7 +723,11 @@ mod test {
encode_decode_check(codec::Compact(123u128), Value::U128(123));
encode_decode_check(true, Value::Bool(true));
encode_decode_check(false, Value::Bool(false));
encode_decode_check_explicit_info::<char, _>('c' as u32, Value::Char('c'));
encode_decode_check_explicit_info::<char, _, _, _>(
'c' as u32,
Value::Char('c'),
ValueVisitor,
);
encode_decode_check("Hello there", Value::Str("Hello there".to_owned()));
encode_decode_check("Hello there".to_string(), Value::Str("Hello there".to_owned()));
}
Expand Down

0 comments on commit 2c59157

Please sign in to comment.