From 6518d00170089832a04db47bb08786baf836a913 Mon Sep 17 00:00:00 2001 From: Mark Goddard Date: Tue, 20 Jun 2023 15:55:02 +0100 Subject: [PATCH 1/3] Add support for Gzip and Zlib compression This allows the S3 active storage server to work with data that was compressed using one of these algorithms before it was written to the object store. Compression is configured using the optional `compression` field in the API JSON request data, and if present should be set to "gzip" or "zlib". We are using the standard flate2 library for decompression. There may be more performant options to consider in future, but this works well as a first pass, and changing in future will not affect the API. --- Cargo.lock | 22 ++++++- Cargo.toml | 1 + README.md | 8 ++- scripts/client.py | 2 + scripts/upload_sample_data.py | 17 +++++- src/app.rs | 2 + src/array.rs | 5 ++ src/compression.rs | 104 ++++++++++++++++++++++++++++++++++ src/error.rs | 16 +++++- src/filter_pipeline.rs | 84 +++++++++++++++++++++++++++ src/lib.rs | 2 + src/metrics.rs | 2 +- src/models.rs | 46 ++++++++++++++- src/operation.rs | 2 + src/operations.rs | 8 +++ 15 files changed, 311 insertions(+), 10 deletions(-) create mode 100644 src/compression.rs create mode 100644 src/filter_pipeline.rs diff --git a/Cargo.lock b/Cargo.lock index feba1d5..8aaf0c7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -460,7 +460,7 @@ dependencies = [ "cc", "cfg-if", "libc", - "miniz_oxide", + "miniz_oxide 0.6.2", "object", "rustc-demangle", ] @@ -744,6 +744,16 @@ dependencies = [ "instant", ] +[[package]] +name = "flate2" +version = "1.0.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b9429470923de8e8cbd4d2dc513535400b4b3fef0319fb5c4e1f520a7bef743" +dependencies = [ + "crc32fast", + "miniz_oxide 0.7.1", +] + [[package]] name = "fnv" version = "1.0.7" @@ -1184,6 +1194,15 @@ dependencies = [ "adler", ] +[[package]] +name = "miniz_oxide" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7810e0be55b428ada41041c41f32c9f1a42817901b4ccf45fa3d4b6561e74c7" +dependencies = [ + "adler", +] + [[package]] name = "mio" version = "0.8.8" @@ -1691,6 +1710,7 @@ dependencies = [ "axum-server", "clap", "expanduser", + "flate2", "http", "hyper", "lazy_static", diff --git a/Cargo.toml b/Cargo.toml index 059d11f..a44df91 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ axum = { version = "0.6", features = ["headers"] } axum-server = { version = "0.4.7", features = ["tls-rustls"] } clap = { version = "4.2", features = ["derive", "env"] } expanduser = "1.2.2" +flate2 = "1.0" http = "*" hyper = { version = "0.14", features = ["full"] } lazy_static = "1.4" diff --git a/README.md b/README.md index 40514c6..6536638 100644 --- a/README.md +++ b/README.md @@ -73,7 +73,11 @@ with a JSON payload of the form: "selection": [ [0, 19, 2], [1, 3, 1] - ] + ], + + // Algorithm used to compress the data + // - optional, defaults to no compression + "compression": "gzip|zlib" } ``` @@ -92,7 +96,7 @@ In particular, the following are known limitations which we intend to address: * Error handling and reporting is minimal * No support for missing data - * No support for compressed or encrypted objects + * No support for encrypted objects ## Running diff --git a/scripts/client.py b/scripts/client.py index 1d60999..f4e2692 100644 --- a/scripts/client.py +++ b/scripts/client.py @@ -36,6 +36,7 @@ def get_args() -> argparse.Namespace: parser.add_argument("--shape", type=str) parser.add_argument("--order", default="C") #, choices=["C", "F"]) allow invalid for testing parser.add_argument("--selection", type=str) + parser.add_argument("--compression", type=str) parser.add_argument("--show-response-headers", action=argparse.BooleanOptionalAction) return parser.parse_args() @@ -49,6 +50,7 @@ def build_request_data(args: argparse.Namespace) -> dict: 'offset': args.offset, 'size': args.size, 'order': args.order, + 'compression': args.compression, } if args.shape: request_data["shape"] = json.loads(args.shape) diff --git a/scripts/upload_sample_data.py b/scripts/upload_sample_data.py index 94fb78c..f8f8621 100644 --- a/scripts/upload_sample_data.py +++ b/scripts/upload_sample_data.py @@ -1,10 +1,13 @@ from enum import Enum +import gzip import numpy as np import pathlib import s3fs +import zlib NUM_ITEMS = 10 OBJECT_PREFIX = "data" +COMPRESSION_ALGS = [None, "gzip", "zlib"] #Use enum which also subclasses string type so that # auto-generated OpenAPI schema can determine allowed dtypes @@ -33,8 +36,16 @@ def n_bytes(self): pass # Create numpy arrays and upload to S3 as bytes -for d in AllowedDatatypes.__members__.keys(): - with s3_fs.open(bucket / f'{OBJECT_PREFIX}-{d}.dat', 'wb') as s3_file: - s3_file.write(np.arange(NUM_ITEMS, dtype=d).tobytes()) +for compression in COMPRESSION_ALGS: + compression_suffix = f"-{compression}" if compression else "" + for d in AllowedDatatypes.__members__.keys(): + obj_name = f'{OBJECT_PREFIX}-{d}{compression_suffix}.dat' + with s3_fs.open(bucket / obj_name, 'wb') as s3_file: + data = np.arange(NUM_ITEMS, dtype=d).tobytes() + if compression == "gzip": + data = gzip.compress(data) + elif compression == "zlib": + data = zlib.compress(data) + s3_file.write(data) print("Data upload successful. \nBucket contents:\n", s3_fs.ls(bucket)) diff --git a/src/app.rs b/src/app.rs index 8818e72..5b01199 100644 --- a/src/app.rs +++ b/src/app.rs @@ -1,6 +1,7 @@ //! Active Storage server API use crate::error::ActiveStorageError; +use crate::filter_pipeline; use crate::metrics::{metrics_handler, track_metrics}; use crate::models; use crate::operation; @@ -159,6 +160,7 @@ async fn operation_handler( ValidatedJson(request_data): ValidatedJson, ) -> Result { let data = download_object(&auth, &request_data).await?; + let data = filter_pipeline::filter_pipeline(&request_data, &data)?; T::execute(&request_data, &data) } diff --git a/src/array.rs b/src/array.rs index 4ffb599..bbe89ad 100644 --- a/src/array.rs +++ b/src/array.rs @@ -236,6 +236,7 @@ mod tests { shape: None, order: None, selection: None, + compression: None, }, ); assert_eq!([42], shape.raw_dim().as_array_view().as_slice().unwrap()); @@ -255,6 +256,7 @@ mod tests { shape: Some(vec![1, 2, 3]), order: None, selection: None, + compression: None, }, ); assert_eq!( @@ -458,6 +460,7 @@ mod tests { shape: None, order: None, selection: None, + compression: None, }; let bytes = Bytes::copy_from_slice(&data); let array = build_array::(&request_data, &bytes).unwrap(); @@ -477,6 +480,7 @@ mod tests { shape: Some(vec![2, 1]), order: None, selection: None, + compression: None, }; let bytes = Bytes::copy_from_slice(&data); let array = build_array::(&request_data, &bytes).unwrap(); @@ -496,6 +500,7 @@ mod tests { shape: None, order: None, selection: None, + compression: None, }; let bytes = Bytes::copy_from_slice(&data); let array = build_array::(&request_data, &bytes).unwrap(); diff --git a/src/compression.rs b/src/compression.rs new file mode 100644 index 0000000..8e59db4 --- /dev/null +++ b/src/compression.rs @@ -0,0 +1,104 @@ +//! (De)compression support. + +use crate::error::ActiveStorageError; +use crate::models; + +use axum::body::Bytes; +use flate2::read::{GzDecoder, ZlibDecoder}; +use std::io::Read; + +/// Decompresses some Bytes and returns the uncompressed data. +/// +/// # Arguments +/// +/// * `compression`: Compression algorithm +/// * `data`: Compressed data [Bytes](axum::body::Bytes) +pub fn decompress( + compression: models::Compression, + data: &Bytes, +) -> Result { + let mut decoder: Box = match compression { + models::Compression::Gzip => Box::new(GzDecoder::<&[u8]>::new(data)), + models::Compression::Zlib => Box::new(ZlibDecoder::<&[u8]>::new(data)), + }; + // The data returned by the S3 client does not have any alignment guarantees. In order to + // reinterpret the data as an array of numbers with a higher alignment than 1, we need to + // return the data in Bytes object in which the underlying data has a higher alignment. + // For now we're hard-coding an alignment of 8 bytes, although this should depend on the + // data type, and potentially whether there are any SIMD requirements. + // Create an 8-byte aligned Vec. + // FIXME: The compressed length will not be enough to store the uncompressed data, and may + // result in a change in the underlying buffer to one that is not correctly aligned. + let mut buf = maligned::align_first::(data.len()); + decoder.read_to_end(&mut buf)?; + // Release any unnecessary capacity. + buf.shrink_to(0); + Ok(buf.into()) +} + +#[cfg(test)] +mod tests { + use super::*; + use flate2::read::{GzEncoder, ZlibEncoder}; + use flate2::Compression; + + fn compress_gzip() -> Vec { + // Adapated from flate2 documentation. + let mut result = Vec::::new(); + let input = b"hello world"; + let mut deflater = GzEncoder::new(&input[..], Compression::fast()); + deflater.read_to_end(&mut result).unwrap(); + result + } + + fn compress_zlib() -> Vec { + // Adapated from flate2 documentation. + let mut result = Vec::::new(); + let input = b"hello world"; + let mut deflater = ZlibEncoder::new(&input[..], Compression::fast()); + deflater.read_to_end(&mut result).unwrap(); + result + } + + #[test] + fn test_decompress_gzip() { + let compressed = compress_gzip(); + let result = decompress(models::Compression::Gzip, &compressed.into()).unwrap(); + assert_eq!(result, b"hello world".as_ref()); + assert_eq!(result.as_ptr().align_offset(8), 0); + } + + #[test] + fn test_decompress_zlib() { + let compressed = compress_zlib(); + let result = decompress(models::Compression::Zlib, &compressed.into()).unwrap(); + assert_eq!(result, b"hello world".as_ref()); + assert_eq!(result.as_ptr().align_offset(8), 0); + } + + #[test] + fn test_decompress_invalid_gzip() { + let invalid = b"invalid format"; + let err = decompress(models::Compression::Gzip, &invalid.as_ref().into()).unwrap_err(); + match err { + ActiveStorageError::Decompression(io_err) => { + assert_eq!(io_err.kind(), std::io::ErrorKind::InvalidInput); + assert_eq!(io_err.to_string(), "invalid gzip header"); + } + err => panic!("unexpected error {}", err), + } + } + + #[test] + fn test_decompress_invalid_zlib() { + let invalid = b"invalid format"; + let err = decompress(models::Compression::Zlib, &invalid.as_ref().into()).unwrap_err(); + match err { + ActiveStorageError::Decompression(io_err) => { + assert_eq!(io_err.kind(), std::io::ErrorKind::InvalidInput); + assert_eq!(io_err.to_string(), "corrupt deflate stream"); + } + err => panic!("unexpected error {}", err), + } + } +} diff --git a/src/error.rs b/src/error.rs index 02f0618..8f6b780 100644 --- a/src/error.rs +++ b/src/error.rs @@ -22,6 +22,10 @@ use tracing::{event, Level}; /// Each variant may result in a different API error response. #[derive(Debug, Error)] pub enum ActiveStorageError { + /// Error decompressing data + #[error("failed to decompress data")] + Decompression(#[from] std::io::Error), + /// Attempt to perform an invalid operation on an empty array or selection #[error("cannot perform {operation} on empty array or selection")] EmptyArray { operation: &'static str }, @@ -174,7 +178,8 @@ impl From for ErrorResponse { fn from(error: ActiveStorageError) -> Self { let response = match &error { // Bad request - ActiveStorageError::EmptyArray { operation: _ } + ActiveStorageError::Decompression(_) + | ActiveStorageError::EmptyArray { operation: _ } | ActiveStorageError::RequestDataJsonRejection(_) | ActiveStorageError::RequestDataValidation(_) | ActiveStorageError::ShapeInvalid(_) => Self::bad_request(&error), @@ -309,6 +314,15 @@ mod tests { assert_eq!(caused_by, error_response.error.caused_by); } + #[tokio::test] + async fn decompression_error() { + let io_error = std::io::Error::new(std::io::ErrorKind::InvalidInput, "decompression error"); + let error = ActiveStorageError::Decompression(io_error); + let message = "failed to decompress data"; + let caused_by = Some(vec!["decompression error"]); + test_active_storage_error(error, StatusCode::BAD_REQUEST, message, caused_by).await; + } + #[tokio::test] async fn empty_array_op_error() { let error = ActiveStorageError::EmptyArray { operation: "foo" }; diff --git a/src/filter_pipeline.rs b/src/filter_pipeline.rs new file mode 100644 index 0000000..8b7f93e --- /dev/null +++ b/src/filter_pipeline.rs @@ -0,0 +1,84 @@ +//! Compression and filter pipeline. + +use crate::compression; +use crate::error::ActiveStorageError; +use crate::models; + +use axum::body::Bytes; + +/// Returns data after applying a filter pipeline. +/// +/// The pipeline is applied in the reverse order to when the data was written. +/// +/// # Arguments +/// +/// * `request_data`: RequestData object for the request +/// * `data`: Data to apply filter pipeline to. +pub fn filter_pipeline( + request_data: &models::RequestData, + data: &Bytes, +) -> Result { + if let Some(compression) = request_data.compression { + compression::decompress(compression, data) + } else { + Ok(data.clone()) + } + // TODO: Defilter +} + +#[cfg(test)] +mod tests { + use super::*; + use flate2::read::GzEncoder; + use flate2::Compression; + use std::io::Read; + use url::Url; + + fn compress_gzip(data: &[u8]) -> Bytes { + // Adapated from flate2 documentation. + let mut result = Vec::::new(); + let mut deflater = GzEncoder::new(data, Compression::fast()); + deflater.read_to_end(&mut result).unwrap(); + result.into() + } + + #[test] + fn test_filter_pipeline_noop() { + let data = [1, 2, 3, 4]; + let bytes = Bytes::copy_from_slice(&data); + let request_data = models::RequestData { + source: Url::parse("http://example.com").unwrap(), + bucket: "bar".to_string(), + object: "baz".to_string(), + dtype: models::DType::Int32, + offset: None, + size: None, + shape: None, + order: None, + selection: None, + compression: None, + }; + let result = filter_pipeline(&request_data, &bytes).unwrap(); + assert_eq!(data.as_ref(), result); + } + + #[test] + fn test_filter_pipeline_gzip() { + let data = [1, 2, 3, 4]; + let bytes = compress_gzip(data.as_ref()); + let request_data = models::RequestData { + source: Url::parse("http://example.com").unwrap(), + bucket: "bar".to_string(), + object: "baz".to_string(), + dtype: models::DType::Int32, + offset: None, + size: None, + shape: None, + order: None, + selection: None, + compression: Some(models::Compression::Gzip), + }; + let result = filter_pipeline(&request_data, &bytes).unwrap(); + assert_eq!(data.as_ref(), result); + } +} diff --git a/src/lib.rs b/src/lib.rs index 391b4f4..a4e89cc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -25,7 +25,9 @@ pub mod app; pub mod array; pub mod cli; +pub mod compression; pub mod error; +pub mod filter_pipeline; pub mod metrics; pub mod models; pub mod operation; diff --git a/src/metrics.rs b/src/metrics.rs index 20d4d38..ffdff2e 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -48,7 +48,7 @@ pub async fn metrics_handler() -> String { .encode(&prometheus::gather(), &mut buffer) .expect("could not encode gathered metrics into temporary buffer"); - String::from_utf8(buffer).expect("could not convert metrics buffer into string") + String::from_utf8(buffer).expect("could not convert metrics buffer into string") } pub async fn track_metrics(request: Request, next: Next) -> impl IntoResponse { diff --git a/src/models.rs b/src/models.rs index 088419d..8c5965c 100644 --- a/src/models.rs +++ b/src/models.rs @@ -85,6 +85,16 @@ impl Slice { } } +/// Compression algorithm +#[derive(Clone, Copy, Debug, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum Compression { + /// Gzip + Gzip, + /// Zlib + Zlib, +} + /// Request data for operations #[derive(Debug, Deserialize, PartialEq, Validate)] #[serde(deny_unknown_fields)] @@ -118,6 +128,8 @@ pub struct RequestData { #[validate] #[validate(length(min = 1, message = "selection length must be greater than 0"))] pub selection: Option>, + /// Compression filter name + pub compression: Option, } /// Validate an array shape @@ -218,6 +230,7 @@ mod tests { shape: None, order: None, selection: None, + compression: None, } } @@ -232,6 +245,7 @@ mod tests { shape: Some(vec![2, 5]), order: Some(Order::C), selection: Some(vec![Slice::new(1, 2, 3), Slice::new(4, 5, 6)]), + compression: Some(Compression::Gzip), } } @@ -315,6 +329,13 @@ mod tests { Token::U32(6), Token::SeqEnd, Token::SeqEnd, + Token::Str("compression"), + Token::Some, + Token::Enum { + name: "Compression", + }, + Token::Str("gzip"), + Token::Unit, Token::StructEnd, ], ); @@ -569,6 +590,27 @@ mod tests { request_data.selection = Some(vec![Slice::new(1, 2, 1)]); request_data.validate().unwrap() } + + #[test] + fn test_invalid_compression() { + assert_de_tokens_error::( + &[ + Token::Struct { + name: "RequestData", + len: 2, + }, + Token::Str("compression"), + Token::Some, + Token::Enum { + name: "Compression", + }, + Token::Str("foo"), + Token::StructEnd, + ], + "unknown variant `foo`, expected `gzip` or `zlib`", + ) + } + #[test] fn test_unknown_field() { assert_de_tokens_error::(&[ @@ -576,7 +618,7 @@ mod tests { Token::Str("foo"), Token::StructEnd ], - "unknown field `foo`, expected one of `source`, `bucket`, `object`, `dtype`, `offset`, `size`, `shape`, `order`, `selection`" + "unknown field `foo`, expected one of `source`, `bucket`, `object`, `dtype`, `offset`, `size`, `shape`, `order`, `selection`, `compression`" ) } @@ -591,7 +633,7 @@ mod tests { #[test] fn test_json_optional_fields() { - let json = r#"{"source": "http://example.com", "bucket": "bar", "object": "baz", "dtype": "int32", "offset": 4, "size": 8, "shape": [2, 5], "order": "C", "selection": [[1, 2, 3], [4, 5, 6]]}"#; + let json = r#"{"source": "http://example.com", "bucket": "bar", "object": "baz", "dtype": "int32", "offset": 4, "size": 8, "shape": [2, 5], "order": "C", "selection": [[1, 2, 3], [4, 5, 6]], "compression": "gzip"}"#; let request_data = serde_json::from_str::(json).unwrap(); assert_eq!(request_data, get_test_request_data_optional()); } diff --git a/src/operation.rs b/src/operation.rs index 36746f8..cb66305 100644 --- a/src/operation.rs +++ b/src/operation.rs @@ -119,6 +119,7 @@ mod tests { shape: None, order: None, selection: None, + compression: None, }; let data = [1, 2, 3, 4]; let bytes = Bytes::copy_from_slice(&data); @@ -159,6 +160,7 @@ mod tests { shape: None, order: None, selection: None, + compression: None, }; let data = [1, 2, 3, 4]; let bytes = Bytes::copy_from_slice(&data); diff --git a/src/operations.rs b/src/operations.rs index 730ffb3..0290b5d 100644 --- a/src/operations.rs +++ b/src/operations.rs @@ -213,6 +213,7 @@ mod tests { shape: None, order: None, selection: None, + compression: None, }; let data: [u8; 8] = [1, 2, 3, 4, 5, 6, 7, 8]; let bytes = Bytes::copy_from_slice(&data); @@ -239,6 +240,7 @@ mod tests { shape: None, order: None, selection: None, + compression: None, }; // data: // A u8 slice of 8 elements == a single i64 value @@ -268,6 +270,7 @@ mod tests { shape: None, order: None, selection: None, + compression: None, }; let data = [1, 2, 3, 4, 5, 6, 7, 8]; let bytes = Bytes::copy_from_slice(&data); @@ -292,6 +295,7 @@ mod tests { shape: None, order: None, selection: None, + compression: None, }; let data = [1, 2, 3, 4, 5, 6, 7, 8]; let bytes = Bytes::copy_from_slice(&data); @@ -316,6 +320,7 @@ mod tests { shape: None, order: None, selection: None, + compression: None, }; let data = [1, 2, 3, 4, 5, 6, 7, 8]; let bytes = Bytes::copy_from_slice(&data); @@ -340,6 +345,7 @@ mod tests { shape: Some(vec![2, 1]), order: None, selection: None, + compression: None, }; let data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; let bytes = Bytes::copy_from_slice(&data); @@ -367,6 +373,7 @@ mod tests { models::Slice::new(0, 2, 1), models::Slice::new(1, 2, 1), ]), + compression: None, }; // 2x2 array, select second row of each column. // [[0x04030201, 0x08070605], [0x12111009, 0x16151413]] @@ -394,6 +401,7 @@ mod tests { shape: None, order: None, selection: None, + compression: None, }; let data = [1, 2, 3, 4, 5, 6, 7, 8]; let bytes = Bytes::copy_from_slice(&data); From 8c28d4624f73e3b9460506555857de11e1c9c83c Mon Sep 17 00:00:00 2001 From: Mark Goddard Date: Mon, 10 Jul 2023 10:42:40 +0100 Subject: [PATCH 2/3] Perform size validation after decompression When data is compressed, the size parameter refers to the size of the compressed data. Typically this is not equal to the size of the uncompressed data, so we can't validate it against the data type size. This change skips initial size/dtype validation when compression is applied, instead performing it once the data has been decompressed. It also adds an additional validation that the size matches the shape, when a shape has been specified. --- src/app.rs | 4 ++++ src/error.rs | 16 +++++++++++++- src/models.rs | 58 ++++++++++++++++++++++++++++++++++++++++++++------- 3 files changed, 69 insertions(+), 9 deletions(-) diff --git a/src/app.rs b/src/app.rs index 5b01199..3fae026 100644 --- a/src/app.rs +++ b/src/app.rs @@ -161,6 +161,10 @@ async fn operation_handler( ) -> Result { let data = download_object(&auth, &request_data).await?; let data = filter_pipeline::filter_pipeline(&request_data, &data)?; + if request_data.compression.is_some() || request_data.size.is_none() { + // Validate the raw uncompressed data size now that we know it. + models::validate_raw_size(data.len(), request_data.dtype, &request_data.shape)?; + } T::execute(&request_data, &data) } diff --git a/src/error.rs b/src/error.rs index 8f6b780..9f44df8 100644 --- a/src/error.rs +++ b/src/error.rs @@ -38,7 +38,11 @@ pub enum ActiveStorageError { #[error("request data is not valid")] RequestDataJsonRejection(#[from] JsonRejection), - /// Error validating RequestData + /// Error validating RequestData (single error) + #[error("request data is not valid")] + RequestDataValidationSingle(#[from] validator::ValidationError), + + /// Error validating RequestData (multiple errors) #[error("request data is not valid")] RequestDataValidation(#[from] validator::ValidationErrors), @@ -181,6 +185,7 @@ impl From for ErrorResponse { ActiveStorageError::Decompression(_) | ActiveStorageError::EmptyArray { operation: _ } | ActiveStorageError::RequestDataJsonRejection(_) + | ActiveStorageError::RequestDataValidationSingle(_) | ActiveStorageError::RequestDataValidation(_) | ActiveStorageError::ShapeInvalid(_) => Self::bad_request(&error), @@ -340,6 +345,15 @@ mod tests { .await; } + #[tokio::test] + async fn request_data_validation_single() { + let validation_error = validator::ValidationError::new("foo"); + let error = ActiveStorageError::RequestDataValidationSingle(validation_error); + let message = "request data is not valid"; + let caused_by = Some(vec!["Validation error: foo [{}]"]); + test_active_storage_error(error, StatusCode::BAD_REQUEST, message, caused_by).await; + } + #[tokio::test] async fn request_data_validation() { let mut validation_errors = validator::ValidationErrors::new(); diff --git a/src/models.rs b/src/models.rs index 8c5965c..fffbcf1 100644 --- a/src/models.rs +++ b/src/models.rs @@ -26,7 +26,7 @@ pub enum DType { impl DType { /// Returns the size of the associated type in bytes. - fn size_of(self) -> usize { + pub fn size_of(self) -> usize { match self { Self::Int32 => std::mem::size_of::(), Self::Int64 => std::mem::size_of::(), @@ -164,16 +164,47 @@ fn validate_shape_selection( Ok(()) } +/// Validate raw data size against data type and shape. +/// +/// # Arguments +/// +/// * `raw_size`: Raw (uncompressed) size of the data in bytes. +/// * `dtype`: Data type +/// * `shape`: Optional shape of the multi-dimensional array +pub fn validate_raw_size( + raw_size: usize, + dtype: DType, + shape: &Option>, +) -> Result<(), ValidationError> { + let dtype_size = dtype.size_of(); + if let Some(shape) = shape { + let expected_size = shape.iter().product::() * dtype_size; + if raw_size != expected_size { + let mut error = + ValidationError::new("Raw data size must be equal to the product of shape indices and dtype size in bytes"); + error.add_param("raw size".into(), &raw_size); + error.add_param("dtype size".into(), &dtype_size); + error.add_param("expected size".into(), &expected_size); + return Err(error); + } + } else if raw_size % dtype_size != 0 { + let mut error = + ValidationError::new("Raw data size must be a multiple of dtype size in bytes"); + error.add_param("raw size".into(), &raw_size); + error.add_param("dtype size".into(), &dtype_size); + return Err(error); + } + Ok(()) +} + /// Validate request data fn validate_request_data(request_data: &RequestData) -> Result<(), ValidationError> { // Validation of multiple fields in RequestData. if let Some(size) = &request_data.size { - let dtype_size = request_data.dtype.size_of(); - if size % dtype_size != 0 { - let mut error = ValidationError::new("Size must be a multiple of dtype size in bytes"); - error.add_param("size".into(), &size); - error.add_param("dtype size".into(), &dtype_size); - return Err(error); + // If the data is compressed then the size refers to the size of the compressed data, so we + // can't validate it at this point. + if request_data.compression.is_none() { + validate_raw_size(*size, request_data.dtype, &request_data.shape)?; } }; match (&request_data.shape, &request_data.selection) { @@ -531,13 +562,24 @@ mod tests { } #[test] - #[should_panic(expected = "Size must be a multiple of dtype size in bytes")] + #[should_panic(expected = "Raw data size must be a multiple of dtype size in bytes")] fn test_invalid_size_for_dtype() { let mut request_data = get_test_request_data(); request_data.size = Some(1); request_data.validate().unwrap() } + #[test] + #[should_panic( + expected = "Raw data size must be equal to the product of shape indices and dtype size in bytes" + )] + fn test_invalid_size_for_shape() { + let mut request_data = get_test_request_data(); + request_data.size = Some(4); + request_data.shape = Some(vec![1, 2]); + request_data.validate().unwrap() + } + #[test] #[should_panic(expected = "Shape and selection must have the same length")] fn test_shape_selection_mismatch() { From c3490bb03ad71876ac7061a1f6062c4833930e89 Mon Sep 17 00:00:00 2001 From: Mark Goddard Date: Wed, 12 Jul 2023 16:06:15 +0100 Subject: [PATCH 3/3] compression: Use id as an internal enum tag This allows us to more easily support compression algorithms that have decompression parameters. --- src/models.rs | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/models.rs b/src/models.rs index fffbcf1..19cd4da 100644 --- a/src/models.rs +++ b/src/models.rs @@ -88,6 +88,7 @@ impl Slice { /// Compression algorithm #[derive(Clone, Copy, Debug, Deserialize, PartialEq)] #[serde(rename_all = "lowercase")] +#[serde(tag = "id")] pub enum Compression { /// Gzip Gzip, @@ -362,11 +363,10 @@ mod tests { Token::SeqEnd, Token::Str("compression"), Token::Some, - Token::Enum { - name: "Compression", - }, + Token::Map { len: None }, + Token::Str("id"), Token::Str("gzip"), - Token::Unit, + Token::MapEnd, Token::StructEnd, ], ); @@ -643,11 +643,10 @@ mod tests { }, Token::Str("compression"), Token::Some, - Token::Enum { - name: "Compression", - }, + Token::Map { len: None }, + Token::Str("id"), Token::Str("foo"), - Token::StructEnd, + Token::MapEnd, ], "unknown variant `foo`, expected `gzip` or `zlib`", ) @@ -675,7 +674,7 @@ mod tests { #[test] fn test_json_optional_fields() { - let json = r#"{"source": "http://example.com", "bucket": "bar", "object": "baz", "dtype": "int32", "offset": 4, "size": 8, "shape": [2, 5], "order": "C", "selection": [[1, 2, 3], [4, 5, 6]], "compression": "gzip"}"#; + let json = r#"{"source": "http://example.com", "bucket": "bar", "object": "baz", "dtype": "int32", "offset": 4, "size": 8, "shape": [2, 5], "order": "C", "selection": [[1, 2, 3], [4, 5, 6]], "compression": {"id": "gzip"}}"#; let request_data = serde_json::from_str::(json).unwrap(); assert_eq!(request_data, get_test_request_data_optional()); }