diff --git a/Cargo.toml b/Cargo.toml index d3bff76..6b93aab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,14 +6,21 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +async-trait = "0.1" axum = "0.6" hyper = { version = "0.14", features = ["full"] } +mime = "0.3" serde = { version = "1.0", features = ["derive"] } serde_json = "*" +strum = { version = "0.24", features = ["derive"] } +strum_macros = "0.24" +thiserror = "1.0" tokio = { version = "1.24", features = ["full"] } tower = "0.4" +tower-http = { version = "0.3", features = ["trace", "validate-request"] } url = { version = "2", features = ["serde"] } validator = { version = "0.16", features = ["derive"] } [dev-dependencies] +regex = "1" serde_test = "1.0" diff --git a/src/app.rs b/src/app.rs index bc66dea..60be4e5 100644 --- a/src/app.rs +++ b/src/app.rs @@ -1,11 +1,40 @@ use crate::models; +use crate::validated_json::ValidatedJson; use axum::{ - extract::Json, + body::Body, + http::header, + http::Request, + http::StatusCode, + response::{IntoResponse, Response}, routing::{get, post}, Router, }; +use tower::ServiceBuilder; +use tower_http::trace::TraceLayer; +use tower_http::validate_request::ValidateRequestHeaderLayer; + +static HEADER_DTYPE: header::HeaderName = header::HeaderName::from_static("x-activestorage-dtype"); +static HEADER_SHAPE: header::HeaderName = header::HeaderName::from_static("x-activestorage-shape"); + +impl IntoResponse for models::Response { + fn into_response(self) -> Response { + ( + [ + ( + &header::CONTENT_TYPE, + mime::APPLICATION_OCTET_STREAM.to_string(), + ), + (&HEADER_DTYPE, self.dtype.to_string().to_lowercase()), + (&HEADER_SHAPE, serde_json::to_string(&self.shape).unwrap()), + ], + self.result, + ) + .into_response() + } +} + pub fn router() -> Router { fn v1() -> Router { Router::new() @@ -15,6 +44,20 @@ pub fn router() -> Router { .route("/min", post(min)) .route("/select", post(select)) .route("/sum", post(sum)) + .layer( + ServiceBuilder::new() + .layer(TraceLayer::new_for_http()) + .layer(ValidateRequestHeaderLayer::custom( + // Validate that an authorization header has been provided. + |request: &mut Request| { + if request.headers().contains_key(header::AUTHORIZATION) { + Ok(()) + } else { + Err(StatusCode::UNAUTHORIZED.into_response()) + } + }, + )), + ) } Router::new() @@ -26,26 +69,54 @@ async fn schema() -> &'static str { "Hello, world!" } -async fn count(Json(request_data): Json) -> String { - format!("Hello, {}!", request_data.source) +async fn count( + ValidatedJson(request_data): ValidatedJson, +) -> models::Response { + models::Response::new( + request_data.source.to_string(), + models::DType::Int32, + vec![], + ) } -async fn max(Json(request_data): Json) -> String { - format!("Hello, {}!", request_data.source) +async fn max(ValidatedJson(request_data): ValidatedJson) -> models::Response { + models::Response::new( + request_data.source.to_string(), + models::DType::Int32, + vec![], + ) } -async fn mean(Json(request_data): Json) -> String { - format!("Hello, {}!", request_data.source) +async fn mean(ValidatedJson(request_data): ValidatedJson) -> models::Response { + models::Response::new( + request_data.source.to_string(), + models::DType::Int32, + vec![], + ) } -async fn min(Json(request_data): Json) -> String { - format!("Hello, {}!", request_data.source) +async fn min(ValidatedJson(request_data): ValidatedJson) -> models::Response { + models::Response::new( + request_data.source.to_string(), + models::DType::Int32, + vec![], + ) } -async fn select(Json(request_data): Json) -> String { - format!("Hello, {}!", request_data.source) +async fn select( + ValidatedJson(request_data): ValidatedJson, +) -> models::Response { + models::Response::new( + request_data.source.to_string(), + models::DType::Int32, + vec![], + ) } -async fn sum(Json(request_data): Json) -> String { - format!("Hello, {}!", request_data.source) +async fn sum(ValidatedJson(request_data): ValidatedJson) -> models::Response { + models::Response::new( + request_data.source.to_string(), + models::DType::Int32, + vec![], + ) } diff --git a/src/main.rs b/src/main.rs index 3fbbcaf..f0a8b83 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,7 @@ use tokio::signal; mod app; mod models; +mod validated_json; #[tokio::main] async fn main() { diff --git a/src/models.rs b/src/models.rs index 5dc2a7b..478130d 100644 --- a/src/models.rs +++ b/src/models.rs @@ -1,8 +1,9 @@ use serde::{Deserialize, Serialize}; +use strum_macros::Display; use url::Url; use validator::{Validate, ValidationError}; -#[derive(Debug, Deserialize, PartialEq)] +#[derive(Debug, Deserialize, Display, PartialEq)] #[serde(rename_all = "lowercase")] pub enum DType { Int32, @@ -27,7 +28,7 @@ pub enum Order { pub struct Slice { pub start: u32, pub end: u32, - #[validate(range(min = 1))] + #[validate(range(min = 1, message = "stride must be greater than 0"))] pub stride: u32, } @@ -37,19 +38,19 @@ pub struct Slice { pub struct RequestData { // TODO: Investigate using lifetimes to enable zero-copy: https://serde.rs/lifetimes.html pub source: Url, - #[validate(length(min = 1))] + #[validate(length(min = 1, message = "bucket must not be empty"))] pub bucket: String, - #[validate(length(min = 1))] + #[validate(length(min = 1, message = "object must not be empty"))] pub object: String, pub dtype: DType, pub offset: Option, - #[validate(range(min = 1))] + #[validate(range(min = 1, message = "size must be greater than 0"))] pub size: Option, - #[validate(length(min = 1))] + #[validate(length(min = 1, message = "shape length must be greater than 0"))] pub shape: Option>, pub order: Option, #[validate] - #[validate(length(min = 1))] + #[validate(length(min = 1, message = "selection length must be greater than 0"))] pub selection: Option>, } @@ -78,6 +79,23 @@ fn validate_request_data(request_data: &RequestData) -> Result<(), ValidationErr Ok(()) } +/// Response containing the result of a computation and associated metadata. +pub struct Response { + pub result: String, + pub dtype: DType, + pub shape: Vec, +} + +impl Response { + pub fn new(result: String, dtype: DType, shape: Vec) -> Response { + Response { + result, + dtype, + shape, + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -255,7 +273,7 @@ mod tests { } #[test] - #[should_panic(expected = "bucket")] + #[should_panic(expected = "bucket must not be empty")] fn test_invalid_bucket() { let mut request_data = get_test_request_data(); request_data.bucket = "".to_string(); @@ -281,7 +299,7 @@ mod tests { } #[test] - #[should_panic(expected = "object")] + #[should_panic(expected = "object must not be empty")] fn test_invalid_object() { let mut request_data = get_test_request_data(); request_data.object = "".to_string(); @@ -322,7 +340,7 @@ mod tests { } #[test] - #[should_panic(expected = "size")] + #[should_panic(expected = "size must be greater than 0")] fn test_invalid_size() { let mut request_data = get_test_request_data(); request_data.size = Some(0); @@ -330,7 +348,7 @@ mod tests { } #[test] - #[should_panic(expected = "shape")] + #[should_panic(expected = "shape length must be greater than 0")] fn test_invalid_shape() { let mut request_data = get_test_request_data(); request_data.shape = Some(vec![]); @@ -356,7 +374,7 @@ mod tests { } #[test] - #[should_panic(expected = "selection")] + #[should_panic(expected = "selection length must be greater than 0")] fn test_invalid_selection() { let mut request_data = get_test_request_data(); request_data.selection = Some(vec![]); @@ -364,7 +382,7 @@ mod tests { } #[test] - #[should_panic(expected = "selection")] + #[should_panic(expected = "stride must be greater than 0")] fn test_invalid_selection2() { let mut request_data = get_test_request_data(); request_data.selection = Some(vec![Slice { diff --git a/src/validated_json.rs b/src/validated_json.rs new file mode 100644 index 0000000..3dd0ec3 --- /dev/null +++ b/src/validated_json.rs @@ -0,0 +1,172 @@ +use async_trait::async_trait; +use axum::{ + extract::{rejection::JsonRejection, FromRequest, Json}, + http::{Request, StatusCode}, + response::{IntoResponse, Response}, +}; +use serde::de::DeserializeOwned; +use thiserror::Error; +use validator::Validate; + +/// An axum extractor based on the Json extractor that also performs validation using the validator +/// crate. +#[derive(Debug, Clone, Copy, Default)] +pub struct ValidatedJson(pub T); + +#[async_trait] +impl FromRequest for ValidatedJson +where + T: DeserializeOwned + Validate, + S: Send + Sync, + Json: FromRequest, + B: Send + 'static, +{ + type Rejection = JsonError; + + async fn from_request(req: Request, state: &S) -> Result { + let Json(value) = Json::::from_request(req, state).await?; + value.validate()?; + Ok(ValidatedJson(value)) + } +} + +#[derive(Debug, Error)] +pub enum JsonError { + #[error(transparent)] + ValidationError(#[from] validator::ValidationErrors), + + #[error(transparent)] + AxumJsonRejection(#[from] JsonRejection), +} + +impl IntoResponse for JsonError { + fn into_response(self) -> Response { + match self { + JsonError::ValidationError(_) => { + let message = format!("Input validation error: [{}]", self).replace('\n', ", "); + (StatusCode::BAD_REQUEST, message) + } + JsonError::AxumJsonRejection(_) => (StatusCode::BAD_REQUEST, self.to_string()), + } + .into_response() + } +} + +#[cfg(test)] +mod tests { + // https://github.com/tokio-rs/axum/blob/main/examples/testing/src/main.rs + + use super::*; + use axum::{ + body::Body, + http::{self, Request, StatusCode}, + routing::post, + Router, + }; + use regex::Regex; + use serde::{Deserialize, Serialize}; + use tower::ServiceExt; // for `oneshot` and `ready` + + #[derive(Deserialize, Validate, Serialize)] + struct TestPayload { + #[validate(length(min = 1, max = 3))] + pub foo: String, + pub bar: Option, + } + + // Handler function that accepts a ValidatedJson extractor. + async fn test_handler(ValidatedJson(payload): ValidatedJson) -> String { + format!("foo: {} bar: {:?}", payload.foo, payload.bar) + } + + // Build a router and make a oneshot request. + async fn request(body: Body) -> Response { + Router::new() + .route("/", post(test_handler)) + .oneshot( + Request::builder() + .method(http::Method::POST) + .uri("/") + .header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) + .body(body) + .unwrap(), + ) + .await + .unwrap() + } + + // Jump through the hoops to get the body as a string. + async fn body_string(response: Response) -> String { + String::from_utf8( + hyper::body::to_bytes(response.into_body()) + .await + .unwrap() + .to_vec(), + ) + .unwrap() + } + + #[tokio::test] + async fn ok() { + let body = Body::from("{\"foo\": \"abc\", \"bar\": 123}"); + let response = request(body).await; + + assert_eq!(response.status(), StatusCode::OK); + + let body = body_string(response).await; + assert_eq!(&body[..], "foo: abc bar: Some(123)"); + } + + #[tokio::test] + async fn invalid_json() { + let body = Body::from("{\""); + let response = request(body).await; + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + + let body = body_string(response).await; + assert_eq!(&body[..], "Failed to parse the request body as JSON"); + } + + #[tokio::test] + async fn invalid_foo_type() { + let body = Body::from("{\"foo\": 123}"); + let response = request(body).await; + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + + let body = body_string(response).await; + assert_eq!( + &body[..], + "Failed to deserialize the JSON body into the target type" + ); + } + + #[tokio::test] + async fn invalid_foo_too_short() { + let body = Body::from("{\"foo\": \"\"}"); + let response = request(body).await; + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + + let body = body_string(response).await; + let re = + Regex::new(r"Input validation error: \[foo: Validation error: length \[\{.*\}\]\]") + .unwrap(); + assert!(re.is_match(&body[..])) + } + + #[tokio::test] + async fn invalid_foo_too_long() { + let body = Body::from("{\"foo\": \"abcd\"}"); + let response = request(body).await; + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + + let body = body_string(response).await; + let re = + Regex::new(r"Input validation error: \[foo: Validation error: length \[\{.*\}\]\]") + .unwrap(); + assert!(re.is_match(&body[..])) + } +}