Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,21 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
async-trait = "0.1"
axum = "0.6"
hyper = { version = "0.14", features = ["full"] }
mime = "0.3"
serde = { version = "1.0", features = ["derive"] }
serde_json = "*"
strum = { version = "0.24", features = ["derive"] }
strum_macros = "0.24"
thiserror = "1.0"
tokio = { version = "1.24", features = ["full"] }
tower = "0.4"
tower-http = { version = "0.3", features = ["trace", "validate-request"] }
url = { version = "2", features = ["serde"] }
validator = { version = "0.16", features = ["derive"] }

[dev-dependencies]
regex = "1"
serde_test = "1.0"
97 changes: 84 additions & 13 deletions src/app.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,40 @@
use crate::models;
use crate::validated_json::ValidatedJson;

use axum::{
extract::Json,
body::Body,
http::header,
http::Request,
http::StatusCode,
response::{IntoResponse, Response},
routing::{get, post},
Router,
};

use tower::ServiceBuilder;
use tower_http::trace::TraceLayer;
use tower_http::validate_request::ValidateRequestHeaderLayer;

static HEADER_DTYPE: header::HeaderName = header::HeaderName::from_static("x-activestorage-dtype");
static HEADER_SHAPE: header::HeaderName = header::HeaderName::from_static("x-activestorage-shape");

impl IntoResponse for models::Response {
fn into_response(self) -> Response {
(
[
(
&header::CONTENT_TYPE,
mime::APPLICATION_OCTET_STREAM.to_string(),
),
(&HEADER_DTYPE, self.dtype.to_string().to_lowercase()),
(&HEADER_SHAPE, serde_json::to_string(&self.shape).unwrap()),
],
self.result,
)
.into_response()
}
}

