diff --git a/src/models.rs b/src/models.rs index 478130d..2bfd46c 100644 --- a/src/models.rs +++ b/src/models.rs @@ -3,7 +3,7 @@ use strum_macros::Display; use url::Url; use validator::{Validate, ValidationError}; -#[derive(Debug, Deserialize, Display, PartialEq)] +#[derive(Clone, Copy, Debug, Deserialize, Display, PartialEq)] #[serde(rename_all = "lowercase")] pub enum DType { Int32, @@ -14,6 +14,20 @@ pub enum DType { Float64, } +impl DType { + /// Returns the size of the associated type in bytes. + fn size_of(self) -> usize { + match self { + Self::Int32 => std::mem::size_of::(), + Self::Int64 => std::mem::size_of::(), + Self::Uint32 => std::mem::size_of::(), + Self::Uint64 => std::mem::size_of::(), + Self::Float32 => std::mem::size_of::(), + Self::Float64 => std::mem::size_of::(), + } + } +} + #[derive(Debug, Deserialize, PartialEq)] pub enum Order { C, @@ -67,15 +81,28 @@ fn validate_request_data(request_data: &RequestData) -> Result<(), ValidationErr // Validation of multiple fields in RequestData. // TODO: More validation of shape & selection vs. size // TODO: More validation that selection fits in shape - if let Some(shape) = &request_data.shape { - if let Some(selection) = &request_data.selection { + if let Some(size) = &request_data.size { + if (*size as usize) % request_data.dtype.size_of() != 0 { + return Err(ValidationError::new( + "Size must be a multiple of dtype size in bytes", + )); + } + }; + match (&request_data.shape, &request_data.selection) { + (Some(shape), Some(selection)) => { if shape.len() != selection.len() { return Err(ValidationError::new( "Shape and selection must have the same length", )); } } - } + (None, Some(_)) => { + return Err(ValidationError::new( + "Selection requires shape to be specified", + )); + } + _ => (), + }; Ok(()) } @@ -121,8 +148,8 @@ mod tests { bucket: "bar".to_string(), object: "baz".to_string(), dtype: DType::Int32, - offset: Some(1), - size: Some(2), + offset: Some(4), + size: Some(8), shape: Some(vec![1, 2]), order: Some(Order::C), selection: Some(vec![ @@ -191,10 +218,10 @@ mod tests { Token::Unit, Token::Str("offset"), Token::Some, - Token::U32(1), + Token::U32(4), Token::Str("size"), Token::Some, - Token::U32(2), + Token::U32(8), Token::Str("shape"), Token::Some, Token::Seq { len: Some(2) }, @@ -405,6 +432,14 @@ mod tests { request_data.validate().unwrap() } + #[test] + #[should_panic(expected = "Size must be a multiple of dtype size in bytes")] + fn test_invalid_size_for_dtype() { + let mut request_data = get_test_request_data(); + request_data.size = Some(1); + request_data.validate().unwrap() + } + #[test] #[should_panic(expected = "Shape and selection must have the same length")] fn test_shape_selection_mismatch() { @@ -418,6 +453,17 @@ mod tests { request_data.validate().unwrap() } + #[test] + #[should_panic(expected = "Selection requires shape to be specified")] + fn test_selection_without_shape() { + let mut request_data = get_test_request_data(); + request_data.selection = Some(vec![Slice { + start: 1, + end: 2, + stride: 1, + }]); + request_data.validate().unwrap() + } #[test] fn test_unknown_field() { assert_de_tokens_error::(&[ @@ -440,7 +486,7 @@ mod tests { #[test] fn test_json_optional_fields() { - let json = r#"{"source": "http://example.com", "bucket": "bar", "object": "baz", "dtype": "int32", "offset": 1, "size": 2, "shape": [1, 2], "order": "C", "selection": [[1, 2, 3], [4, 5, 6]]}"#; + let json = r#"{"source": "http://example.com", "bucket": "bar", "object": "baz", "dtype": "int32", "offset": 4, "size": 8, "shape": [1, 2], "order": "C", "selection": [[1, 2, 3], [4, 5, 6]]}"#; let request_data = serde_json::from_str::(json).unwrap(); assert_eq!(request_data, get_test_request_data_optional()); }