From 12811c3aba2400cf38a024d21f52b21487cd6cd0 Mon Sep 17 00:00:00 2001 From: Markus Westerlind Date: Fri, 25 Oct 2019 16:32:29 +0200 Subject: [PATCH] refactor: Simplify more by using async fn --- Makefile | 2 +- src/aio.rs | 42 +++++------ src/parser.rs | 190 +++++++++++++++++++++----------------------------- 3 files changed, 101 insertions(+), 133 deletions(-) diff --git a/Makefile b/Makefile index 103bb153a..a6b556114 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ test: @echo "====================================================================" @echo "Testing Connection Type TCP" @echo "====================================================================" - @REDISRS_SERVER_TYPE=tcp RUST_TEST_THREADS=1 cargo test --all-features + @REDISRS_SERVER_TYPE=tcp RUST_TEST_THREADS=1 cargo test --all-features -- --nocapture @echo "Testing Connection Type UNIX" @echo "====================================================================" @REDISRS_SERVER_TYPE=unix cargo test --test parser --test test_basic --test test_types --all-features diff --git a/src/aio.rs b/src/aio.rs index f444d83ac..38bd0b2e3 100644 --- a/src/aio.rs +++ b/src/aio.rs @@ -116,8 +116,9 @@ impl ActualConnection { pub async fn connect(connection_info: &ConnectionInfo) -> RedisResult { let con = match *connection_info.addr { ConnectionAddr::Tcp(ref host, port) => { - let socket_addr = match (&host[..], port).to_socket_addrs() { - Ok(mut socket_addrs) => match socket_addrs.next() { + let socket_addr = { + let mut socket_addrs = (&host[..], port).to_socket_addrs()?; + match socket_addrs.next() { Some(socket_addr) => socket_addr, None => { return Err(RedisError::from(( @@ -125,14 +126,14 @@ pub async fn connect(connection_info: &ConnectionInfo) -> RedisResult return Err(err.into()), + } }; TcpStream::connect(&socket_addr) .map_ok(|con| ActualConnection::Tcp(BufReader::new(BufWriter::new(con)))) .await? } + #[cfg(unix)] ConnectionAddr::Unix(ref path) => { UnixStream::connect(path) @@ -155,38 +156,33 @@ pub async fn connect(connection_info: &ConnectionInfo) -> RedisResult { - let mut cmd = cmd("AUTH"); - cmd.arg(&**passwd); - let x = cmd.query_async::<_, Value>(&mut rv).await; - match x { - Ok(Value::Okay) => (), - _ => { - fail!(( - ErrorKind::AuthenticationFailed, - "Password authentication failed" - )); - } + if let Some(passwd) = &connection_info.passwd { + let mut cmd = cmd("AUTH"); + cmd.arg(&**passwd); + match cmd.query_async::<_, Value>(&mut rv).await { + Ok(Value::Okay) => (), + _ => { + fail!(( + ErrorKind::AuthenticationFailed, + "Password authentication failed" + )); } } - None => (), } if connection_info.db != 0 { let mut cmd = cmd("SELECT"); cmd.arg(connection_info.db); - let result = cmd.query_async::<_, Value>(&mut rv).await; - match result { - Ok(Value::Okay) => Ok(rv), + match cmd.query_async::<_, Value>(&mut rv).await { + Ok(Value::Okay) => (), _ => fail!(( ErrorKind::ResponseError, "Redis server refused to switch database" )), } - } else { - Ok(rv) } + + Ok(rv) } /// An async abstraction over connections. diff --git a/src/parser.rs b/src/parser.rs index ddc9f3d5f..69421a3de 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -5,11 +5,9 @@ use std::str; use crate::types::{make_extension_error, ErrorKind, RedisError, RedisResult, Value}; -use pin_project::pin_project; - use bytes::BytesMut; use combine::{combine_parse_partial, combine_parser_impl, parse_mode, parser}; -use futures::{prelude::*, ready, task, Poll}; +use futures::{future, task, Poll}; use tokio_codec::{Decoder, Encoder}; use tokio_io::{AsyncBufRead, AsyncRead}; @@ -185,120 +183,98 @@ impl Decoder for ValueCodec { } } -#[pin_project] -pub struct ValueFuture { - #[pin] - reader: Option, - state: AnySendPartialState, - // Intermediate storage for data we know that we need to parse a value but we haven't been able - // to parse completely yet - remaining: Vec, -} - -impl ValueFuture { - fn reader(&mut self) -> Pin<&mut R> - where - R: Unpin, - { - Pin::new(self.reader.as_mut().unwrap()) - } +// https://github.com/tokio-rs/tokio/pull/1687 +async fn fill_buf(reader: &mut R) -> io::Result<&[u8]> +where + R: AsyncBufRead + Unpin, +{ + let mut reader = Some(reader); + future::poll_fn(move |cx| match reader.take() { + Some(r) => match Pin::new(&mut *r).poll_fill_buf(cx) { + // SAFETY We either drop `self.reader` and return a slice with the lifetime of the + // reader or we return Pending/Err (neither which contains `'a`). + // In either case `poll_fill_buf` can not be called while it's contents are exposed + Poll::Ready(Ok(x)) => unsafe { return Ok(&*(x as *const _)).into() }, + Poll::Ready(Err(err)) => Err(err).into(), + Poll::Pending => { + reader = Some(r); + Poll::Pending + } + }, + None => panic!("fill_buf polled after completion"), + }) + .await } -impl Future for ValueFuture +/// Parses a redis value asynchronously. +pub async fn parse_redis_value_async(mut reader: R) -> RedisResult where R: AsyncBufRead + Unpin, { - type Output = RedisResult; - - fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll { - loop { - assert!( - self.reader.is_some(), - "ValueFuture: poll called on completed future" - ); - let remaining_data = self.remaining.len(); - - let (opt, mut removed) = { - let self_ = &mut *self; - let buffer = match ready!(Pin::new(self_.reader.as_mut().unwrap()).poll_fill_buf(cx)) - { - Ok(buffer) => buffer, - Err(err) => return Err(err.into()).into(), - }; - if buffer.len() == 0 { - return Err((ErrorKind::ResponseError, "Could not read enough bytes").into()) - .into(); - } - let buffer = if !self_.remaining.is_empty() { - self_.remaining.extend(buffer); - &self_.remaining[..] - } else { - buffer - }; - let stream = combine::easy::Stream(combine::stream::PartialStream(buffer)); - match combine::stream::decode(value(), stream, &mut self_.state) { - Ok(x) => x, - Err(err) => { - let err = err - .map_position(|pos| pos.translate_position(buffer)) - .map_range(|range| format!("{:?}", range)) - .to_string(); - return Err(RedisError::from(( - ErrorKind::ResponseError, - "parse error", - err, - ))) - .into(); - } - } + let mut state = Default::default(); + let mut remaining = Vec::new(); + loop { + let remaining_data = remaining.len(); + + let (opt, mut removed) = { + let buffer = fill_buf(&mut reader).await?; + if buffer.len() == 0 { + return Err((ErrorKind::ResponseError, "Could not read enough bytes").into()); + } + let buffer = if !remaining.is_empty() { + remaining.extend(buffer); + &remaining[..] + } else { + buffer }; - if !self.remaining.is_empty() { - // Remove the data we have parsed and adjust `removed` to be the amount of data we - // consumed from `self.reader` - self.remaining.drain(..removed); - if removed >= remaining_data { - removed -= remaining_data; - } else { - removed = 0; + let stream = combine::easy::Stream(combine::stream::PartialStream(&buffer[..])); + match combine::stream::decode(value(), stream, &mut state) { + Ok(x) => x, + Err(err) => { + let err = err + .map_position(|pos| pos.translate_position(&buffer[..])) + .map_range(|range| format!("{:?}", range)) + .to_string(); + return Err(RedisError::from(( + ErrorKind::ResponseError, + "parse error", + err, + ))); } } + }; - match opt { - Some(value) => { - self.reader().consume(removed); - self.reader.take().unwrap(); - return Ok(value?).into(); - } - None => { - // We have not enough data to produce a Value but we know that all the data of - // the current buffer are necessary. Consume all the buffered data to ensure - // that the next iteration actually reads more data. - let buffer_len = { - let self_ = &mut *self; - let buffer = - ready!(Pin::new(self_.reader.as_mut().unwrap()).poll_fill_buf(cx))?; - if remaining_data == 0 { - self_.remaining.extend(&buffer[removed..]); - } - buffer.len() - }; - self.reader().consume(buffer_len); - } + if !remaining.is_empty() { + // Remove the data we have parsed and adjust `removed` to be the amount of data we + // consumed from `self.reader` + remaining.drain(..removed); + if removed >= remaining_data { + removed -= remaining_data; + } else { + removed = 0; } } - } -} -/// Parses a redis value asynchronously. -pub fn parse_redis_value_async(reader: R) -> impl Future> -where - R: AsyncRead + AsyncBufRead + Unpin, -{ - ValueFuture { - reader: Some(reader), - state: Default::default(), - remaining: Vec::new(), + match opt { + Some(value) => { + Pin::new(&mut reader).consume(removed); + return Ok(value?); + } + None => { + // We have not enough data to produce a Value but we know that all the data of + // the current buffer are necessary. Consume all the buffered data to ensure + // that the next iteration actually reads more data. + let buffer_len = { + let buffer = fill_buf(&mut reader).await?; + if remaining_data == 0 { + remaining.extend(&buffer[removed..]); + } + buffer.len() + }; + Pin::new(&mut reader).consume(buffer_len); + } + } } } @@ -352,11 +328,7 @@ impl<'a, T: BufRead> Parser { /// Parses synchronously into a single value from the reader. pub fn parse_value(&mut self) -> RedisResult { - let parser = ValueFuture { - reader: Some(BlockingWrapper(&mut self.reader)), - state: Default::default(), - remaining: Vec::new(), - }; + let parser = parse_redis_value_async(BlockingWrapper(&mut self.reader)); futures::executor::block_on(parser) } }