Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix record nested value de/serialization #1751

Merged
merged 3 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 26 additions & 10 deletions crates/burn-core/src/record/serde/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ pub enum NestedValue {

/// A vector of nested values (typically used for vector of structs or numbers)
Vec(Vec<NestedValue>),

/// A vector of 16-bit unsigned integer values.
U16s(Vec<u16>),

/// A vector of 32-bit floating point values.
F32s(Vec<f32>),
}
impl NestedValue {
/// Get the nested value as a map.
Expand Down Expand Up @@ -313,19 +319,27 @@ fn cleanup_empty_maps(current: &mut NestedValue) {
}
}

fn write_vec_truncated<T: core::fmt::Debug>(
vec: &[T],
f: &mut core::fmt::Formatter,
) -> fmt::Result {
write!(f, "Vec([")?;
for (i, v) in vec.iter().take(3).enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{:?}", v)?;
}
write!(f, ", ...] len={})", vec.len())
}

impl fmt::Debug for NestedValue {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
NestedValue::Vec(vec) if vec.len() > 3 => {
write!(f, "Vec([")?;
for (i, v) in vec.iter().take(3).enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{:?}", v)?;
}
write!(f, ", ...] len={})", vec.len())
}
// Truncate values for vector
NestedValue::Vec(vec) if vec.len() > 3 => write_vec_truncated(vec, f),
NestedValue::U16s(vec) if vec.len() > 3 => write_vec_truncated(vec, f),
NestedValue::F32s(vec) if vec.len() > 3 => write_vec_truncated(vec, f),
// Handle other variants as usual
NestedValue::Default(origin) => f.debug_tuple("Default").field(origin).finish(),
NestedValue::Bool(b) => f.debug_tuple("Bool").field(b).finish(),
Expand All @@ -339,6 +353,8 @@ impl fmt::Debug for NestedValue {
NestedValue::U64(val) => f.debug_tuple("U64").field(val).finish(),
NestedValue::Map(map) => f.debug_map().entries(map.iter()).finish(),
NestedValue::Vec(vec) => f.debug_list().entries(vec.iter()).finish(),
NestedValue::U16s(vec) => f.debug_list().entries(vec.iter()).finish(),
NestedValue::F32s(vec) => f.debug_list().entries(vec.iter()).finish(),
}
}
}
127 changes: 112 additions & 15 deletions crates/burn-core/src/record/serde/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,13 +286,27 @@ impl<'de, A: BurnModuleAdapter> serde::Deserializer<'de> for Deserializer<A> {
where
V: Visitor<'de>,
{
if let Some(NestedValue::Vec(vec)) = self.value {
visitor.visit_seq(VecSeqAccess::<A>::new(vec, self.default_for_missing_fields))
if let Some(value) = self.value {
match value {
NestedValue::Vec(_) => visitor.visit_seq(VecSeqAccess::<A, NestedValue>::new(
value,
self.default_for_missing_fields,
)),
NestedValue::U16s(_) => visitor.visit_seq(VecSeqAccess::<A, u16>::new(
value,
self.default_for_missing_fields,
)),
NestedValue::F32s(_) => visitor.visit_seq(VecSeqAccess::<A, f32>::new(
value,
self.default_for_missing_fields,
)),
laggui marked this conversation as resolved.
Show resolved Hide resolved
Comment on lines +295 to +302
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we add more element types such as bf16, f16, f64, u64, u32?

Copy link
Member Author

@laggui laggui May 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we're gonna have to (at some point anyway). I started with the two types for my use cases (BF16 for Llama and F32 for all other models like the ResNet family for testing purposes). They're probably the most common too.

Wanted to get a review of the implementation before going further. As mentioned by @antimora currently this doesn't scale very well to add a concrete implementation for each type, but it's the easiest solution I came up with for now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see another way: we need to compiler to know the size of the vector of elements at compile time! Maybe we could read the vector as bytes instead and cast them later on?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm maybe. I think for now I'll stick to the current method to wrap up this PR and in the future we could refactor this if needed. At the same time, it's not like the number of types will not be manageable.. so "not scaling" isn't necessarily an issue right now.

_ => Err(de::Error::custom(format!(
"Expected Vec but got {:?}",
value
))),
}
} else {
Err(de::Error::custom(format!(
"Expected Vec but got {:?}",
self.value
)))
Err(de::Error::custom("Expected Vec but got None"))
}
}

Expand Down Expand Up @@ -385,23 +399,56 @@ impl<'de, A: BurnModuleAdapter> serde::Deserializer<'de> for Deserializer<A> {
}

/// A sequence access for a vector in the nested value data structure.
struct VecSeqAccess<A: BurnModuleAdapter> {
iter: std::vec::IntoIter<NestedValue>,
struct VecSeqAccess<A: BurnModuleAdapter, I> {
iter: Box<dyn Iterator<Item = I>>,
default_for_missing_fields: bool,
phantom: std::marker::PhantomData<A>,
}

impl<A: BurnModuleAdapter> VecSeqAccess<A> {
fn new(vec: Vec<NestedValue>, default_for_missing_fields: bool) -> Self {
VecSeqAccess {
iter: vec.into_iter(),
default_for_missing_fields,
phantom: std::marker::PhantomData,
// Concrete implementation for `Vec<NestedValue>`
impl<A: BurnModuleAdapter> VecSeqAccess<A, NestedValue> {
fn new(vec: NestedValue, default_for_missing_fields: bool) -> Self {
match vec {
NestedValue::Vec(v) => VecSeqAccess {
iter: Box::new(v.into_iter()),
default_for_missing_fields,
phantom: std::marker::PhantomData,
},
_ => panic!("Invalid vec sequence"),
}
}
}

// Concrete implementation for `Vec<u16>`
impl<A: BurnModuleAdapter> VecSeqAccess<A, u16> {
fn new(vec: NestedValue, default_for_missing_fields: bool) -> Self {
match vec {
NestedValue::U16s(v) => VecSeqAccess {
iter: Box::new(v.into_iter()),
default_for_missing_fields,
phantom: std::marker::PhantomData,
},
_ => panic!("Invalid vec sequence"),
}
}
}

impl<'de, A> SeqAccess<'de> for VecSeqAccess<A>
// Concrete implementation for `Vec<f32>`
impl<A: BurnModuleAdapter> VecSeqAccess<A, f32> {
fn new(vec: NestedValue, default_for_missing_fields: bool) -> Self {
match vec {
NestedValue::F32s(v) => VecSeqAccess {
iter: Box::new(v.into_iter()),
default_for_missing_fields,
phantom: std::marker::PhantomData,
},
_ => panic!("Invalid vec sequence"),
}
}
}

// Concrete implementation for `Vec<NestedValue>`
impl<'de, A> SeqAccess<'de> for VecSeqAccess<A, NestedValue>
where
NestedValueWrapper<A>: IntoDeserializer<'de, Error>,
A: BurnModuleAdapter,
Expand All @@ -424,6 +471,56 @@ where
}
}

// Concrete implementation for `Vec<u16>`
impl<'de, A> SeqAccess<'de> for VecSeqAccess<A, u16>
where
NestedValueWrapper<A>: IntoDeserializer<'de, Error>,
A: BurnModuleAdapter,
{
type Error = Error;

fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
where
T: DeserializeSeed<'de>,
{
let item = match self.iter.next() {
Some(v) => v,
None => return Ok(None),
};

seed.deserialize(
NestedValueWrapper::<A>::new(NestedValue::U16(item), self.default_for_missing_fields)
.into_deserializer(),
)
.map(Some)
}
}

// Concrete implementation for `Vec<f32>`
impl<'de, A> SeqAccess<'de> for VecSeqAccess<A, f32>
where
NestedValueWrapper<A>: IntoDeserializer<'de, Error>,
A: BurnModuleAdapter,
{
type Error = Error;

fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
where
T: DeserializeSeed<'de>,
{
let item = match self.iter.next() {
Some(v) => v,
None => return Ok(None),
};

seed.deserialize(
NestedValueWrapper::<A>::new(NestedValue::F32(item), self.default_for_missing_fields)
.into_deserializer(),
)
.map(Some)
}
}

/// A map access for a map in the nested value data structure.
struct HashMapAccess<A: BurnModuleAdapter> {
iter: std::collections::hash_map::IntoIter<String, NestedValue>,
Expand Down
26 changes: 22 additions & 4 deletions crates/burn-core/src/record/serde/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use serde::{
/// the actual serialization of modules (although it could be used for that as well if all
/// primitive types are implemented).
pub struct Serializer {
// The state of the serialization process
/// The state of the serialization process
state: Option<NestedValue>,
}

Expand Down Expand Up @@ -254,11 +254,30 @@ impl SerializeSeq for Serializer {
Some(NestedValue::Vec(ref mut vec)) => {
vec.push(serialized_value); // Inserting into the state
}
Some(NestedValue::U16s(ref mut vec)) => {
if let NestedValue::U16(val) = serialized_value {
vec.push(val);
} else {
panic!("Invalid value type encountered");
}
}
Some(NestedValue::F32s(ref mut vec)) => {
if let NestedValue::F32(val) = serialized_value {
vec.push(val);
} else {
panic!("Invalid value type encountered");
}
}
Some(_) => {
panic!("Invalid state encountered");
}
None => {
self.state = Some(NestedValue::Vec(vec![serialized_value]));
let val = match serialized_value {
NestedValue::U16(val) => NestedValue::U16s(vec![val]),
NestedValue::F32(val) => NestedValue::F32s(vec![val]),
_ => NestedValue::Vec(vec![serialized_value]),
};
self.state = Some(val);
}
}

Expand Down Expand Up @@ -331,7 +350,6 @@ mod tests {
// the order of the fields is not guaranteed for HashMaps.
assert_eq!(serialized_str.len(), 135);
}

#[test]
fn test_param_serde() {
type Backend = burn_ndarray::NdArray<f32>;
Expand All @@ -352,6 +370,6 @@ mod tests {

// Compare the lengths of expected and actual serialized strings because
// the order of the fields is not guaranteed for HashMaps.
assert_eq!(serialized_str.len(), 149);
assert_eq!(serialized_str.len(), 134);
}
}
Loading