pub fn router() -> Router {
fn v1() -> Router {
Router::new()
Expand All @@ -15,6 +44,20 @@ pub fn router() -> Router {
.route("/min", post(min))
.route("/select", post(select))
.route("/sum", post(sum))
.layer(
ServiceBuilder::new()
.layer(TraceLayer::new_for_http())
.layer(ValidateRequestHeaderLayer::custom(
// Validate that an authorization header has been provided.
|request: &mut Request<Body>| {
if request.headers().contains_key(header::AUTHORIZATION) {
Ok(())
} else {
Err(StatusCode::UNAUTHORIZED.into_response())
}
},
)),
)
}

Router::new()
Expand All @@ -26,26 +69,54 @@ async fn schema() -> &'static str {
"Hello, world!"
}

async fn count(Json(request_data): Json<models::RequestData>) -> String {
format!("Hello, {}!", request_data.source)
async fn count(
ValidatedJson(request_data): ValidatedJson<models::RequestData>,
) -> models::Response {
models::Response::new(
request_data.source.to_string(),
models::DType::Int32,
vec![],
)
}

async fn max(Json(request_data): Json<models::RequestData>) -> String {
format!("Hello, {}!", request_data.source)
async fn max(ValidatedJson(request_data): ValidatedJson<models::RequestData>) -> models::Response {
models::Response::new(
request_data.source.to_string(),
models::DType::Int32,
vec![],
)
}

async fn mean(Json(request_data): Json<models::RequestData>) -> String {
format!("Hello, {}!", request_data.source)
async fn mean(ValidatedJson(request_data): ValidatedJson<models::RequestData>) -> models::Response {
models::Response::new(
request_data.source.to_string(),
models::DType::Int32,
vec![],
)
}

async fn min(Json(request_data): Json<models::RequestData>) -> String {
format!("Hello, {}!", request_data.source)
async fn min(ValidatedJson(request_data): ValidatedJson<models::RequestData>) -> models::Response {
models::Response::new(
request_data.source.to_string(),
models::DType::Int32,
vec![],
)
}

async fn select(Json(request_data): Json<models::RequestData>) -> String {
format!("Hello, {}!", request_data.source)
async fn select(
ValidatedJson(request_data): ValidatedJson<models::RequestData>,
) -> models::Response {
models::Response::new(
request_data.source.to_string(),
models::DType::Int32,
vec![],
)
}

async fn sum(Json(request_data): Json<models::RequestData>) -> String {
format!("Hello, {}!", request_data.source)
async fn sum(ValidatedJson(request_data): ValidatedJson<models::RequestData>) -> models::Response {
models::Response::new(
request_data.source.to_string(),
models::DType::Int32,
vec![],
)
}
1 change: 1 addition & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use tokio::signal;

mod app;
mod models;
mod validated_json;

#[tokio::main]
async fn main() {
Expand Down
44 changes: 31 additions & 13 deletions src/models.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use serde::{Deserialize, Serialize};
use strum_macros::Display;
use url::Url;
use validator::{Validate, ValidationError};

#[derive(Debug, Deserialize, PartialEq)]
#[derive(Debug, Deserialize, Display, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum DType {
Int32,
Expand All @@ -27,7 +28,7 @@ pub enum Order {
pub struct Slice {
pub start: u32,
pub end: u32,
#[validate(range(min = 1))]
#[validate(range(min = 1, message = "stride must be greater than 0"))]
pub stride: u32,
}

Expand All @@ -37,19 +38,19 @@ pub struct Slice {
pub struct RequestData {
// TODO: Investigate using lifetimes to enable zero-copy: https://serde.rs/lifetimes.html
pub source: Url,
#[validate(length(min = 1))]
#[validate(length(min = 1, message = "bucket must not be empty"))]
pub bucket: String,
#[validate(length(min = 1))]
#[validate(length(min = 1, message = "object must not be empty"))]
pub object: String,
pub dtype: DType,
pub offset: Option<u32>,
#[validate(range(min = 1))]
#[validate(range(min = 1, message = "size must be greater than 0"))]
pub size: Option<u32>,
#[validate(length(min = 1))]
#[validate(length(min = 1, message = "shape length must be greater than 0"))]
pub shape: Option<Vec<u32>>,
pub order: Option<Order>,
#[validate]
#[validate(length(min = 1))]
#[validate(length(min = 1, message = "selection length must be greater than 0"))]
pub selection: Option<Vec<Slice>>,
}

Expand Down Expand Up @@ -78,6 +79,23 @@ fn validate_request_data(request_data: &RequestData) -> Result<(), ValidationErr
Ok(())
}

/// Response containing the result of a computation and associated metadata.
pub struct Response {
pub result: String,
pub dtype: DType,
pub shape: Vec<u32>,
}

impl Response {
pub fn new(result: String, dtype: DType, shape: Vec<u32>) -> Response {
Response {
result,
dtype,
shape,
}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -255,7 +273,7 @@ mod tests {
}

#[test]
#[should_panic(expected = "bucket")]
#[should_panic(expected = "bucket must not be empty")]
fn test_invalid_bucket() {
let mut request_data = get_test_request_data();
request_data.bucket = "".to_string();
Expand All @@ -281,7 +299,7 @@ mod tests {
}

#[test]
#[should_panic(expected = "object")]
#[should_panic(expected = "object must not be empty")]
fn test_invalid_object() {
let mut request_data = get_test_request_data();
request_data.object = "".to_string();
Expand Down Expand Up @@ -322,15 +340,15 @@ mod tests {
}

#[test]
#[should_panic(expected = "size")]
#[should_panic(expected = "size must be greater than 0")]
fn test_invalid_size() {
let mut request_data = get_test_request_data();
request_data.size = Some(0);
request_data.validate().unwrap()
}

#[test]
#[should_panic(expected = "shape")]
#[should_panic(expected = "shape length must be greater than 0")]
fn test_invalid_shape() {
let mut request_data = get_test_request_data();
request_data.shape = Some(vec![]);
Expand All @@ -356,15 +374,15 @@ mod tests {
}

#[test]
#[should_panic(expected = "selection")]
#[should_panic(expected = "selection length must be greater than 0")]
fn test_invalid_selection() {
let mut request_data = get_test_request_data();
request_data.selection = Some(vec![]);
request_data.validate().unwrap()
}

#[test]
#[should_panic(expected = "selection")]
#[should_panic(expected = "stride must be greater than 0")]
fn test_invalid_selection2() {
let mut request_data = get_test_request_data();
request_data.selection = Some(vec![Slice {
Expand Down
Loading