Skip to content

Commit

Permalink
Make the default notifications iterator read nonblocking
Browse files Browse the repository at this point in the history
It is always super confusing as to when a notification that's been sent
to the client will actually show up in the old version of this iterator,
so it's best to have it see if there's anything waiting in the TCP
buffer.

Closes #149
  • Loading branch information
sfackler committed Dec 27, 2015
1 parent 278ee1c commit bb837bd
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 13 deletions.
18 changes: 18 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,24 @@ impl InnerConnection {
}
}

fn read_message_with_notification_nonblocking(&mut self)
-> std::io::Result<Option<BackendMessage>> {
debug_assert!(!self.desynchronized);
loop {
match try_desync!(self, self.stream.read_message_nonblocking()) {
Some(NoticeResponse { fields }) => {
if let Ok(err) = DbError::new_raw(fields) {
self.notice_handler.handle_notice(err);
}
}
Some(ParameterStatus { parameter, value }) => {
self.parameters.insert(parameter, value);
}
val => return Ok(val),
}
}
}

fn read_message(&mut self) -> std_io::Result<BackendMessage> {
loop {
match try!(self.read_message_with_notification()) {
Expand Down
24 changes: 22 additions & 2 deletions src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};

use types::Oid;
use util;
use priv_io::ReadTimeout;
use priv_io::StreamOptions;

use self::BackendMessage::*;
use self::FrontendMessage::*;
Expand Down Expand Up @@ -287,10 +287,12 @@ pub trait ReadMessage {

fn read_message_timeout(&mut self, timeout: Duration) -> io::Result<Option<BackendMessage>>;

fn read_message_nonblocking(&mut self) -> io::Result<Option<BackendMessage>>;

fn finish_read_message(&mut self, ident: u8) -> io::Result<BackendMessage>;
}

impl<R: BufRead + ReadTimeout> ReadMessage for R {
impl<R: BufRead + StreamOptions> ReadMessage for R {
fn read_message(&mut self) -> io::Result<BackendMessage> {
let ident = try!(self.read_u8());
self.finish_read_message(ident)
Expand All @@ -314,6 +316,24 @@ impl<R: BufRead + ReadTimeout> ReadMessage for R {
}
}

fn read_message_nonblocking(&mut self) -> io::Result<Option<BackendMessage>> {
try!(self.set_nonblocking(true));
let ident = self.read_u8();
try!(self.set_nonblocking(false));

match ident {
Ok(ident) => self.finish_read_message(ident).map(Some),
Err(e) => {
let e: io::Error = e.into();
if e.kind() == io::ErrorKind::WouldBlock {
Ok(None)
} else {
Err(e)
}
}
}
}

fn finish_read_message(&mut self, ident: u8) -> io::Result<BackendMessage> {
// subtract size of length value
let len = try!(self.read_u32::<BigEndian>()) - mem::size_of::<u32>() as u32;
Expand Down
28 changes: 22 additions & 6 deletions src/notification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ impl<'conn> Notifications<'conn> {
/// # Note
///
/// This iterator may start returning `Some` after previously returning
/// `None` if more notifications are received. However, those notifications
/// will not be registered until the connection is used in some way.
/// `None` if more notifications are received.
pub fn iter<'a>(&'a self) -> Iter<'a> {
Iter { conn: self.conn }
}
Expand Down Expand Up @@ -72,7 +71,7 @@ impl<'conn> Notifications<'conn> {
}

impl<'a, 'conn> IntoIterator for &'a Notifications<'conn> {
type Item = Notification;
type Item = Result<Notification>;
type IntoIter = Iter<'a>;

fn into_iter(self) -> Iter<'a> {
Expand All @@ -92,10 +91,27 @@ pub struct Iter<'a> {
}

impl<'a> Iterator for Iter<'a> {
type Item = Notification;
type Item = Result<Notification>;

fn next(&mut self) -> Option<Result<Notification>> {
let mut conn = self.conn.conn.borrow_mut();

fn next(&mut self) -> Option<Notification> {
self.conn.conn.borrow_mut().notifications.pop_front()
if let Some(notification) = conn.notifications.pop_front() {
return Some(Ok(notification));
}

match conn.read_message_with_notification_nonblocking() {
Ok(Some(NotificationResponse { pid, channel, payload })) => {
Some(Ok(Notification {
pid: pid,
channel: channel,
payload: payload,
}))
}
Ok(None) => None,
Err(err) => Some(Err(Error::Io(err))),
_ => unreachable!(),
}
}
}

Expand Down
13 changes: 11 additions & 2 deletions src/priv_io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ use message::FrontendMessage::SslRequest;
const DEFAULT_PORT: u16 = 5432;

#[doc(hidden)]
pub trait ReadTimeout {
pub trait StreamOptions {
fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()>;
fn set_nonblocking(&self, nonblock: bool) -> io::Result<()>;
}

impl ReadTimeout for BufStream<Box<StreamWrapper>> {
impl StreamOptions for BufStream<Box<StreamWrapper>> {
fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
match self.get_ref().get_ref().0 {
InternalStream::Tcp(ref s) => {
Expand All @@ -36,6 +37,14 @@ impl ReadTimeout for BufStream<Box<StreamWrapper>> {
InternalStream::Unix(ref s) => s.set_read_timeout(timeout),
}
}

fn set_nonblocking(&self, nonblock: bool) -> io::Result<()> {
match self.get_ref().get_ref().0 {
InternalStream::Tcp(ref s) => s.set_nonblocking(nonblock),
#[cfg(feature = "unix_socket")]
InternalStream::Unix(ref s) => s.set_nonblocking(nonblock),
}
}
}

/// A connection to the Postgres server.
Expand Down
6 changes: 3 additions & 3 deletions tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -575,20 +575,20 @@ fn test_notification_iterator_some() {
pid: 0,
channel: "test_notification_iterator_one_channel".to_string(),
payload: "hello".to_string()
}, it.next().unwrap());
}, it.next().unwrap().unwrap());
check_notification(Notification {
pid: 0,
channel: "test_notification_iterator_one_channel2".to_string(),
payload: "world".to_string()
}, it.next().unwrap());
}, it.next().unwrap().unwrap());
assert!(it.next().is_none());

or_panic!(conn.execute("NOTIFY test_notification_iterator_one_channel, '!'", &[]));
check_notification(Notification {
pid: 0,
channel: "test_notification_iterator_one_channel".to_string(),
payload: "!".to_string()
}, it.next().unwrap());
}, it.next().unwrap().unwrap());
assert!(it.next().is_none());
}

Expand Down

0 comments on commit bb837bd

Please sign in to comment.