diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7a1cd98a2..d39aff24e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,21 +12,14 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v1 + uses: actions/checkout@v3 - name: Install rust - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@stable with: - toolchain: stable components: rustfmt - profile: minimal - override: true - - name: cargo fmt -- --check - uses: actions-rs/cargo@v1 - with: - command: fmt - args: --all -- --check + - run: cargo fmt --all --check test: name: Test @@ -50,30 +43,21 @@ jobs: - build: compression features: "--features compression" - steps: - name: Checkout - uses: actions/checkout@v1 + uses: actions/checkout@v3 - name: Install rust - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@master with: toolchain: ${{ matrix.rust || 'stable' }} - profile: minimal - override: true - name: Test - uses: actions-rs/cargo@v1 - with: - command: test - args: ${{ matrix.features }} + run: cargo test ${{ matrix.features }} - name: Test all benches if: matrix.benches - uses: actions-rs/cargo@v1 - with: - command: test - args: --benches ${{ matrix.features }} + run: cargo test --benches ${{ matrix.features }} doc: name: Build docs @@ -81,17 +65,10 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v1 + uses: actions/checkout@v3 - name: Install Rust - uses: actions-rs/toolchain@v1 - with: - profile: minimal - toolchain: nightly - override: true + uses: dtolnay/rust-toolchain@nightly - name: cargo doc - uses: actions-rs/cargo@v1 - with: - command: rustdoc - args: -- -D broken_intra_doc_links + run: cargo rustdoc -- -D broken_intra_doc_links diff --git a/CHANGELOG.md b/CHANGELOG.md index dbf80b879..81ebd9e59 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +### v0.3.4 (March 31, 2023) + +- **Fixes**: + - `multipart::Part` data is now streamed instead of buffered. + - Update dependency used for `multipart` filters. + ### v0.3.3 (September 27, 2022) - **Fixes**: diff --git a/Cargo.toml b/Cargo.toml index 485de8e8a..cd20d3a13 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "warp" -version = "0.3.3" # don't forget to update html_root_url +version = "0.3.4" # don't forget to update html_root_url description = "serve the web at warp speeds" authors = ["Sean McArthur "] license = "MIT" @@ -27,7 +27,7 @@ hyper = { version = "0.14.19", features = ["stream", "server", "http1", "http2", log = "0.4" mime = "0.3" mime_guess = "2.0.0" -multipart = { version = "0.18", default-features = false, features = ["server"], optional = true } +multer = { version = "2.1.0", optional = true } scoped-tls = "1.0" serde = "1.0" serde_json = "1.0" @@ -37,24 +37,25 @@ tokio-stream = "0.1.1" tokio-util = { version = "0.7", features = ["io"] } tracing = { version = "0.1.21", default-features = false, features = ["log", "std"] } tower-service = "0.3" -tokio-tungstenite = { version = "0.17", optional = true } +tokio-tungstenite = { version = "0.18", optional = true } percent-encoding = "2.1" pin-project = "1.0" tokio-rustls = { version = "0.23", optional = true } -rustls-pemfile = "0.2" +rustls-pemfile = "1.0" [dev-dependencies] pretty_env_logger = "0.4" -tracing-subscriber = "0.2.7" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing-log = "0.1" serde_derive = "1.0" handlebars = "4.0" tokio = { version = "1.0", features = ["macros", "rt-multi-thread"] } tokio-stream = { version = "0.1.1", features = ["net"] } -listenfd = "0.3" +listenfd = "1.0" [features] default = ["multipart", "websocket"] +multipart = ["multer"] websocket = ["tokio-tungstenite"] tls = ["tokio-rustls"] @@ -96,3 +97,8 @@ required-features = ["websocket"] [[example]] name = "query_string" + + +[[example]] +name = "multipart" +required-features = ["multipart"] diff --git a/examples/multipart.rs b/examples/multipart.rs new file mode 100644 index 000000000..40548d1e2 --- /dev/null +++ b/examples/multipart.rs @@ -0,0 +1,28 @@ +use futures_util::TryStreamExt; +use warp::multipart::FormData; +use warp::Buf; +use warp::Filter; + +#[tokio::main] +async fn main() { + // Running curl -F file=@.gitignore 'localhost:3030/' should print [("file", ".gitignore", "\n/target\n**/*.rs.bk\nCargo.lock\n.idea/\nwarp.iml\n")] + let route = warp::multipart::form().and_then(|form: FormData| async move { + let field_names: Vec<_> = form + .and_then(|mut field| async move { + let contents = + String::from_utf8_lossy(field.data().await.unwrap().unwrap().chunk()) + .to_string(); + Ok(( + field.name().to_string(), + field.filename().unwrap().to_string(), + contents, + )) + }) + .try_collect() + .await + .unwrap(); + + Ok::<_, warp::Rejection>(format!("{:?}", field_names)) + }); + warp::serve(route).run(([127, 0, 0, 1], 3030)).await; +} diff --git a/examples/todos.rs b/examples/todos.rs index 904d604e8..ee5c3865a 100644 --- a/examples/todos.rs +++ b/examples/todos.rs @@ -38,7 +38,7 @@ mod filters { /// The 4 TODOs filters combined. pub fn todos( db: Db, - ) -> impl Filter + Clone { + ) -> impl Filter + Clone { todos_list(db.clone()) .or(todos_create(db.clone())) .or(todos_update(db.clone())) @@ -48,7 +48,7 @@ mod filters { /// GET /todos?offset=3&limit=5 pub fn todos_list( db: Db, - ) -> impl Filter + Clone { + ) -> impl Filter + Clone { warp::path!("todos") .and(warp::get()) .and(warp::query::()) @@ -59,7 +59,7 @@ mod filters { /// POST /todos with JSON body pub fn todos_create( db: Db, - ) -> impl Filter + Clone { + ) -> impl Filter + Clone { warp::path!("todos") .and(warp::post()) .and(json_body()) @@ -70,7 +70,7 @@ mod filters { /// PUT /todos/:id with JSON body pub fn todos_update( db: Db, - ) -> impl Filter + Clone { + ) -> impl Filter + Clone { warp::path!("todos" / u64) .and(warp::put()) .and(json_body()) @@ -81,7 +81,7 @@ mod filters { /// DELETE /todos/:id pub fn todos_delete( db: Db, - ) -> impl Filter + Clone { + ) -> impl Filter + Clone { // We'll make one of our endpoints admin-only to show how authentication filters are used let admin_only = warp::header::exact("authorization", "Bearer admin"); diff --git a/src/filters/multipart.rs b/src/filters/multipart.rs index ef2ec9268..434c2a165 100644 --- a/src/filters/multipart.rs +++ b/src/filters/multipart.rs @@ -2,17 +2,19 @@ //! //! Filters that extract a multipart body for a route. -use std::fmt; +use std::error::Error as StdError; +use std::fmt::{Display, Formatter}; use std::future::Future; -use std::io::{Cursor, Read}; use std::pin::Pin; use std::task::{Context, Poll}; +use std::{fmt, io}; use bytes::{Buf, Bytes}; use futures_util::{future, Stream}; use headers::ContentType; +use hyper::Body; use mime::Mime; -use multipart::server::Multipart; +use multer::{Field as PartInner, Multipart as FormDataInner}; use crate::filter::{Filter, FilterBase, Internal}; use crate::reject::{self, Rejection}; @@ -32,17 +34,14 @@ pub struct FormOptions { /// /// Extracted with a `warp::multipart::form` filter. pub struct FormData { - inner: Multipart>, + inner: FormDataInner<'static>, } /// A single "part" of a multipart/form-data body. /// /// Yielded from the `FormData` stream. pub struct Part { - name: String, - filename: Option, - content_type: Option, - data: Option>, + part: PartInner<'static>, } /// Create a `Filter` to extract a `multipart/form-data` body from a request. @@ -86,9 +85,12 @@ impl FilterBase for FormOptions { let filt = super::body::content_length_limit(self.max_length) .and(boundary) - .and(super::body::bytes()) - .map(|boundary, body| FormData { - inner: Multipart::with_body(Cursor::new(body), boundary), + .and(super::body::body()) + .map(|boundary: String, body| { + let body = BodyIoError(body); + FormData { + inner: FormDataInner::new(body, &boundary), + } }); let fut = filt.filter(Internal); @@ -108,23 +110,18 @@ impl fmt::Debug for FormData { impl Stream for FormData { type Item = Result; - fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - match (*self).inner.read_entry() { - Ok(Some(mut field)) => { - let mut data = Vec::new(); - field - .data - .read_to_end(&mut data) - .map_err(crate::Error::new)?; - Poll::Ready(Some(Ok(Part { - name: field.headers.name.to_string(), - filename: field.headers.filename, - content_type: field.headers.content_type.map(|m| m.to_string()), - data: Some(data), - }))) + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.inner.poll_next_field(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(Some(part))) => { + if part.name().is_some() { + Poll::Ready(Some(Ok(Part { part }))) + } else { + Poll::Ready(Some(Err(crate::Error::new(MultipartFieldMissingName)))) + } } - Ok(None) => Poll::Ready(None), - Err(e) => Poll::Ready(Some(Err(crate::Error::new(e)))), + Poll::Ready(Ok(None)) => Poll::Ready(None), + Poll::Ready(Err(err)) => Poll::Ready(Some(Err(crate::Error::new(err)))), } } } @@ -134,22 +131,23 @@ impl Stream for FormData { impl Part { /// Get the name of this part. pub fn name(&self) -> &str { - &self.name + self.part.name().expect("checked for name previously") } /// Get the filename of this part, if present. pub fn filename(&self) -> Option<&str> { - self.filename.as_deref() + self.part.file_name() } /// Get the content-type of this part, if present. pub fn content_type(&self) -> Option<&str> { - self.content_type.as_deref() + let content_type = self.part.content_type(); + content_type.map(|t| t.type_().as_str()) } /// Asynchronously get some of the data for this `Part`. pub async fn data(&mut self) -> Option> { - self.take_data() + future::poll_fn(|cx| self.poll_next(cx)).await } /// Convert this `Part` into a `Stream` of `Buf`s. @@ -157,21 +155,26 @@ impl Part { PartStream(self) } - fn take_data(&mut self) -> Option> { - self.data.take().map(|vec| Ok(vec.into())) + fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll>> { + match Pin::new(&mut self.part).poll_next(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Some(Ok(bytes))) => Poll::Ready(Some(Ok(bytes))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(crate::Error::new(err)))), + } } } impl fmt::Debug for Part { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let mut builder = f.debug_struct("Part"); - builder.field("name", &self.name); + builder.field("name", &self.part.name()); - if let Some(ref filename) = self.filename { + if let Some(ref filename) = self.part.file_name() { builder.field("filename", filename); } - if let Some(ref mime) = self.content_type { + if let Some(ref mime) = self.part.content_type() { builder.field("content_type", mime); } @@ -184,7 +187,36 @@ struct PartStream(Part); impl Stream for PartStream { type Item = Result; - fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(self.0.take_data()) + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.0.poll_next(cx) } } + +struct BodyIoError(Body); + +impl Stream for BodyIoError { + type Item = io::Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match Pin::new(&mut self.0).poll_next(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Some(Ok(bytes))) => Poll::Ready(Some(Ok(bytes))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Ready(Some(Err(err))) => { + Poll::Ready(Some(Err(io::Error::new(io::ErrorKind::Other, err)))) + } + } + } +} + +/// An error used when a multipart field is missing a name. +#[derive(Debug)] +struct MultipartFieldMissingName; + +impl Display for MultipartFieldMissingName { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "Multipart field is missing a name") + } +} + +impl StdError for MultipartFieldMissingName {} diff --git a/src/lib.rs b/src/lib.rs index a965b1d1b..3f33a1260 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,4 @@ -#![doc(html_root_url = "https://docs.rs/warp/0.3.3")] +#![doc(html_root_url = "https://docs.rs/warp/0.3.4")] #![deny(missing_docs)] #![deny(missing_debug_implementations)] #![deny(rust_2018_idioms)] diff --git a/tests/tracing.rs b/tests/tracing.rs index 0ae88fbc3..cf87feda3 100644 --- a/tests/tracing.rs +++ b/tests/tracing.rs @@ -3,7 +3,7 @@ use warp::Filter; #[tokio::test] async fn uses_tracing() { // Setup a log subscriber (responsible to print to output) - let subscriber = tracing_subscriber::fmt::Subscriber::builder() + let subscriber = tracing_subscriber::fmt() .with_env_filter("trace") .without_time() .finish(); diff --git a/tests/ws.rs b/tests/ws.rs index d5b60356e..4e57e7f70 100644 --- a/tests/ws.rs +++ b/tests/ws.rs @@ -275,7 +275,7 @@ async fn ws_with_query() { } // Websocket filter that echoes all messages back. -fn ws_echo() -> impl Filter + Copy { +fn ws_echo() -> impl Filter + Copy { warp::ws().map(|ws: warp::ws::Ws| { ws.on_upgrade(|websocket| { // Just echo all messages back...