diff --git a/Cargo.toml b/Cargo.toml
index d3bff76..6b93aab 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -6,14 +6,21 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
+async-trait = "0.1"
axum = "0.6"
hyper = { version = "0.14", features = ["full"] }
+mime = "0.3"
serde = { version = "1.0", features = ["derive"] }
serde_json = "*"
+strum = { version = "0.24", features = ["derive"] }
+strum_macros = "0.24"
+thiserror = "1.0"
tokio = { version = "1.24", features = ["full"] }
tower = "0.4"
+tower-http = { version = "0.3", features = ["trace", "validate-request"] }
url = { version = "2", features = ["serde"] }
validator = { version = "0.16", features = ["derive"] }
[dev-dependencies]
+regex = "1"
serde_test = "1.0"
diff --git a/src/app.rs b/src/app.rs
index bc66dea..60be4e5 100644
--- a/src/app.rs
+++ b/src/app.rs
@@ -1,11 +1,40 @@
use crate::models;
+use crate::validated_json::ValidatedJson;
use axum::{
- extract::Json,
+ body::Body,
+ http::header,
+ http::Request,
+ http::StatusCode,
+ response::{IntoResponse, Response},
routing::{get, post},
Router,
};
+use tower::ServiceBuilder;
+use tower_http::trace::TraceLayer;
+use tower_http::validate_request::ValidateRequestHeaderLayer;
+
+static HEADER_DTYPE: header::HeaderName = header::HeaderName::from_static("x-activestorage-dtype");
+static HEADER_SHAPE: header::HeaderName = header::HeaderName::from_static("x-activestorage-shape");
+
+impl IntoResponse for models::Response {
+ fn into_response(self) -> Response {
+ (
+ [
+ (
+ &header::CONTENT_TYPE,
+ mime::APPLICATION_OCTET_STREAM.to_string(),
+ ),
+ (&HEADER_DTYPE, self.dtype.to_string().to_lowercase()),
+ (&HEADER_SHAPE, serde_json::to_string(&self.shape).unwrap()),
+ ],
+ self.result,
+ )
+ .into_response()
+ }
+}
+
pub fn router() -> Router {
fn v1() -> Router {
Router::new()
@@ -15,6 +44,20 @@ pub fn router() -> Router {
.route("/min", post(min))
.route("/select", post(select))
.route("/sum", post(sum))
+ .layer(
+ ServiceBuilder::new()
+ .layer(TraceLayer::new_for_http())
+ .layer(ValidateRequestHeaderLayer::custom(
+ // Validate that an authorization header has been provided.
+ |request: &mut Request
| {
+ if request.headers().contains_key(header::AUTHORIZATION) {
+ Ok(())
+ } else {
+ Err(StatusCode::UNAUTHORIZED.into_response())
+ }
+ },
+ )),
+ )
}
Router::new()
@@ -26,26 +69,54 @@ async fn schema() -> &'static str {
"Hello, world!"
}
-async fn count(Json(request_data): Json) -> String {
- format!("Hello, {}!", request_data.source)
+async fn count(
+ ValidatedJson(request_data): ValidatedJson,
+) -> models::Response {
+ models::Response::new(
+ request_data.source.to_string(),
+ models::DType::Int32,
+ vec![],
+ )
}
-async fn max(Json(request_data): Json) -> String {
- format!("Hello, {}!", request_data.source)
+async fn max(ValidatedJson(request_data): ValidatedJson) -> models::Response {
+ models::Response::new(
+ request_data.source.to_string(),
+ models::DType::Int32,
+ vec![],
+ )
}
-async fn mean(Json(request_data): Json) -> String {
- format!("Hello, {}!", request_data.source)
+async fn mean(ValidatedJson(request_data): ValidatedJson) -> models::Response {
+ models::Response::new(
+ request_data.source.to_string(),
+ models::DType::Int32,
+ vec![],
+ )
}
-async fn min(Json(request_data): Json) -> String {
- format!("Hello, {}!", request_data.source)
+async fn min(ValidatedJson(request_data): ValidatedJson) -> models::Response {
+ models::Response::new(
+ request_data.source.to_string(),
+ models::DType::Int32,
+ vec![],
+ )
}
-async fn select(Json(request_data): Json) -> String {
- format!("Hello, {}!", request_data.source)
+async fn select(
+ ValidatedJson(request_data): ValidatedJson,
+) -> models::Response {
+ models::Response::new(
+ request_data.source.to_string(),
+ models::DType::Int32,
+ vec![],
+ )
}
-async fn sum(Json(request_data): Json) -> String {
- format!("Hello, {}!", request_data.source)
+async fn sum(ValidatedJson(request_data): ValidatedJson) -> models::Response {
+ models::Response::new(
+ request_data.source.to_string(),
+ models::DType::Int32,
+ vec![],
+ )
}
diff --git a/src/main.rs b/src/main.rs
index 3fbbcaf..f0a8b83 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -2,6 +2,7 @@ use tokio::signal;
mod app;
mod models;
+mod validated_json;
#[tokio::main]
async fn main() {
diff --git a/src/models.rs b/src/models.rs
index 5dc2a7b..478130d 100644
--- a/src/models.rs
+++ b/src/models.rs
@@ -1,8 +1,9 @@
use serde::{Deserialize, Serialize};
+use strum_macros::Display;
use url::Url;
use validator::{Validate, ValidationError};
-#[derive(Debug, Deserialize, PartialEq)]
+#[derive(Debug, Deserialize, Display, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum DType {
Int32,
@@ -27,7 +28,7 @@ pub enum Order {
pub struct Slice {
pub start: u32,
pub end: u32,
- #[validate(range(min = 1))]
+ #[validate(range(min = 1, message = "stride must be greater than 0"))]
pub stride: u32,
}
@@ -37,19 +38,19 @@ pub struct Slice {
pub struct RequestData {
// TODO: Investigate using lifetimes to enable zero-copy: https://serde.rs/lifetimes.html
pub source: Url,
- #[validate(length(min = 1))]
+ #[validate(length(min = 1, message = "bucket must not be empty"))]
pub bucket: String,
- #[validate(length(min = 1))]
+ #[validate(length(min = 1, message = "object must not be empty"))]
pub object: String,
pub dtype: DType,
pub offset: Option,
- #[validate(range(min = 1))]
+ #[validate(range(min = 1, message = "size must be greater than 0"))]
pub size: Option,
- #[validate(length(min = 1))]
+ #[validate(length(min = 1, message = "shape length must be greater than 0"))]
pub shape: Option>,
pub order: Option,
#[validate]
- #[validate(length(min = 1))]
+ #[validate(length(min = 1, message = "selection length must be greater than 0"))]
pub selection: Option>,
}
@@ -78,6 +79,23 @@ fn validate_request_data(request_data: &RequestData) -> Result<(), ValidationErr
Ok(())
}
+/// Response containing the result of a computation and associated metadata.
+pub struct Response {
+ pub result: String,
+ pub dtype: DType,
+ pub shape: Vec,
+}
+
+impl Response {
+ pub fn new(result: String, dtype: DType, shape: Vec) -> Response {
+ Response {
+ result,
+ dtype,
+ shape,
+ }
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -255,7 +273,7 @@ mod tests {
}
#[test]
- #[should_panic(expected = "bucket")]
+ #[should_panic(expected = "bucket must not be empty")]
fn test_invalid_bucket() {
let mut request_data = get_test_request_data();
request_data.bucket = "".to_string();
@@ -281,7 +299,7 @@ mod tests {
}
#[test]
- #[should_panic(expected = "object")]
+ #[should_panic(expected = "object must not be empty")]
fn test_invalid_object() {
let mut request_data = get_test_request_data();
request_data.object = "".to_string();
@@ -322,7 +340,7 @@ mod tests {
}
#[test]
- #[should_panic(expected = "size")]
+ #[should_panic(expected = "size must be greater than 0")]
fn test_invalid_size() {
let mut request_data = get_test_request_data();
request_data.size = Some(0);
@@ -330,7 +348,7 @@ mod tests {
}
#[test]
- #[should_panic(expected = "shape")]
+ #[should_panic(expected = "shape length must be greater than 0")]
fn test_invalid_shape() {
let mut request_data = get_test_request_data();
request_data.shape = Some(vec![]);
@@ -356,7 +374,7 @@ mod tests {
}
#[test]
- #[should_panic(expected = "selection")]
+ #[should_panic(expected = "selection length must be greater than 0")]
fn test_invalid_selection() {
let mut request_data = get_test_request_data();
request_data.selection = Some(vec![]);
@@ -364,7 +382,7 @@ mod tests {
}
#[test]
- #[should_panic(expected = "selection")]
+ #[should_panic(expected = "stride must be greater than 0")]
fn test_invalid_selection2() {
let mut request_data = get_test_request_data();
request_data.selection = Some(vec![Slice {
diff --git a/src/validated_json.rs b/src/validated_json.rs
new file mode 100644
index 0000000..3dd0ec3
--- /dev/null
+++ b/src/validated_json.rs
@@ -0,0 +1,172 @@
+use async_trait::async_trait;
+use axum::{
+ extract::{rejection::JsonRejection, FromRequest, Json},
+ http::{Request, StatusCode},
+ response::{IntoResponse, Response},
+};
+use serde::de::DeserializeOwned;
+use thiserror::Error;
+use validator::Validate;
+
+/// An axum extractor based on the Json extractor that also performs validation using the validator
+/// crate.
+#[derive(Debug, Clone, Copy, Default)]
+pub struct ValidatedJson(pub T);
+
+#[async_trait]
+impl FromRequest for ValidatedJson
+where
+ T: DeserializeOwned + Validate,
+ S: Send + Sync,
+ Json: FromRequest,
+ B: Send + 'static,
+{
+ type Rejection = JsonError;
+
+ async fn from_request(req: Request, state: &S) -> Result {
+ let Json(value) = Json::::from_request(req, state).await?;
+ value.validate()?;
+ Ok(ValidatedJson(value))
+ }
+}
+
+#[derive(Debug, Error)]
+pub enum JsonError {
+ #[error(transparent)]
+ ValidationError(#[from] validator::ValidationErrors),
+
+ #[error(transparent)]
+ AxumJsonRejection(#[from] JsonRejection),
+}
+
+impl IntoResponse for JsonError {
+ fn into_response(self) -> Response {
+ match self {
+ JsonError::ValidationError(_) => {
+ let message = format!("Input validation error: [{}]", self).replace('\n', ", ");
+ (StatusCode::BAD_REQUEST, message)
+ }
+ JsonError::AxumJsonRejection(_) => (StatusCode::BAD_REQUEST, self.to_string()),
+ }
+ .into_response()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ // https://github.com/tokio-rs/axum/blob/main/examples/testing/src/main.rs
+
+ use super::*;
+ use axum::{
+ body::Body,
+ http::{self, Request, StatusCode},
+ routing::post,
+ Router,
+ };
+ use regex::Regex;
+ use serde::{Deserialize, Serialize};
+ use tower::ServiceExt; // for `oneshot` and `ready`
+
+ #[derive(Deserialize, Validate, Serialize)]
+ struct TestPayload {
+ #[validate(length(min = 1, max = 3))]
+ pub foo: String,
+ pub bar: Option,
+ }
+
+ // Handler function that accepts a ValidatedJson extractor.
+ async fn test_handler(ValidatedJson(payload): ValidatedJson) -> String {
+ format!("foo: {} bar: {:?}", payload.foo, payload.bar)
+ }
+
+ // Build a router and make a oneshot request.
+ async fn request(body: Body) -> Response {
+ Router::new()
+ .route("/", post(test_handler))
+ .oneshot(
+ Request::builder()
+ .method(http::Method::POST)
+ .uri("/")
+ .header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref())
+ .body(body)
+ .unwrap(),
+ )
+ .await
+ .unwrap()
+ }
+
+ // Jump through the hoops to get the body as a string.
+ async fn body_string(response: Response) -> String {
+ String::from_utf8(
+ hyper::body::to_bytes(response.into_body())
+ .await
+ .unwrap()
+ .to_vec(),
+ )
+ .unwrap()
+ }
+
+ #[tokio::test]
+ async fn ok() {
+ let body = Body::from("{\"foo\": \"abc\", \"bar\": 123}");
+ let response = request(body).await;
+
+ assert_eq!(response.status(), StatusCode::OK);
+
+ let body = body_string(response).await;
+ assert_eq!(&body[..], "foo: abc bar: Some(123)");
+ }
+
+ #[tokio::test]
+ async fn invalid_json() {
+ let body = Body::from("{\"");
+ let response = request(body).await;
+
+ assert_eq!(response.status(), StatusCode::BAD_REQUEST);
+
+ let body = body_string(response).await;
+ assert_eq!(&body[..], "Failed to parse the request body as JSON");
+ }
+
+ #[tokio::test]
+ async fn invalid_foo_type() {
+ let body = Body::from("{\"foo\": 123}");
+ let response = request(body).await;
+
+ assert_eq!(response.status(), StatusCode::BAD_REQUEST);
+
+ let body = body_string(response).await;
+ assert_eq!(
+ &body[..],
+ "Failed to deserialize the JSON body into the target type"
+ );
+ }
+
+ #[tokio::test]
+ async fn invalid_foo_too_short() {
+ let body = Body::from("{\"foo\": \"\"}");
+ let response = request(body).await;
+
+ assert_eq!(response.status(), StatusCode::BAD_REQUEST);
+
+ let body = body_string(response).await;
+ let re =
+ Regex::new(r"Input validation error: \[foo: Validation error: length \[\{.*\}\]\]")
+ .unwrap();
+ assert!(re.is_match(&body[..]))
+ }
+
+ #[tokio::test]
+ async fn invalid_foo_too_long() {
+ let body = Body::from("{\"foo\": \"abcd\"}");
+ let response = request(body).await;
+
+ assert_eq!(response.status(), StatusCode::BAD_REQUEST);
+
+ let body = body_string(response).await;
+ let re =
+ Regex::new(r"Input validation error: \[foo: Validation error: length \[\{.*\}\]\]")
+ .unwrap();
+ assert!(re.is_match(&body[..]))
+ }
+}