Skip to content

Commit

Permalink
Fix out of memory bug #1270 (#1286)
Browse files Browse the repository at this point in the history
  • Loading branch information
antimora committed Feb 12, 2024
1 parent 397bc02 commit 16541ea
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 18 deletions.
5 changes: 3 additions & 2 deletions burn-core/src/record/serde/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ use serde::Deserialize;
#[derive(Debug, Clone)]
pub enum NestedValue {
/// The default value, which actually does not hold any value and it is used to indicate that
/// the value should be populated with the default value.
Default,
/// the value should be populated with the default value. It contains an optional string with
/// the originator field name.
Default(Option<String>),

/// A boolean value.
Bool(bool),
Expand Down
48 changes: 33 additions & 15 deletions burn-core/src/record/serde/de.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::collections::HashMap;

use super::adapter::DefaultAdapter;
use super::data::NestedValue;
use super::{adapter::BurnModuleAdapter, error::Error};

Expand Down Expand Up @@ -78,7 +77,8 @@ impl<'de, A: BurnModuleAdapter> serde::Deserializer<'de> for Deserializer<A> {
let map = if self.default_for_missing_fields {
let mut map = map;
for field in fields.iter().map(|s| s.to_string()) {
map.entry(field).or_insert(NestedValue::Default);
map.entry(field.clone())
.or_insert(NestedValue::Default(Some(field)));
}
map
} else {
Expand Down Expand Up @@ -420,11 +420,13 @@ where
T: DeserializeSeed<'de>,
{
match self.next_value.take() {
Some(NestedValue::Default) => seed.deserialize(DefaultDeserializer),
Some(NestedValue::Default(originator)) => {
seed.deserialize(DefaultDeserializer::new(originator))
}
Some(v) => seed.deserialize(
NestedValueWrapper::new(v, self.default_for_missing_fields).into_deserializer(),
),
None => seed.deserialize(DefaultDeserializer),
None => seed.deserialize(DefaultDeserializer::new(None)),
}
}
}
Expand Down Expand Up @@ -455,7 +457,18 @@ impl<A: BurnModuleAdapter> IntoDeserializer<'_, Error> for NestedValueWrapper<A>
}

/// A default deserializer that always returns the default value.
struct DefaultDeserializer;
struct DefaultDeserializer {
/// The originator field name (the top level missing field name)
originator_field_name: Option<String>,
}

impl DefaultDeserializer {
fn new(originator_field_name: Option<String>) -> Self {
Self {
originator_field_name,
}
}
}

impl<'de> serde::Deserializer<'de> for DefaultDeserializer {
type Error = Error;
Expand Down Expand Up @@ -581,20 +594,18 @@ impl<'de> serde::Deserializer<'de> for DefaultDeserializer {

fn deserialize_struct<V>(
self,
_name: &'static str,
name: &'static str,
_fields: &'static [&'static str],
visitor: V,
_visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
let mut map: HashMap<String, NestedValue> = HashMap::new();

for field in _fields.iter().map(|s| s.to_string()) {
map.insert(field, NestedValue::Default);
}

visitor.visit_map(HashMapAccess::<DefaultAdapter>::new(map, true))
panic!(
"Missing source values for the '{}' field of type '{}'. Please verify the source data and ensure the field name is correct",
self.originator_field_name.unwrap_or("UNKNOWN".to_string()),
name,
);
}

fn deserialize_tuple_struct<V>(
Expand Down Expand Up @@ -654,7 +665,14 @@ impl<'de> SeqAccess<'de> for DefaultSeqAccess {
where
T: DeserializeSeed<'de>,
{
seed.deserialize(DefaultDeserializer).map(Some)
match self.size {
Some(0) => Ok(None),
Some(ref mut size) => {
*size -= 1;
seed.deserialize(DefaultDeserializer::new(None)).map(Some)
}
None => Ok(None),
}
}

fn size_hint(&self) -> Option<usize> {
Expand Down
2 changes: 1 addition & 1 deletion burn-core/src/record/serde/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ impl SerializerTrait for Serializer {
}

fn serialize_none(self) -> Result<Self::Ok, Self::Error> {
Ok(NestedValue::Default)
Ok(NestedValue::Default(None))
}
fn serialize_u32(self, _v: u32) -> Result<Self::Ok, Self::Error> {
unimplemented!()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/usr/bin/env python3

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(2, 2, (2,2))

def forward(self, x):
x = self.conv1(x)
return x


def main():
torch.set_printoptions(precision=8)
torch.manual_seed(1)
model = Model().to(torch.device("cpu"))
torch.save(model.state_dict(), "missing_module_field.pt")

if __name__ == '__main__':
main()
Binary file not shown.
30 changes: 30 additions & 0 deletions burn-import/pytorch-tests/tests/missing_module_field/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use burn::{module::Module, nn::conv::Conv2d, tensor::backend::Backend};

#[derive(Module, Debug)]
pub struct Net<B: Backend> {
do_not_exist_in_pt: Conv2d<B>,
}

#[cfg(test)]
mod tests {
type Backend = burn_ndarray::NdArray<f32>;

use burn::record::{FullPrecisionSettings, Recorder};
use burn_import::pytorch::PyTorchFileRecorder;

use super::*;

#[test]
#[should_panic(
expected = "Missing source values for the 'do_not_exist_in_pt' field of type 'Conv2dRecordItem'. Please verify the source data and ensure the field name is correct"
)]
fn should_fail_if_struct_field_is_missing() {
let device = Default::default();
let _record: NetRecord<Backend> = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load(
"tests/missing_module_field/missing_module_field.pt".into(),
&device,
)
.expect("Should decode state successfully");
}
}
1 change: 1 addition & 0 deletions burn-import/pytorch-tests/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@ cfg_if::cfg_if! {
mod key_remap_chained;
mod layer_norm;
mod linear;
mod missing_module_field;
}
}

0 comments on commit 16541ea

Please sign in to comment.