From d26899115b8c1f986ddf4064ae5ddf42baf68a41 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Thu, 9 May 2024 10:21:42 -0400 Subject: [PATCH 1/3] Add Vec and Vec NestedValue variants --- crates/burn-core/src/record/serde/data.rs | 36 ++++++++++++++++------- crates/burn-core/src/record/serde/de.rs | 18 ++++++++---- crates/burn-core/src/record/serde/ser.rs | 23 +++++++++++++-- 3 files changed, 60 insertions(+), 17 deletions(-) diff --git a/crates/burn-core/src/record/serde/data.rs b/crates/burn-core/src/record/serde/data.rs index b869be5c4f..16bbde7f01 100644 --- a/crates/burn-core/src/record/serde/data.rs +++ b/crates/burn-core/src/record/serde/data.rs @@ -54,6 +54,12 @@ pub enum NestedValue { /// A vector of nested values (typically used for vector of structs or numbers) Vec(Vec), + + /// A vector of 16-bit unsigned integer values. + U16s(Vec), + + /// A vector of 32-bit floating point values. + F32s(Vec), } impl NestedValue { /// Get the nested value as a map. @@ -313,19 +319,27 @@ fn cleanup_empty_maps(current: &mut NestedValue) { } } +fn write_vec_truncated( + vec: &[T], + f: &mut core::fmt::Formatter, +) -> fmt::Result { + write!(f, "Vec([")?; + for (i, v) in vec.iter().take(3).enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{:?}", v)?; + } + write!(f, ", ...] len={})", vec.len()) +} + impl fmt::Debug for NestedValue { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - NestedValue::Vec(vec) if vec.len() > 3 => { - write!(f, "Vec([")?; - for (i, v) in vec.iter().take(3).enumerate() { - if i > 0 { - write!(f, ", ")?; - } - write!(f, "{:?}", v)?; - } - write!(f, ", ...] len={})", vec.len()) - } + // Truncate values for vector + NestedValue::Vec(vec) if vec.len() > 3 => write_vec_truncated(vec, f), + NestedValue::U16s(vec) if vec.len() > 3 => write_vec_truncated(vec, f), + NestedValue::F32s(vec) if vec.len() > 3 => write_vec_truncated(vec, f), // Handle other variants as usual NestedValue::Default(origin) => f.debug_tuple("Default").field(origin).finish(), NestedValue::Bool(b) => f.debug_tuple("Bool").field(b).finish(), @@ -339,6 +353,8 @@ impl fmt::Debug for NestedValue { NestedValue::U64(val) => f.debug_tuple("U64").field(val).finish(), NestedValue::Map(map) => f.debug_map().entries(map.iter()).finish(), NestedValue::Vec(vec) => f.debug_list().entries(vec.iter()).finish(), + NestedValue::U16s(vec) => f.debug_list().entries(vec.iter()).finish(), + NestedValue::F32s(vec) => f.debug_list().entries(vec.iter()).finish(), } } } diff --git a/crates/burn-core/src/record/serde/de.rs b/crates/burn-core/src/record/serde/de.rs index 5a09b3bde2..3627f61eb1 100644 --- a/crates/burn-core/src/record/serde/de.rs +++ b/crates/burn-core/src/record/serde/de.rs @@ -4,6 +4,7 @@ use std::collections::HashMap; use super::data::NestedValue; use super::{adapter::BurnModuleAdapter, error::Error}; +use serde::de::value::SeqDeserializer; use serde::de::{EnumAccess, VariantAccess}; use serde::{ de::{self, DeserializeSeed, IntoDeserializer, MapAccess, SeqAccess, Visitor}, @@ -286,13 +287,20 @@ impl<'de, A: BurnModuleAdapter> serde::Deserializer<'de> for Deserializer { where V: Visitor<'de>, { - if let Some(NestedValue::Vec(vec)) = self.value { - visitor.visit_seq(VecSeqAccess::::new(vec, self.default_for_missing_fields)) - } else { - Err(de::Error::custom(format!( + match self.value { + Some(NestedValue::Vec(vec)) => { + visitor.visit_seq(VecSeqAccess::::new(vec, self.default_for_missing_fields)) + } + Some(NestedValue::U16s(vec)) => { + visitor.visit_seq(SeqDeserializer::new(vec.into_iter())) + } + Some(NestedValue::F32s(vec)) => { + visitor.visit_seq(SeqDeserializer::new(vec.into_iter())) + } + _ => Err(de::Error::custom(format!( "Expected Vec but got {:?}", self.value - ))) + ))), } } diff --git a/crates/burn-core/src/record/serde/ser.rs b/crates/burn-core/src/record/serde/ser.rs index 9c16d09a7a..b79feaae93 100644 --- a/crates/burn-core/src/record/serde/ser.rs +++ b/crates/burn-core/src/record/serde/ser.rs @@ -16,7 +16,7 @@ use serde::{ /// the actual serialization of modules (although it could be used for that as well if all /// primitive types are implemented). pub struct Serializer { - // The state of the serialization process + /// The state of the serialization process state: Option, } @@ -254,11 +254,30 @@ impl SerializeSeq for Serializer { Some(NestedValue::Vec(ref mut vec)) => { vec.push(serialized_value); // Inserting into the state } + Some(NestedValue::U16s(ref mut vec)) => { + if let NestedValue::U16(val) = serialized_value { + vec.push(val); + } else { + panic!("Invalid value type encountered"); + } + } + Some(NestedValue::F32s(ref mut vec)) => { + if let NestedValue::F32(val) = serialized_value { + vec.push(val); + } else { + panic!("Invalid value type encountered"); + } + } Some(_) => { panic!("Invalid state encountered"); } None => { - self.state = Some(NestedValue::Vec(vec![serialized_value])); + let val = match serialized_value { + NestedValue::U16(val) => NestedValue::U16s(vec![val]), + NestedValue::F32(val) => NestedValue::F32s(vec![val]), + _ => NestedValue::Vec(vec![serialized_value]), + }; + self.state = Some(val); } } From 70375935e57d278b6ba5e9ae4c2a2c40589ee25a Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Thu, 9 May 2024 13:33:07 -0400 Subject: [PATCH 2/3] Add Vec and Vec support for VecSeqAccess (fix f16/bf16 deserialize) --- crates/burn-core/src/record/serde/de.rs | 135 ++++++++++++++++++++---- 1 file changed, 112 insertions(+), 23 deletions(-) diff --git a/crates/burn-core/src/record/serde/de.rs b/crates/burn-core/src/record/serde/de.rs index 3627f61eb1..0e5c52b5ee 100644 --- a/crates/burn-core/src/record/serde/de.rs +++ b/crates/burn-core/src/record/serde/de.rs @@ -4,7 +4,6 @@ use std::collections::HashMap; use super::data::NestedValue; use super::{adapter::BurnModuleAdapter, error::Error}; -use serde::de::value::SeqDeserializer; use serde::de::{EnumAccess, VariantAccess}; use serde::{ de::{self, DeserializeSeed, IntoDeserializer, MapAccess, SeqAccess, Visitor}, @@ -287,20 +286,27 @@ impl<'de, A: BurnModuleAdapter> serde::Deserializer<'de> for Deserializer { where V: Visitor<'de>, { - match self.value { - Some(NestedValue::Vec(vec)) => { - visitor.visit_seq(VecSeqAccess::::new(vec, self.default_for_missing_fields)) - } - Some(NestedValue::U16s(vec)) => { - visitor.visit_seq(SeqDeserializer::new(vec.into_iter())) - } - Some(NestedValue::F32s(vec)) => { - visitor.visit_seq(SeqDeserializer::new(vec.into_iter())) + if let Some(value) = self.value { + match value { + NestedValue::Vec(_) => visitor.visit_seq(VecSeqAccess::::new( + value, + self.default_for_missing_fields, + )), + NestedValue::U16s(_) => visitor.visit_seq(VecSeqAccess::::new( + value, + self.default_for_missing_fields, + )), + NestedValue::F32s(_) => visitor.visit_seq(VecSeqAccess::::new( + value, + self.default_for_missing_fields, + )), + _ => Err(de::Error::custom(format!( + "Expected Vec but got {:?}", + value + ))), } - _ => Err(de::Error::custom(format!( - "Expected Vec but got {:?}", - self.value - ))), + } else { + Err(de::Error::custom("Expected Vec but got None")) } } @@ -393,23 +399,56 @@ impl<'de, A: BurnModuleAdapter> serde::Deserializer<'de> for Deserializer { } /// A sequence access for a vector in the nested value data structure. -struct VecSeqAccess { - iter: std::vec::IntoIter, +struct VecSeqAccess { + iter: Box>, default_for_missing_fields: bool, phantom: std::marker::PhantomData, } -impl VecSeqAccess { - fn new(vec: Vec, default_for_missing_fields: bool) -> Self { - VecSeqAccess { - iter: vec.into_iter(), - default_for_missing_fields, - phantom: std::marker::PhantomData, +// Concrete implementation for `Vec` +impl VecSeqAccess { + fn new(vec: NestedValue, default_for_missing_fields: bool) -> Self { + match vec { + NestedValue::Vec(v) => VecSeqAccess { + iter: Box::new(v.into_iter()), + default_for_missing_fields, + phantom: std::marker::PhantomData, + }, + _ => panic!("Invalid vec sequence"), } } } -impl<'de, A> SeqAccess<'de> for VecSeqAccess +// Concrete implementation for `Vec` +impl VecSeqAccess { + fn new(vec: NestedValue, default_for_missing_fields: bool) -> Self { + match vec { + NestedValue::U16s(v) => VecSeqAccess { + iter: Box::new(v.into_iter()), + default_for_missing_fields, + phantom: std::marker::PhantomData, + }, + _ => panic!("Invalid vec sequence"), + } + } +} + +// Concrete implementation for `Vec` +impl VecSeqAccess { + fn new(vec: NestedValue, default_for_missing_fields: bool) -> Self { + match vec { + NestedValue::F32s(v) => VecSeqAccess { + iter: Box::new(v.into_iter()), + default_for_missing_fields, + phantom: std::marker::PhantomData, + }, + _ => panic!("Invalid vec sequence"), + } + } +} + +// Concrete implementation for `Vec` +impl<'de, A> SeqAccess<'de> for VecSeqAccess where NestedValueWrapper: IntoDeserializer<'de, Error>, A: BurnModuleAdapter, @@ -432,6 +471,56 @@ where } } +// Concrete implementation for `Vec` +impl<'de, A> SeqAccess<'de> for VecSeqAccess +where + NestedValueWrapper: IntoDeserializer<'de, Error>, + A: BurnModuleAdapter, +{ + type Error = Error; + + fn next_element_seed(&mut self, seed: T) -> Result, Self::Error> + where + T: DeserializeSeed<'de>, + { + let item = match self.iter.next() { + Some(v) => v, + None => return Ok(None), + }; + + seed.deserialize( + NestedValueWrapper::::new(NestedValue::U16(item), self.default_for_missing_fields) + .into_deserializer(), + ) + .map(Some) + } +} + +// Concrete implementation for `Vec` +impl<'de, A> SeqAccess<'de> for VecSeqAccess +where + NestedValueWrapper: IntoDeserializer<'de, Error>, + A: BurnModuleAdapter, +{ + type Error = Error; + + fn next_element_seed(&mut self, seed: T) -> Result, Self::Error> + where + T: DeserializeSeed<'de>, + { + let item = match self.iter.next() { + Some(v) => v, + None => return Ok(None), + }; + + seed.deserialize( + NestedValueWrapper::::new(NestedValue::F32(item), self.default_for_missing_fields) + .into_deserializer(), + ) + .map(Some) + } +} + /// A map access for a map in the nested value data structure. struct HashMapAccess { iter: std::collections::hash_map::IntoIter, From 40b3c71cd3d8521c22dab29438c65cec89fea95c Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Thu, 9 May 2024 15:31:53 -0400 Subject: [PATCH 3/3] Fix test_param_serde test --- crates/burn-core/src/record/serde/ser.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/crates/burn-core/src/record/serde/ser.rs b/crates/burn-core/src/record/serde/ser.rs index b79feaae93..effeeab7e1 100644 --- a/crates/burn-core/src/record/serde/ser.rs +++ b/crates/burn-core/src/record/serde/ser.rs @@ -350,7 +350,6 @@ mod tests { // the order of the fields is not guaranteed for HashMaps. assert_eq!(serialized_str.len(), 135); } - #[test] fn test_param_serde() { type Backend = burn_ndarray::NdArray; @@ -371,6 +370,6 @@ mod tests { // Compare the lengths of expected and actual serialized strings because // the order of the fields is not guaranteed for HashMaps. - assert_eq!(serialized_str.len(), 149); + assert_eq!(serialized_str.len(), 134); } }