Skip to content

Commit

Permalink
Fix record nested value de/serialization (#1751)
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui committed May 22, 2024
1 parent 6137d42 commit 550086a
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 29 deletions.
36 changes: 26 additions & 10 deletions crates/burn-core/src/record/serde/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ pub enum NestedValue {

/// A vector of nested values (typically used for vector of structs or numbers)
Vec(Vec<NestedValue>),

/// A vector of 16-bit unsigned integer values.
U16s(Vec<u16>),

/// A vector of 32-bit floating point values.
F32s(Vec<f32>),
}
impl NestedValue {
/// Get the nested value as a map.
Expand Down Expand Up @@ -313,19 +319,27 @@ fn cleanup_empty_maps(current: &mut NestedValue) {
}
}

fn write_vec_truncated<T: core::fmt::Debug>(
vec: &[T],
f: &mut core::fmt::Formatter,
) -> fmt::Result {
write!(f, "Vec([")?;
for (i, v) in vec.iter().take(3).enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{:?}", v)?;
}
write!(f, ", ...] len={})", vec.len())
}

impl fmt::Debug for NestedValue {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
NestedValue::Vec(vec) if vec.len() > 3 => {
write!(f, "Vec([")?;
for (i, v) in vec.iter().take(3).enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{:?}", v)?;
}
write!(f, ", ...] len={})", vec.len())
}
// Truncate values for vector
NestedValue::Vec(vec) if vec.len() > 3 => write_vec_truncated(vec, f),
NestedValue::U16s(vec) if vec.len() > 3 => write_vec_truncated(vec, f),
NestedValue::F32s(vec) if vec.len() > 3 => write_vec_truncated(vec, f),
// Handle other variants as usual
NestedValue::Default(origin) => f.debug_tuple("Default").field(origin).finish(),
NestedValue::Bool(b) => f.debug_tuple("Bool").field(b).finish(),
Expand All @@ -339,6 +353,8 @@ impl fmt::Debug for NestedValue {
NestedValue::U64(val) => f.debug_tuple("U64").field(val).finish(),
NestedValue::Map(map) => f.debug_map().entries(map.iter()).finish(),
NestedValue::Vec(vec) => f.debug_list().entries(vec.iter()).finish(),
NestedValue::U16s(vec) => f.debug_list().entries(vec.iter()).finish(),
NestedValue::F32s(vec) => f.debug_list().entries(vec.iter()).finish(),
}
}
}
127 changes: 112 additions & 15 deletions crates/burn-core/src/record/serde/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,13 +286,27 @@ impl<'de, A: BurnModuleAdapter> serde::Deserializer<'de> for Deserializer<A> {
where
V: Visitor<'de>,
{
if let Some(NestedValue::Vec(vec)) = self.value {
visitor.visit_seq(VecSeqAccess::<A>::new(vec, self.default_for_missing_fields))
if let Some(value) = self.value {
match value {
NestedValue::Vec(_) => visitor.visit_seq(VecSeqAccess::<A, NestedValue>::new(
value,
self.default_for_missing_fields,
)),
NestedValue::U16s(_) => visitor.visit_seq(VecSeqAccess::<A, u16>::new(
value,
self.default_for_missing_fields,
)),
NestedValue::F32s(_) => visitor.visit_seq(VecSeqAccess::<A, f32>::new(
value,
self.default_for_missing_fields,
)),
_ => Err(de::Error::custom(format!(
"Expected Vec but got {:?}",
value
))),
}
} else {
Err(de::Error::custom(format!(
"Expected Vec but got {:?}",
self.value
)))
Err(de::Error::custom("Expected Vec but got None"))
}
}

Expand Down Expand Up @@ -385,23 +399,56 @@ impl<'de, A: BurnModuleAdapter> serde::Deserializer<'de> for Deserializer<A> {
}

/// A sequence access for a vector in the nested value data structure.
struct VecSeqAccess<A: BurnModuleAdapter> {
iter: std::vec::IntoIter<NestedValue>,
struct VecSeqAccess<A: BurnModuleAdapter, I> {
iter: Box<dyn Iterator<Item = I>>,
default_for_missing_fields: bool,
phantom: std::marker::PhantomData<A>,
}

impl<A: BurnModuleAdapter> VecSeqAccess<A> {
fn new(vec: Vec<NestedValue>, default_for_missing_fields: bool) -> Self {
VecSeqAccess {
iter: vec.into_iter(),
default_for_missing_fields,
phantom: std::marker::PhantomData,
// Concrete implementation for `Vec<NestedValue>`
impl<A: BurnModuleAdapter> VecSeqAccess<A, NestedValue> {
fn new(vec: NestedValue, default_for_missing_fields: bool) -> Self {
match vec {
NestedValue::Vec(v) => VecSeqAccess {
iter: Box::new(v.into_iter()),
default_for_missing_fields,
phantom: std::marker::PhantomData,
},
_ => panic!("Invalid vec sequence"),
}
}
}

// Concrete implementation for `Vec<u16>`
impl<A: BurnModuleAdapter> VecSeqAccess<A, u16> {
fn new(vec: NestedValue, default_for_missing_fields: bool) -> Self {
match vec {
NestedValue::U16s(v) => VecSeqAccess {
iter: Box::new(v.into_iter()),
default_for_missing_fields,
phantom: std::marker::PhantomData,
},
_ => panic!("Invalid vec sequence"),
}
}
}

impl<'de, A> SeqAccess<'de> for VecSeqAccess<A>
// Concrete implementation for `Vec<f32>`
impl<A: BurnModuleAdapter> VecSeqAccess<A, f32> {
fn new(vec: NestedValue, default_for_missing_fields: bool) -> Self {
match vec {
NestedValue::F32s(v) => VecSeqAccess {
iter: Box::new(v.into_iter()),
default_for_missing_fields,
phantom: std::marker::PhantomData,
},
_ => panic!("Invalid vec sequence"),
}
}
}

// Concrete implementation for `Vec<NestedValue>`
impl<'de, A> SeqAccess<'de> for VecSeqAccess<A, NestedValue>
where
NestedValueWrapper<A>: IntoDeserializer<'de, Error>,
A: BurnModuleAdapter,
Expand All @@ -424,6 +471,56 @@ where
}
}

// Concrete implementation for `Vec<u16>`
impl<'de, A> SeqAccess<'de> for VecSeqAccess<A, u16>
where
NestedValueWrapper<A>: IntoDeserializer<'de, Error>,
A: BurnModuleAdapter,
{
type Error = Error;

fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
where
T: DeserializeSeed<'de>,
{
let item = match self.iter.next() {
Some(v) => v,
None => return Ok(None),
};

seed.deserialize(
NestedValueWrapper::<A>::new(NestedValue::U16(item), self.default_for_missing_fields)
.into_deserializer(),
)
.map(Some)
}
}

// Concrete implementation for `Vec<f32>`
impl<'de, A> SeqAccess<'de> for VecSeqAccess<A, f32>
where
NestedValueWrapper<A>: IntoDeserializer<'de, Error>,
A: BurnModuleAdapter,
{
type Error = Error;

fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
where
T: DeserializeSeed<'de>,
{
let item = match self.iter.next() {
Some(v) => v,
None => return Ok(None),
};

seed.deserialize(
NestedValueWrapper::<A>::new(NestedValue::F32(item), self.default_for_missing_fields)
.into_deserializer(),
)
.map(Some)
}
}

/// A map access for a map in the nested value data structure.
struct HashMapAccess<A: BurnModuleAdapter> {
iter: std::collections::hash_map::IntoIter<String, NestedValue>,
Expand Down
26 changes: 22 additions & 4 deletions crates/burn-core/src/record/serde/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use serde::{
/// the actual serialization of modules (although it could be used for that as well if all
/// primitive types are implemented).
pub struct Serializer {
// The state of the serialization process
/// The state of the serialization process
state: Option<NestedValue>,
}

Expand Down Expand Up @@ -254,11 +254,30 @@ impl SerializeSeq for Serializer {
Some(NestedValue::Vec(ref mut vec)) => {
vec.push(serialized_value); // Inserting into the state
}
Some(NestedValue::U16s(ref mut vec)) => {
if let NestedValue::U16(val) = serialized_value {
vec.push(val);
} else {
panic!("Invalid value type encountered");
}
}
Some(NestedValue::F32s(ref mut vec)) => {
if let NestedValue::F32(val) = serialized_value {
vec.push(val);
} else {
panic!("Invalid value type encountered");
}
}
Some(_) => {
panic!("Invalid state encountered");
}
None => {
self.state = Some(NestedValue::Vec(vec![serialized_value]));
let val = match serialized_value {
NestedValue::U16(val) => NestedValue::U16s(vec![val]),
NestedValue::F32(val) => NestedValue::F32s(vec![val]),
_ => NestedValue::Vec(vec![serialized_value]),
};
self.state = Some(val);
}
}

Expand Down Expand Up @@ -331,7 +350,6 @@ mod tests {
// the order of the fields is not guaranteed for HashMaps.
assert_eq!(serialized_str.len(), 135);
}

#[test]
fn test_param_serde() {
type Backend = burn_ndarray::NdArray<f32>;
Expand All @@ -352,6 +370,6 @@ mod tests {

// Compare the lengths of expected and actual serialized strings because
// the order of the fields is not guaranteed for HashMaps.
assert_eq!(serialized_str.len(), 149);
assert_eq!(serialized_str.len(), 134);
}
}

0 comments on commit 550086a

Please sign in to comment.