Skip to content

Commit

Permalink
refactor: Simplify more by using async fn
Browse files Browse the repository at this point in the history
  • Loading branch information
Marwes committed Oct 25, 2019
1 parent 143916a commit 12811c3
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 133 deletions.
2 changes: 1 addition & 1 deletion Makefile
Expand Up @@ -5,7 +5,7 @@ test:
@echo "===================================================================="
@echo "Testing Connection Type TCP"
@echo "===================================================================="
@REDISRS_SERVER_TYPE=tcp RUST_TEST_THREADS=1 cargo test --all-features
@REDISRS_SERVER_TYPE=tcp RUST_TEST_THREADS=1 cargo test --all-features -- --nocapture
@echo "Testing Connection Type UNIX"
@echo "===================================================================="
@REDISRS_SERVER_TYPE=unix cargo test --test parser --test test_basic --test test_types --all-features
Expand Down
42 changes: 19 additions & 23 deletions src/aio.rs
Expand Up @@ -116,23 +116,24 @@ impl ActualConnection {
pub async fn connect(connection_info: &ConnectionInfo) -> RedisResult<Connection> {
let con = match *connection_info.addr {
ConnectionAddr::Tcp(ref host, port) => {
let socket_addr = match (&host[..], port).to_socket_addrs() {
Ok(mut socket_addrs) => match socket_addrs.next() {
let socket_addr = {
let mut socket_addrs = (&host[..], port).to_socket_addrs()?;
match socket_addrs.next() {
Some(socket_addr) => socket_addr,
None => {
return Err(RedisError::from((
ErrorKind::InvalidClientConfig,
"No address found for host",
)));
}
},
Err(err) => return Err(err.into()),
}
};

TcpStream::connect(&socket_addr)
.map_ok(|con| ActualConnection::Tcp(BufReader::new(BufWriter::new(con))))
.await?
}

#[cfg(unix)]
ConnectionAddr::Unix(ref path) => {
UnixStream::connect(path)
Expand All @@ -155,38 +156,33 @@ pub async fn connect(connection_info: &ConnectionInfo) -> RedisResult<Connection
db: connection_info.db,
};

match connection_info.passwd {
Some(ref passwd) => {
let mut cmd = cmd("AUTH");
cmd.arg(&**passwd);
let x = cmd.query_async::<_, Value>(&mut rv).await;
match x {
Ok(Value::Okay) => (),
_ => {
fail!((
ErrorKind::AuthenticationFailed,
"Password authentication failed"
));
}
if let Some(passwd) = &connection_info.passwd {
let mut cmd = cmd("AUTH");
cmd.arg(&**passwd);
match cmd.query_async::<_, Value>(&mut rv).await {
Ok(Value::Okay) => (),
_ => {
fail!((
ErrorKind::AuthenticationFailed,
"Password authentication failed"
));
}
}
None => (),
}

if connection_info.db != 0 {
let mut cmd = cmd("SELECT");
cmd.arg(connection_info.db);
let result = cmd.query_async::<_, Value>(&mut rv).await;
match result {
Ok(Value::Okay) => Ok(rv),
match cmd.query_async::<_, Value>(&mut rv).await {
Ok(Value::Okay) => (),
_ => fail!((
ErrorKind::ResponseError,
"Redis server refused to switch database"
)),
}
} else {
Ok(rv)
}

Ok(rv)
}

/// An async abstraction over connections.
Expand Down
190 changes: 81 additions & 109 deletions src/parser.rs
Expand Up @@ -5,11 +5,9 @@ use std::str;

use crate::types::{make_extension_error, ErrorKind, RedisError, RedisResult, Value};

use pin_project::pin_project;

use bytes::BytesMut;
use combine::{combine_parse_partial, combine_parser_impl, parse_mode, parser};
use futures::{prelude::*, ready, task, Poll};
use futures::{future, task, Poll};
use tokio_codec::{Decoder, Encoder};
use tokio_io::{AsyncBufRead, AsyncRead};

Expand Down Expand Up @@ -185,120 +183,98 @@ impl Decoder for ValueCodec {
}
}

#[pin_project]
pub struct ValueFuture<R> {
#[pin]
reader: Option<R>,
state: AnySendPartialState,
// Intermediate storage for data we know that we need to parse a value but we haven't been able
// to parse completely yet
remaining: Vec<u8>,
}

impl<R> ValueFuture<R> {
fn reader(&mut self) -> Pin<&mut R>
where
R: Unpin,
{
Pin::new(self.reader.as_mut().unwrap())
}
// https://github.com/tokio-rs/tokio/pull/1687
async fn fill_buf<R>(reader: &mut R) -> io::Result<&[u8]>
where
R: AsyncBufRead + Unpin,
{
let mut reader = Some(reader);
future::poll_fn(move |cx| match reader.take() {
Some(r) => match Pin::new(&mut *r).poll_fill_buf(cx) {
// SAFETY We either drop `self.reader` and return a slice with the lifetime of the
// reader or we return Pending/Err (neither which contains `'a`).
// In either case `poll_fill_buf` can not be called while it's contents are exposed
Poll::Ready(Ok(x)) => unsafe { return Ok(&*(x as *const _)).into() },
Poll::Ready(Err(err)) => Err(err).into(),
Poll::Pending => {
reader = Some(r);
Poll::Pending
}
},
None => panic!("fill_buf polled after completion"),
})
.await
}

impl<R> Future for ValueFuture<R>
/// Parses a redis value asynchronously.
pub async fn parse_redis_value_async<R>(mut reader: R) -> RedisResult<Value>
where
R: AsyncBufRead + Unpin,
{
type Output = RedisResult<Value>;

fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Self::Output> {
loop {
assert!(
self.reader.is_some(),
"ValueFuture: poll called on completed future"
);
let remaining_data = self.remaining.len();

let (opt, mut removed) = {
let self_ = &mut *self;
let buffer = match ready!(Pin::new(self_.reader.as_mut().unwrap()).poll_fill_buf(cx))
{
Ok(buffer) => buffer,
Err(err) => return Err(err.into()).into(),
};
if buffer.len() == 0 {
return Err((ErrorKind::ResponseError, "Could not read enough bytes").into())
.into();
}
let buffer = if !self_.remaining.is_empty() {
self_.remaining.extend(buffer);
&self_.remaining[..]
} else {
buffer
};
let stream = combine::easy::Stream(combine::stream::PartialStream(buffer));
match combine::stream::decode(value(), stream, &mut self_.state) {
Ok(x) => x,
Err(err) => {
let err = err
.map_position(|pos| pos.translate_position(buffer))
.map_range(|range| format!("{:?}", range))
.to_string();
return Err(RedisError::from((
ErrorKind::ResponseError,
"parse error",
err,
)))
.into();
}
}
let mut state = Default::default();
let mut remaining = Vec::new();
loop {
let remaining_data = remaining.len();

let (opt, mut removed) = {
let buffer = fill_buf(&mut reader).await?;
if buffer.len() == 0 {
return Err((ErrorKind::ResponseError, "Could not read enough bytes").into());
}
let buffer = if !remaining.is_empty() {
remaining.extend(buffer);
&remaining[..]
} else {
buffer
};

if !self.remaining.is_empty() {
// Remove the data we have parsed and adjust `removed` to be the amount of data we
// consumed from `self.reader`
self.remaining.drain(..removed);
if removed >= remaining_data {
removed -= remaining_data;
} else {
removed = 0;
let stream = combine::easy::Stream(combine::stream::PartialStream(&buffer[..]));
match combine::stream::decode(value(), stream, &mut state) {
Ok(x) => x,
Err(err) => {
let err = err
.map_position(|pos| pos.translate_position(&buffer[..]))
.map_range(|range| format!("{:?}", range))
.to_string();
return Err(RedisError::from((
ErrorKind::ResponseError,
"parse error",
err,
)));
}
}
};

match opt {
Some(value) => {
self.reader().consume(removed);
self.reader.take().unwrap();
return Ok(value?).into();
}
None => {
// We have not enough data to produce a Value but we know that all the data of
// the current buffer are necessary. Consume all the buffered data to ensure
// that the next iteration actually reads more data.
let buffer_len = {
let self_ = &mut *self;
let buffer =
ready!(Pin::new(self_.reader.as_mut().unwrap()).poll_fill_buf(cx))?;
if remaining_data == 0 {
self_.remaining.extend(&buffer[removed..]);
}
buffer.len()
};
self.reader().consume(buffer_len);
}
if !remaining.is_empty() {
// Remove the data we have parsed and adjust `removed` to be the amount of data we
// consumed from `self.reader`
remaining.drain(..removed);
if removed >= remaining_data {
removed -= remaining_data;
} else {
removed = 0;
}
}
}
}

/// Parses a redis value asynchronously.
pub fn parse_redis_value_async<R>(reader: R) -> impl Future<Output = RedisResult<Value>>
where
R: AsyncRead + AsyncBufRead + Unpin,
{
ValueFuture {
reader: Some(reader),
state: Default::default(),
remaining: Vec::new(),
match opt {
Some(value) => {
Pin::new(&mut reader).consume(removed);
return Ok(value?);
}
None => {
// We have not enough data to produce a Value but we know that all the data of
// the current buffer are necessary. Consume all the buffered data to ensure
// that the next iteration actually reads more data.
let buffer_len = {
let buffer = fill_buf(&mut reader).await?;
if remaining_data == 0 {
remaining.extend(&buffer[removed..]);
}
buffer.len()
};
Pin::new(&mut reader).consume(buffer_len);
}
}
}
}

Expand Down Expand Up @@ -352,11 +328,7 @@ impl<'a, T: BufRead> Parser<T> {

/// Parses synchronously into a single value from the reader.
pub fn parse_value(&mut self) -> RedisResult<Value> {
let parser = ValueFuture {
reader: Some(BlockingWrapper(&mut self.reader)),
state: Default::default(),
remaining: Vec::new(),
};
let parser = parse_redis_value_async(BlockingWrapper(&mut self.reader));
futures::executor::block_on(parser)
}
}
Expand Down

0 comments on commit 12811c3

Please sign in to comment.