diff --git a/src/de/enum_.rs b/src/de/enum_.rs index 7e516878e..9012873a5 100644 --- a/src/de/enum_.rs +++ b/src/de/enum_.rs @@ -53,3 +53,56 @@ impl<'de, 'a> de::VariantAccess<'de> for UnitVariantAccess<'a, 'de> { Err(Error::InvalidType) } } + +pub(crate) struct VariantAccess<'a, 'b> { + de: &'a mut Deserializer<'b>, +} + +impl<'a, 'b> VariantAccess<'a, 'b> { + pub(crate) fn new(de: &'a mut Deserializer<'b>) -> Self { + VariantAccess { de } + } +} + +impl<'a, 'de> de::EnumAccess<'de> for VariantAccess<'a, 'de> { + type Error = Error; + type Variant = Self; + + fn variant_seed(self, seed: V) -> Result<(V::Value, Self)> + where + V: de::DeserializeSeed<'de>, + { + let variant = seed.deserialize(&mut *self.de)?; + self.de.parse_object_colon()?; + Ok((variant, self)) + } +} + +impl<'de, 'a> de::VariantAccess<'de> for VariantAccess<'a, 'de> { + type Error = Error; + + fn unit_variant(self) -> Result<()> { + de::Deserialize::deserialize(self.de) + } + + fn newtype_variant_seed(self, seed: T) -> Result + where + T: de::DeserializeSeed<'de>, + { + seed.deserialize(self.de) + } + + fn tuple_variant(self, _len: usize, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + de::Deserializer::deserialize_seq(self.de, visitor) + } + + fn struct_variant(self, fields: &'static [&'static str], visitor: V) -> Result + where + V: de::Visitor<'de>, + { + de::Deserializer::deserialize_struct(self.de, "", fields, visitor) + } +} diff --git a/src/de/mod.rs b/src/de/mod.rs index c838c881a..6edb4b7b4 100644 --- a/src/de/mod.rs +++ b/src/de/mod.rs @@ -5,7 +5,7 @@ use core::{fmt, str}; use serde::de::{self, Visitor}; -use self::enum_::UnitVariantAccess; +use self::enum_::{UnitVariantAccess, VariantAccess}; use self::map::MapAccess; use self::seq::SeqAccess; @@ -471,28 +471,40 @@ impl<'a, 'de> de::Deserializer<'de> for &'a mut Deserializer<'de> { } } - /// Unsupported. Use a more specific deserialize_* method - fn deserialize_unit(self, _visitor: V) -> Result + fn deserialize_unit(self, visitor: V) -> Result where V: Visitor<'de>, { - unreachable!() + let peek = match self.parse_whitespace() { + Some(b) => b, + None => { + return Err(Error::EofWhileParsingValue); + } + }; + + match peek { + b'n' => { + self.eat_char(); + self.parse_ident(b"ull")?; + visitor.visit_unit() + } + _ => Err(Error::InvalidType), + } } - /// Unsupported. Use a more specific deserialize_* method - fn deserialize_unit_struct(self, _name: &'static str, _visitor: V) -> Result + fn deserialize_unit_struct(self, _name: &'static str, visitor: V) -> Result where V: Visitor<'de>, { - unreachable!() + self.deserialize_unit(visitor) } /// Unsupported. We can’t parse newtypes because we don’t know the underlying type. - fn deserialize_newtype_struct(self, _name: &'static str, _visitor: V) -> Result + fn deserialize_newtype_struct(self, _name: &'static str, visitor: V) -> Result where V: Visitor<'de>, { - unreachable!() + visitor.visit_newtype_struct(self) } fn deserialize_seq(self, visitor: V) -> Result @@ -573,6 +585,17 @@ impl<'a, 'de> de::Deserializer<'de> for &'a mut Deserializer<'de> { { match self.parse_whitespace().ok_or(Error::EofWhileParsingValue)? { b'"' => visitor.visit_enum(UnitVariantAccess::new(self)), + b'{' => { + self.eat_char(); + let value = visitor.visit_enum(VariantAccess::new(self))?; + match self.parse_whitespace().ok_or(Error::EofWhileParsingValue)? { + b'}' => { + self.eat_char(); + Ok(value) + } + _ => Err(Error::ExpectedSomeValue), + } + }, _ => Err(Error::ExpectedSomeValue), } } @@ -878,6 +901,41 @@ mod tests { assert!(crate::from_str::(r#"{ "temperature": -1 }"#).is_err()); } + #[test] + fn test_unit() { + assert_eq!(crate::from_str::<()>(r#"null"#).unwrap(), ()); + } + + #[test] + fn newtype_struct() { + #[derive(Deserialize, Debug, PartialEq)] + struct A(pub u32); + + assert_eq!(crate::from_str::(r#"54"#).unwrap(), A(54)); + } + + #[test] + fn test_newtype_variant() { + #[derive(Deserialize, Debug, PartialEq)] + enum A { + A(u32), + } + let a = A::A(54); + let x = crate::from_str::(r#"{"A":54}"#); + assert_eq!(x, Ok(a)); + } + + #[test] + fn test_struct_variant() { + #[derive(Deserialize, Debug, PartialEq)] + enum A { + A { x: u32, y: u16 }, + } + let a = A::A { x: 54, y: 720 }; + let x = crate::from_str::(r#"{"A": {"x":54,"y":720 } }"#); + assert_eq!(x, Ok(a)); + } + #[test] #[cfg(not(feature = "custom-error-messages"))] fn struct_tuple() { diff --git a/src/ser/mod.rs b/src/ser/mod.rs index bd4556608..bfcc0d82e 100644 --- a/src/ser/mod.rs +++ b/src/ser/mod.rs @@ -3,12 +3,13 @@ use core::{fmt, fmt::Write}; use serde::ser; +use serde::ser::SerializeStruct as _; use heapless::{consts::*, String, Vec}; use self::map::SerializeMap; use self::seq::SerializeSeq; -use self::struct_::SerializeStruct; +use self::struct_::{SerializeStruct, SerializeStructVariant}; mod map; mod seq; @@ -146,7 +147,7 @@ where type SerializeTupleVariant = Unreachable; type SerializeMap = SerializeMap<'a, B>; type SerializeStruct = SerializeStruct<'a, B>; - type SerializeStructVariant = Unreachable; + type SerializeStructVariant = SerializeStructVariant<'a, B>; fn serialize_bool(self, v: bool) -> Result { if v { @@ -234,11 +235,11 @@ where } fn serialize_unit(self) -> Result { - unreachable!() + self.serialize_none() } fn serialize_unit_struct(self, _name: &'static str) -> Result { - unreachable!() + self.serialize_unit() } fn serialize_unit_variant( @@ -253,25 +254,29 @@ where fn serialize_newtype_struct( self, _name: &'static str, - _value: &T, + value: &T, ) -> Result where T: ser::Serialize, { - unreachable!() + value.serialize(self) } fn serialize_newtype_variant( - self, + mut self, _name: &'static str, _variant_index: u32, - _variant: &'static str, - _value: &T, + variant: &'static str, + value: &T, ) -> Result where T: ser::Serialize, { - unreachable!() + self.buf.push(b'{')?; + let mut s = SerializeStruct::new(&mut self); + s.serialize_field(variant, value)?; + s.end()?; + Ok(()) } fn serialize_seq(self, _len: Option) -> Result { @@ -318,10 +323,14 @@ where self, _name: &'static str, _variant_index: u32, - _variant: &'static str, + variant: &'static str, _len: usize, ) -> Result { - unreachable!() + self.buf.extend_from_slice(b"{\"")?; + self.buf.extend_from_slice(variant.as_bytes())?; + self.buf.extend_from_slice(b"\":{")?; + + Ok(SerializeStructVariant::new(self)) } fn collect_str(self, _value: &T) -> Result @@ -414,22 +423,6 @@ impl ser::SerializeMap for Unreachable { } } -impl ser::SerializeStructVariant for Unreachable { - type Ok = (); - type Error = Error; - - fn serialize_field(&mut self, _key: &'static str, _value: &T) -> Result<()> - where - T: ser::Serialize, - { - unreachable!() - } - - fn end(self) -> Result { - unreachable!() - } -} - #[cfg(test)] mod tests { use serde_derive::Serialize; @@ -597,4 +590,40 @@ mod tests { r#"{"a":true,"b":false}"# ); } + + #[test] + fn test_unit() { + let a = (); + assert_eq!(&*crate::to_string::(&a).unwrap(), r#"null"#); + } + + #[test] + fn test_newtype_struct() { + #[derive(Serialize)] + struct A(pub u32); + let a = A(54); + assert_eq!(&*crate::to_string::(&a).unwrap(), r#"54"#); + } + + #[test] + fn test_newtype_variant() { + #[derive(Serialize)] + enum A { + A(u32), + } + let a = A::A(54); + + assert_eq!(&*crate::to_string::(&a).unwrap(), r#"{"A":54}"#); + } + + #[test] + fn test_struct_variant() { + #[derive(Serialize)] + enum A { + A { x: u32, y: u16 }, + } + let a = A::A { x: 54, y: 720 }; + + assert_eq!(&*crate::to_string::(&a).unwrap(), r#"{"A":{"x":54,"y":720}}"#); + } } diff --git a/src/ser/struct_.rs b/src/ser/struct_.rs index b76826766..8f7b0bf07 100644 --- a/src/ser/struct_.rs +++ b/src/ser/struct_.rs @@ -52,3 +52,52 @@ where Ok(()) } } + +pub struct SerializeStructVariant<'a, B> +where + B: ArrayLength, +{ + ser: &'a mut Serializer, + first: bool, +} + +impl<'a, B> SerializeStructVariant<'a, B> +where + B: ArrayLength, +{ + pub(crate) fn new(ser: &'a mut Serializer) -> Self { + SerializeStructVariant { ser, first: true } + } +} + +impl<'a, B> ser::SerializeStructVariant for SerializeStructVariant<'a, B> +where + B: ArrayLength, +{ + type Ok = (); + type Error = Error; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<()> + where + T: ser::Serialize, + { + // XXX if `value` is `None` we not produce any output for this field + if !self.first { + self.ser.buf.push(b',')?; + } + self.first = false; + + self.ser.buf.push(b'"')?; + self.ser.buf.extend_from_slice(key.as_bytes())?; + self.ser.buf.extend_from_slice(b"\":")?; + + value.serialize(&mut *self.ser)?; + + Ok(()) + } + + fn end(self) -> Result { + self.ser.buf.extend_from_slice(b"}}")?; + Ok(()) + } +}