Skip to content

Commit

Permalink
Use 'Arc<Mutex>'. Add dynamic exclusivity check.
Browse files Browse the repository at this point in the history
This partially reverts dd7015f while also eliminating the potential for
deadlock by returning an error instead.
  • Loading branch information
SergioBenitez committed May 16, 2021
1 parent 50e15ee commit 2758e77
Show file tree
Hide file tree
Showing 8 changed files with 205 additions and 229 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ httparse = "1.3"
mime = "0.3"
encoding_rs = "0.8"
derive_more = "0.99"
spin = { version = "0.9", default-features = false, features = ["spin_mutex"] }

serde = { version = "1.0", optional = true }
serde_json = { version = "1.0", optional = true }
Expand Down
10 changes: 5 additions & 5 deletions src/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@ use futures_util::stream::Stream;

use crate::constants;

pub(crate) struct StreamBuffer {
pub(crate) struct StreamBuffer<'r> {
pub(crate) eof: bool,
pub(crate) buf: BytesMut,
pub(crate) stream: Pin<Box<dyn Stream<Item = Result<Bytes, crate::Error>> + Send>>,
pub(crate) stream: Pin<Box<dyn Stream<Item = Result<Bytes, crate::Error>> + Send + 'r>>,
pub(crate) whole_stream_size_limit: u64,
pub(crate) stream_size_counter: u64,
}

