Skip to content

Commit

Permalink
Allow nested structured records
Browse files Browse the repository at this point in the history
  • Loading branch information
potocpav committed Mar 13, 2018
1 parent f205d18 commit 87eefde
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 14 deletions.
68 changes: 59 additions & 9 deletions src/header.rs
Expand Up @@ -50,7 +50,9 @@ impl DType {
let shape_str = shape.iter().fold(String::new(), |o,n| o + &format!("{},", n));
format!("('{}', '{}', ({})), ", name, ty, shape_str)
},
Record(_) => unimplemented!("nested record dtypes")
ref record@Record(_) => {
format!("('{}', {}), ", name, record.descr())
},
}
)
.fold("[".to_string(), |o, n| o + &n) + "]",
Expand Down Expand Up @@ -78,20 +80,28 @@ fn convert_list_to_record_fields(values: &[Value]) -> Result<Vec<Field>> {
}

fn convert_tuple_to_record_field(tuple: &[Value]) -> Result<Field> {
use self::Value::String;
use self::Value::{String,List};

match tuple.len() {
2 | 3 => match (&tuple[0], &tuple[1]) {
(&String(ref name), &String(ref dtype)) =>
2 | 3 => match (&tuple[0], &tuple[1], tuple.get(2)) {
(&String(ref name), &String(ref dtype), ref shape) =>
Ok(Field { name: name.clone(), dtype: DType::Plain {
ty: dtype.clone(),
shape: if tuple.len() == 2 {
vec![]
shape: if let &Some(ref s) = shape {
convert_value_to_shape(s)?
} else {
convert_value_to_shape(&tuple[2])?
vec![]
}
} }),
_ => invalid_data("list entry must contain strings for id and dtype")
(&String(ref name), &List(ref list), None) =>
Ok(Field {
name: name.clone(),
dtype: DType::Record(convert_list_to_record_fields(list)?)
}),
(&String(_), &List(_), Some(_)) =>
invalid_data("nested arrays of Record types are not supported."),
_ =>
invalid_data("list entry must contain a string for id and a valid dtype")
},
_ => invalid_data("list entry must contain 2 or 3 items")
}
Expand Down Expand Up @@ -257,6 +267,22 @@ mod tests {
assert_eq!(dtype.descr(), "'>f8'");
}

#[test]
fn description_of_nested_record_dtype() {
let dtype = DType::Record(vec![
Field {
name: "parent".to_string(),
dtype: DType::Record(vec![
Field {
name: "child".to_string(),
dtype: DType::Plain { ty: "<i4".to_string(), shape: vec![] }
},
]),
}
]);
assert_eq!(dtype.descr(), "[('parent', [('child', '<i4'), ]), ]");
}

#[test]
fn converts_simple_description_to_record_dtype() {
let dtype = ">f8".to_string();
Expand All @@ -283,7 +309,7 @@ mod tests {
}

#[test]
fn record_description_with_onedimenional_field_shape_declaration() {
fn record_description_with_onedimensional_field_shape_declaration() {
let descr = parse("[('a', '>f8', (1,))]");
let expected_dtype = DType::Record(vec![
Field {
Expand All @@ -294,6 +320,30 @@ mod tests {
assert_eq!(DType::from_descr(descr).unwrap(), expected_dtype);
}

#[test]
fn record_description_with_nested_record_field() {
let descr = parse("[('parent', [('child', '<i4')])]");
let expected_dtype = DType::Record(vec![
Field {
name: "parent".to_string(),
dtype: DType::Record(vec![
Field {
name: "child".to_string(),
dtype: DType::Plain { ty: "<i4".to_string(), shape: vec![] }
},
]),
}
]);
assert_eq!(DType::from_descr(descr).unwrap(), expected_dtype);
}


#[test]
fn errors_on_nested_record_field_array() {
let descr = parse("[('parent', [('child', '<i4')], (2,))]");
assert!(DType::from_descr(descr).is_err());
}

#[test]
fn errors_on_value_variants_that_cannot_be_converted() {
let no_dtype = Value::Bool(false);
Expand Down
13 changes: 8 additions & 5 deletions tests/roundtrip.rs
Expand Up @@ -8,6 +8,12 @@ use std::io::{Read, Write};
use byteorder::{WriteBytesExt, LittleEndian};
use npy::{DType, Serializable};

#[derive(Serializable, Debug, PartialEq, Clone)]
struct Nested {
v1: f32,
v2: f32,
}

#[derive(Serializable, Debug, PartialEq, Clone)]
struct Array {
v_i8: i8,
Expand All @@ -23,6 +29,7 @@ struct Array {
v_arr_u32: [u32;7],
v_mat_u64: [[u64; 3]; 5],
vec: Vector5,
nested: Nested,
}

#[derive(Debug, PartialEq, Clone)]
Expand Down Expand Up @@ -79,6 +86,7 @@ fn roundtrip() {
v_arr_u32: [j,1+j,2+j,3+j,4+j,5+j,6+j],
v_mat_u64: [[k,1+k,2+k],[3+k,4+k,5+k],[6+k,7+k,8+k],[9+k,10+k,11+k],[12+k,13+k,14+k]],
vec: Vector5(vec![1,2,3,4,5]),
nested: Nested { v1: 10.0 * i as f32, v2: i as f32 },
};
arrays.push(a);
}
Expand Down Expand Up @@ -106,8 +114,3 @@ fn roundtrip_with_simple_dtype() {
let array_read = npy::NpyData::from_bytes(&buffer).unwrap().to_vec();
assert_eq!(array_written, array_read);
}

#[derive(Serializable, Debug, PartialEq, Clone)]
struct S {
s: [[[i8; 2]; 3]; 4],
}

0 comments on commit 87eefde

Please sign in to comment.