impl StreamBuffer {
impl<'r> StreamBuffer<'r> {
pub fn new<S>(stream: S, whole_stream_size_limit: u64) -> Self
where
S: Stream<Item = Result<Bytes, crate::Error>> + Send + 'static,
S: Stream<Item = Result<Bytes, crate::Error>> + Send + 'r,
{
StreamBuffer {
eof: false,
Expand Down Expand Up @@ -162,7 +162,7 @@ impl StreamBuffer {
}
}

impl fmt::Debug for StreamBuffer {
impl fmt::Debug for StreamBuffer<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("StreamBuffer").finish()
}
Expand Down
4 changes: 2 additions & 2 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ pub enum Error {
StreamReadFailed(BoxError),

/// Failed to lock the multipart shared state for any changes.
#[display(fmt = "failed to lock multipart state: {}", _0)]
LockFailure(BoxError),
#[display(fmt = "failed to lock multipart state")]
LockFailure,

/// The `Content-Type` header is not `multipart/form-data`.
#[display(fmt = "Content-Type is not multipart/form-data")]
Expand Down
76 changes: 35 additions & 41 deletions src/field.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::borrow::Cow;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use bytes::{Bytes, BytesMut};
Expand All @@ -8,10 +9,11 @@ use futures_util::stream::{Stream, TryStreamExt};
use http::header::HeaderMap;
#[cfg(feature = "json")]
use serde::de::DeserializeOwned;
use spin::mutex::spin::SpinMutex as Mutex;

use crate::content_disposition::ContentDisposition;
use crate::helpers;
use crate::state::{MultipartState, StreamingStage};
use crate::multipart::{MultipartState, StreamingStage};
use crate::{helpers, Error};

/// A single field in a multipart stream.
///
Expand Down Expand Up @@ -43,66 +45,51 @@ use crate::state::{MultipartState, StreamingStage};
///
/// [`Multipart`]: crate::Multipart
#[derive(Debug)]
pub struct Field<'a> {
state: &'a mut MultipartState,
pub struct Field<'r> {
state: Arc<Mutex<MultipartState<'r>>>,
done: bool,
data: FieldData,
}

/// Owned field data. This is used by `Multipart` to extend the lifetime of a
/// `Field`.
#[derive(Debug)]
pub(crate) struct FieldData {
headers: HeaderMap,
content_disposition: ContentDisposition,
content_type: Option<mime::Mime>,
idx: usize,
}

impl FieldData {
pub(crate) fn new(headers: HeaderMap, idx: usize, content_disposition: ContentDisposition) -> Self {
impl<'r> Field<'r> {
pub(crate) fn new(
state: Arc<Mutex<MultipartState<'r>>>,
headers: HeaderMap,
idx: usize,
content_disposition: ContentDisposition,
) -> Self {
let content_type = helpers::parse_content_type(&headers);

FieldData {
Field {
state,
headers,
content_disposition,
content_type,
idx,
}
}

pub(crate) fn name(&self) -> Option<&str> {
self.content_disposition.field_name.as_deref()
}
}

impl<'a> Field<'a> {
pub(crate) fn from_data(state: &'a mut MultipartState, data: FieldData) -> Self {
Self {
state,
done: false,
data,
}
}

/// The field name found in the [`Content-Disposition`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition) header.
pub fn name(&self) -> Option<&str> {
self.data.content_disposition.field_name.as_deref()
self.content_disposition.field_name.as_deref()
}

/// The file name found in the [`Content-Disposition`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition) header.
pub fn file_name(&self) -> Option<&str> {
self.data.content_disposition.file_name.as_deref()
self.content_disposition.file_name.as_deref()
}

/// Get the content type of the field.
pub fn content_type(&self) -> Option<&mime::Mime> {
self.data.content_type.as_ref()
self.content_type.as_ref()
}

/// Get a map of headers as [`HeaderMap`].
pub fn headers(&self) -> &HeaderMap {
&self.data.headers
&self.headers
}

/// Get the full data of the field as [`Bytes`].
Expand Down Expand Up @@ -324,42 +311,49 @@ impl<'a> Field<'a> {
/// # tokio::runtime::Runtime::new().unwrap().block_on(run());
/// ```
pub fn index(&self) -> usize {
self.data.idx
self.idx
}
}

impl Stream for Field<'_> {
type Item = Result<Bytes, crate::Error>;
type Item = Result<Bytes, Error>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.done {
return Poll::Ready(None);
}

let state = &mut *self.state;
debug_assert!(self.state.try_lock().is_some(), "expected exlusive lock");
let state = self.state.clone();
let mut lock = match state.try_lock() {
Some(lock) => lock,
None => return Poll::Ready(Some(Err(Error::LockFailure))),
};

let state = &mut *lock;
if let Err(err) = state.buffer.poll_stream(cx) {
return Poll::Ready(Some(Err(crate::Error::StreamReadFailed(err.into()))));
}

match state
.buffer
.read_field_data(state.boundary.as_str(), state.curr_field_name.as_deref())
.read_field_data(&state.boundary, state.curr_field_name.as_deref())
{
Ok(Some((done, bytes))) => {
self.state.curr_field_size_counter += bytes.len() as u64;
state.curr_field_size_counter += bytes.len() as u64;

if self.state.curr_field_size_counter > self.state.curr_field_size_limit {
if state.curr_field_size_counter > state.curr_field_size_limit {
return Poll::Ready(Some(Err(crate::Error::FieldSizeExceeded {
limit: self.state.curr_field_size_limit,
field_name: self.state.curr_field_name.clone(),
limit: state.curr_field_size_limit,
field_name: state.curr_field_name.clone(),
})));
}

if done {
state.stage = StreamingStage::ReadingBoundary;
self.done = true;
self.state.stage = StreamingStage::ReadingBoundary;
}

Poll::Ready(Some(Ok(bytes)))
}
Ok(None) => Poll::Pending,
Expand Down
3 changes: 1 addition & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,9 @@ mod field;
mod helpers;
mod multipart;
mod size_limit;
mod state;

/// A Result type often returned from methods that can have `multer` errors.
pub type Result<T> = std::result::Result<T, Error>;
pub type Result<T, E = Error> = std::result::Result<T, E>;

/// Parses the `Content-Type` header to extract the boundary value.
///
Expand Down
Loading

0 comments on commit 2758e77

Please sign in to comment.