diff --git a/.circleci/config.yml b/.circleci/config.yml index acaec4a32..f3dae7101 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -22,7 +22,7 @@ version: 2 jobs: build: docker: - - image: rust:1.40.0 + - image: rust:1.41.0 environment: RUSTFLAGS: -D warnings - image: sfackler/rust-postgres-test:6 diff --git a/postgres/Cargo.toml b/postgres/Cargo.toml index d23715cde..d0cf11004 100644 --- a/postgres/Cargo.toml +++ b/postgres/Cargo.toml @@ -35,7 +35,7 @@ fallible-iterator = "0.2" futures = "0.3" tokio-postgres = { version = "0.5.3", path = "../tokio-postgres" } -tokio = { version = "0.2", features = ["rt-core"] } +tokio = { version = "0.2", features = ["rt-core", "time"] } log = "0.4" [dev-dependencies] diff --git a/postgres/src/binary_copy.rs b/postgres/src/binary_copy.rs index 7828cb599..259347195 100644 --- a/postgres/src/binary_copy.rs +++ b/postgres/src/binary_copy.rs @@ -1,7 +1,8 @@ //! Utilities for working with the PostgreSQL binary copy format. +use crate::connection::ConnectionRef; use crate::types::{ToSql, Type}; -use crate::{CopyInWriter, CopyOutReader, Error, Rt}; +use crate::{CopyInWriter, CopyOutReader, Error}; use fallible_iterator::FallibleIterator; use futures::StreamExt; use std::pin::Pin; @@ -13,7 +14,7 @@ use tokio_postgres::binary_copy::{self, BinaryCopyOutStream}; /// /// The copy *must* be explicitly completed via the `finish` method. If it is not, the copy will be aborted. pub struct BinaryCopyInWriter<'a> { - runtime: Rt<'a>, + connection: ConnectionRef<'a>, sink: Pin>, } @@ -26,7 +27,7 @@ impl<'a> BinaryCopyInWriter<'a> { .expect("writer has already been written to"); BinaryCopyInWriter { - runtime: writer.runtime, + connection: writer.connection, sink: Box::pin(binary_copy::BinaryCopyInWriter::new(stream, types)), } } @@ -37,7 +38,7 @@ impl<'a> BinaryCopyInWriter<'a> { /// /// Panics if the number of values provided does not match the number expected. pub fn write(&mut self, values: &[&(dyn ToSql + Sync)]) -> Result<(), Error> { - self.runtime.block_on(self.sink.as_mut().write(values)) + self.connection.block_on(self.sink.as_mut().write(values)) } /// A maximally-flexible version of `write`. @@ -50,20 +51,21 @@ impl<'a> BinaryCopyInWriter<'a> { I: IntoIterator, I::IntoIter: ExactSizeIterator, { - self.runtime.block_on(self.sink.as_mut().write_raw(values)) + self.connection + .block_on(self.sink.as_mut().write_raw(values)) } /// Completes the copy, returning the number of rows added. /// /// This method *must* be used to complete the copy process. If it is not, the copy will be aborted. pub fn finish(mut self) -> Result { - self.runtime.block_on(self.sink.as_mut().finish()) + self.connection.block_on(self.sink.as_mut().finish()) } } /// An iterator of rows deserialized from the PostgreSQL binary copy format. pub struct BinaryCopyOutIter<'a> { - runtime: Rt<'a>, + connection: ConnectionRef<'a>, stream: Pin>, } @@ -76,7 +78,7 @@ impl<'a> BinaryCopyOutIter<'a> { .expect("reader has already been read from"); BinaryCopyOutIter { - runtime: reader.runtime, + connection: reader.connection, stream: Box::pin(BinaryCopyOutStream::new(stream, types)), } } @@ -87,6 +89,8 @@ impl FallibleIterator for BinaryCopyOutIter<'_> { type Error = Error; fn next(&mut self) -> Result, Error> { - self.runtime.block_on(self.stream.next()).transpose() + let stream = &mut self.stream; + self.connection + .block_on(async { stream.next().await.transpose() }) } } diff --git a/postgres/src/client.rs b/postgres/src/client.rs index 0a3a51e1b..a0c61b33d 100644 --- a/postgres/src/client.rs +++ b/postgres/src/client.rs @@ -1,45 +1,21 @@ +use crate::connection::Connection; use crate::{ - CancelToken, Config, CopyInWriter, CopyOutReader, RowIter, Statement, ToStatement, Transaction, - TransactionBuilder, + CancelToken, Config, CopyInWriter, CopyOutReader, Notifications, RowIter, Statement, + ToStatement, Transaction, TransactionBuilder, }; -use std::ops::{Deref, DerefMut}; -use tokio::runtime::Runtime; use tokio_postgres::tls::{MakeTlsConnect, TlsConnect}; use tokio_postgres::types::{ToSql, Type}; use tokio_postgres::{Error, Row, SimpleQueryMessage, Socket}; -pub(crate) struct Rt<'a>(pub &'a mut Runtime); - -// no-op impl to extend the borrow until drop -impl Drop for Rt<'_> { - fn drop(&mut self) {} -} - -impl Deref for Rt<'_> { - type Target = Runtime; - - #[inline] - fn deref(&self) -> &Runtime { - self.0 - } -} - -impl DerefMut for Rt<'_> { - #[inline] - fn deref_mut(&mut self) -> &mut Runtime { - self.0 - } -} - /// A synchronous PostgreSQL client. pub struct Client { - runtime: Runtime, + connection: Connection, client: tokio_postgres::Client, } impl Client { - pub(crate) fn new(runtime: Runtime, client: tokio_postgres::Client) -> Client { - Client { runtime, client } + pub(crate) fn new(connection: Connection, client: tokio_postgres::Client) -> Client { + Client { connection, client } } /// A convenience function which parses a configuration string into a `Config` and then connects to the database. @@ -62,10 +38,6 @@ impl Client { Config::new() } - fn rt(&mut self) -> Rt<'_> { - Rt(&mut self.runtime) - } - /// Executes a statement, returning the number of rows modified. /// /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list @@ -104,7 +76,7 @@ impl Client { where T: ?Sized + ToStatement, { - self.runtime.block_on(self.client.execute(query, params)) + self.connection.block_on(self.client.execute(query, params)) } /// Executes a statement, returning the resulting rows. @@ -140,7 +112,7 @@ impl Client { where T: ?Sized + ToStatement, { - self.runtime.block_on(self.client.query(query, params)) + self.connection.block_on(self.client.query(query, params)) } /// Executes a statement which returns a single row, returning it. @@ -177,7 +149,8 @@ impl Client { where T: ?Sized + ToStatement, { - self.runtime.block_on(self.client.query_one(query, params)) + self.connection + .block_on(self.client.query_one(query, params)) } /// Executes a statement which returns zero or one rows, returning it. @@ -223,7 +196,8 @@ impl Client { where T: ?Sized + ToStatement, { - self.runtime.block_on(self.client.query_opt(query, params)) + self.connection + .block_on(self.client.query_opt(query, params)) } /// A maximally-flexible version of `query`. @@ -289,9 +263,9 @@ impl Client { I::IntoIter: ExactSizeIterator, { let stream = self - .runtime + .connection .block_on(self.client.query_raw(query, params))?; - Ok(RowIter::new(self.rt(), stream)) + Ok(RowIter::new(self.connection.as_ref(), stream)) } /// Creates a new prepared statement. @@ -318,7 +292,7 @@ impl Client { /// # } /// ``` pub fn prepare(&mut self, query: &str) -> Result { - self.runtime.block_on(self.client.prepare(query)) + self.connection.block_on(self.client.prepare(query)) } /// Like `prepare`, but allows the types of query parameters to be explicitly specified. @@ -349,7 +323,7 @@ impl Client { /// # } /// ``` pub fn prepare_typed(&mut self, query: &str, types: &[Type]) -> Result { - self.runtime + self.connection .block_on(self.client.prepare_typed(query, types)) } @@ -380,8 +354,8 @@ impl Client { where T: ?Sized + ToStatement, { - let sink = self.runtime.block_on(self.client.copy_in(query))?; - Ok(CopyInWriter::new(self.rt(), sink)) + let sink = self.connection.block_on(self.client.copy_in(query))?; + Ok(CopyInWriter::new(self.connection.as_ref(), sink)) } /// Executes a `COPY TO STDOUT` statement, returning a reader of the resulting data. @@ -408,8 +382,8 @@ impl Client { where T: ?Sized + ToStatement, { - let stream = self.runtime.block_on(self.client.copy_out(query))?; - Ok(CopyOutReader::new(self.rt(), stream)) + let stream = self.connection.block_on(self.client.copy_out(query))?; + Ok(CopyOutReader::new(self.connection.as_ref(), stream)) } /// Executes a sequence of SQL statements using the simple query protocol. @@ -428,7 +402,7 @@ impl Client { /// functionality to safely imbed that data in the request. Do not form statements via string concatenation and pass /// them to this method! pub fn simple_query(&mut self, query: &str) -> Result, Error> { - self.runtime.block_on(self.client.simple_query(query)) + self.connection.block_on(self.client.simple_query(query)) } /// Executes a sequence of SQL statements using the simple query protocol. @@ -442,7 +416,7 @@ impl Client { /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass /// them to this method! pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> { - self.runtime.block_on(self.client.batch_execute(query)) + self.connection.block_on(self.client.batch_execute(query)) } /// Begins a new database transaction. @@ -466,8 +440,8 @@ impl Client { /// # } /// ``` pub fn transaction(&mut self) -> Result, Error> { - let transaction = self.runtime.block_on(self.client.transaction())?; - Ok(Transaction::new(&mut self.runtime, transaction)) + let transaction = self.connection.block_on(self.client.transaction())?; + Ok(Transaction::new(self.connection.as_ref(), transaction)) } /// Returns a builder for a transaction with custom settings. @@ -494,7 +468,14 @@ impl Client { /// # } /// ``` pub fn build_transaction(&mut self) -> TransactionBuilder<'_> { - TransactionBuilder::new(&mut self.runtime, self.client.build_transaction()) + TransactionBuilder::new(self.connection.as_ref(), self.client.build_transaction()) + } + + /// Returns a structure providing access to asynchronous notifications. + /// + /// Use the `LISTEN` command to register this connection for notifications. + pub fn notifications(&mut self) -> Notifications<'_> { + Notifications::new(self.connection.as_ref()) } /// Constructs a cancellation token that can later be used to request @@ -516,7 +497,7 @@ impl Client { /// thread::spawn(move || { /// // Abort the query after 5s. /// thread::sleep(Duration::from_secs(5)); - /// cancel_token.cancel_query(NoTls); + /// let _ = cancel_token.cancel_query(NoTls); /// }); /// /// match client.simple_query("SELECT long_running_query()") { diff --git a/postgres/src/config.rs b/postgres/src/config.rs index f6b151a8e..b344efdd2 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -2,9 +2,8 @@ //! //! Requires the `runtime` Cargo feature (enabled by default). +use crate::connection::Connection; use crate::Client; -use futures::FutureExt; -use log::error; use std::fmt; use std::path::Path; use std::str::FromStr; @@ -324,15 +323,8 @@ impl Config { let (client, connection) = runtime.block_on(self.config.connect(tls))?; - // FIXME don't spawn this so error reporting is less weird. - let connection = connection.map(|r| { - if let Err(e) = r { - error!("postgres connection error: {}", e) - } - }); - runtime.spawn(connection); - - Ok(Client::new(runtime, client)) + let connection = Connection::new(runtime, connection); + Ok(Client::new(connection, client)) } } diff --git a/postgres/src/connection.rs b/postgres/src/connection.rs new file mode 100644 index 000000000..acea5eca7 --- /dev/null +++ b/postgres/src/connection.rs @@ -0,0 +1,129 @@ +use crate::{Error, Notification}; +use futures::future; +use futures::{pin_mut, Stream}; +use log::info; +use std::collections::VecDeque; +use std::future::Future; +use std::ops::{Deref, DerefMut}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::runtime::Runtime; +use tokio_postgres::AsyncMessage; + +pub struct Connection { + runtime: Runtime, + connection: Pin> + Send>>, + notifications: VecDeque, +} + +impl Connection { + pub fn new(runtime: Runtime, connection: tokio_postgres::Connection) -> Connection + where + S: AsyncRead + AsyncWrite + Unpin + 'static + Send, + T: AsyncRead + AsyncWrite + Unpin + 'static + Send, + { + Connection { + runtime, + connection: Box::pin(ConnectionStream { connection }), + notifications: VecDeque::new(), + } + } + + pub fn as_ref(&mut self) -> ConnectionRef<'_> { + ConnectionRef { connection: self } + } + + pub fn enter(&self, f: F) -> T + where + F: FnOnce() -> T, + { + self.runtime.enter(f) + } + + pub fn block_on(&mut self, future: F) -> Result + where + F: Future>, + { + pin_mut!(future); + self.poll_block_on(|cx, _, _| future.as_mut().poll(cx)) + } + + pub fn poll_block_on(&mut self, mut f: F) -> Result + where + F: FnMut(&mut Context<'_>, &mut VecDeque, bool) -> Poll>, + { + let connection = &mut self.connection; + let notifications = &mut self.notifications; + self.runtime.block_on({ + future::poll_fn(|cx| { + let done = loop { + match connection.as_mut().poll_next(cx) { + Poll::Ready(Some(Ok(AsyncMessage::Notification(notification)))) => { + notifications.push_back(notification); + } + Poll::Ready(Some(Ok(AsyncMessage::Notice(notice)))) => { + info!("{}: {}", notice.severity(), notice.message()); + } + Poll::Ready(Some(Ok(_))) => {} + Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)), + Poll::Ready(None) => break true, + Poll::Pending => break false, + } + }; + + f(cx, notifications, done) + }) + }) + } + + pub fn notifications(&self) -> &VecDeque { + &self.notifications + } + + pub fn notifications_mut(&mut self) -> &mut VecDeque { + &mut self.notifications + } +} + +pub struct ConnectionRef<'a> { + connection: &'a mut Connection, +} + +// no-op impl to extend the borrow until drop +impl Drop for ConnectionRef<'_> { + #[inline] + fn drop(&mut self) {} +} + +impl Deref for ConnectionRef<'_> { + type Target = Connection; + + #[inline] + fn deref(&self) -> &Connection { + self.connection + } +} + +impl DerefMut for ConnectionRef<'_> { + #[inline] + fn deref_mut(&mut self) -> &mut Connection { + self.connection + } +} + +struct ConnectionStream { + connection: tokio_postgres::Connection, +} + +impl Stream for ConnectionStream +where + S: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, +{ + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.connection.poll_message(cx) + } +} diff --git a/postgres/src/copy_in_writer.rs b/postgres/src/copy_in_writer.rs index fc11818ab..c996ed857 100644 --- a/postgres/src/copy_in_writer.rs +++ b/postgres/src/copy_in_writer.rs @@ -1,5 +1,5 @@ +use crate::connection::ConnectionRef; use crate::lazy_pin::LazyPin; -use crate::Rt; use bytes::{Bytes, BytesMut}; use futures::SinkExt; use std::io; @@ -10,15 +10,15 @@ use tokio_postgres::{CopyInSink, Error}; /// /// The copy *must* be explicitly completed via the `finish` method. If it is not, the copy will be aborted. pub struct CopyInWriter<'a> { - pub(crate) runtime: Rt<'a>, + pub(crate) connection: ConnectionRef<'a>, pub(crate) sink: LazyPin>, buf: BytesMut, } impl<'a> CopyInWriter<'a> { - pub(crate) fn new(runtime: Rt<'a>, sink: CopyInSink) -> CopyInWriter<'a> { + pub(crate) fn new(connection: ConnectionRef<'a>, sink: CopyInSink) -> CopyInWriter<'a> { CopyInWriter { - runtime, + connection, sink: LazyPin::new(sink), buf: BytesMut::new(), } @@ -29,7 +29,7 @@ impl<'a> CopyInWriter<'a> { /// If this is not called, the copy will be aborted. pub fn finish(mut self) -> Result { self.flush_inner()?; - self.runtime.block_on(self.sink.pinned().finish()) + self.connection.block_on(self.sink.pinned().finish()) } fn flush_inner(&mut self) -> Result<(), Error> { @@ -37,7 +37,7 @@ impl<'a> CopyInWriter<'a> { return Ok(()); } - self.runtime + self.connection .block_on(self.sink.pinned().send(self.buf.split().freeze())) } } diff --git a/postgres/src/copy_out_reader.rs b/postgres/src/copy_out_reader.rs index 9091e2200..a205d1a1a 100644 --- a/postgres/src/copy_out_reader.rs +++ b/postgres/src/copy_out_reader.rs @@ -1,5 +1,5 @@ +use crate::connection::ConnectionRef; use crate::lazy_pin::LazyPin; -use crate::Rt; use bytes::{Buf, Bytes}; use futures::StreamExt; use std::io::{self, BufRead, Read}; @@ -7,15 +7,15 @@ use tokio_postgres::CopyOutStream; /// The reader returned by the `copy_out` method. pub struct CopyOutReader<'a> { - pub(crate) runtime: Rt<'a>, + pub(crate) connection: ConnectionRef<'a>, pub(crate) stream: LazyPin, cur: Bytes, } impl<'a> CopyOutReader<'a> { - pub(crate) fn new(runtime: Rt<'a>, stream: CopyOutStream) -> CopyOutReader<'a> { + pub(crate) fn new(connection: ConnectionRef<'a>, stream: CopyOutStream) -> CopyOutReader<'a> { CopyOutReader { - runtime, + connection, stream: LazyPin::new(stream), cur: Bytes::new(), } @@ -35,10 +35,14 @@ impl Read for CopyOutReader<'_> { impl BufRead for CopyOutReader<'_> { fn fill_buf(&mut self) -> io::Result<&[u8]> { if !self.cur.has_remaining() { - match self.runtime.block_on(self.stream.pinned().next()) { - Some(Ok(cur)) => self.cur = cur, - Some(Err(e)) => return Err(io::Error::new(io::ErrorKind::Other, e)), - None => {} + let mut stream = self.stream.pinned(); + match self + .connection + .block_on({ async { stream.next().await.transpose() } }) + { + Ok(Some(cur)) => self.cur = cur, + Err(e) => return Err(io::Error::new(io::ErrorKind::Other, e)), + Ok(None) => {} }; } diff --git a/postgres/src/lib.rs b/postgres/src/lib.rs index 2b2dcec38..80380a87e 100644 --- a/postgres/src/lib.rs +++ b/postgres/src/lib.rs @@ -65,8 +65,8 @@ pub use fallible_iterator; pub use tokio_postgres::{ - error, row, tls, types, Column, IsolationLevel, Portal, SimpleQueryMessage, Socket, Statement, - ToStatement, + error, row, tls, types, Column, IsolationLevel, Notification, Portal, SimpleQueryMessage, + Socket, Statement, ToStatement, }; pub use crate::cancel_token::CancelToken; @@ -77,6 +77,8 @@ pub use crate::copy_out_reader::CopyOutReader; #[doc(no_inline)] pub use crate::error::Error; pub use crate::generic_client::GenericClient; +#[doc(inline)] +pub use crate::notifications::Notifications; #[doc(no_inline)] pub use crate::row::{Row, SimpleQueryRow}; pub use crate::row_iter::RowIter; @@ -89,10 +91,12 @@ pub mod binary_copy; mod cancel_token; mod client; pub mod config; +mod connection; mod copy_in_writer; mod copy_out_reader; mod generic_client; mod lazy_pin; +pub mod notifications; mod row_iter; mod transaction; mod transaction_builder; diff --git a/postgres/src/notifications.rs b/postgres/src/notifications.rs new file mode 100644 index 000000000..e8c681548 --- /dev/null +++ b/postgres/src/notifications.rs @@ -0,0 +1,161 @@ +//! Asynchronous notifications. + +use crate::connection::ConnectionRef; +use crate::{Error, Notification}; +use fallible_iterator::FallibleIterator; +use futures::{ready, FutureExt}; +use std::task::Poll; +use std::time::Duration; +use tokio::time::{self, Delay, Instant}; + +/// Notifications from a PostgreSQL backend. +pub struct Notifications<'a> { + connection: ConnectionRef<'a>, +} + +impl<'a> Notifications<'a> { + pub(crate) fn new(connection: ConnectionRef<'a>) -> Notifications<'a> { + Notifications { connection } + } + + /// Returns the number of already buffered pending notifications. + pub fn len(&self) -> usize { + self.connection.notifications().len() + } + + /// Determines if there are any already buffered pending notifications. + pub fn is_empty(&self) -> bool { + self.connection.notifications().is_empty() + } + + /// Returns a nonblocking iterator over notifications. + /// + /// If there are no already buffered pending notifications, this iterator will poll the connection but will not + /// block waiting on notifications over the network. A return value of `None` either indicates that there are no + /// pending notifications or that the server has disconnected. + /// + /// # Note + /// + /// This iterator may start returning `Some` after previously returning `None` if more notifications are received. + pub fn iter(&mut self) -> Iter<'_> { + Iter { + connection: self.connection.as_ref(), + } + } + + /// Returns a blocking iterator over notifications. + /// + /// If there are no already buffered pending notifications, this iterator will block indefinitely waiting on the + /// PostgreSQL backend server to send one. It will only return `None` if the server has disconnected. + pub fn blocking_iter(&mut self) -> BlockingIter<'_> { + BlockingIter { + connection: self.connection.as_ref(), + } + } + + /// Returns an iterator over notifications which blocks a limited amount of time. + /// + /// If there are no already buffered pending notifications, this iterator will block waiting on the PostgreSQL + /// backend server to send one up to the provided timeout. A return value of `None` either indicates that there are + /// no pending notifications or that the server has disconnected. + /// + /// # Note + /// + /// This iterator may start returning `Some` after previously returning `None` if more notifications are received. + pub fn timeout_iter(&mut self, timeout: Duration) -> TimeoutIter<'_> { + TimeoutIter { + delay: self.connection.enter(|| time::delay_for(timeout)), + timeout, + connection: self.connection.as_ref(), + } + } +} + +/// A nonblocking iterator over pending notifications. +pub struct Iter<'a> { + connection: ConnectionRef<'a>, +} + +impl<'a> FallibleIterator for Iter<'a> { + type Item = Notification; + type Error = Error; + + fn next(&mut self) -> Result, Self::Error> { + if let Some(notification) = self.connection.notifications_mut().pop_front() { + return Ok(Some(notification)); + } + + self.connection + .poll_block_on(|_, notifications, _| Poll::Ready(Ok(notifications.pop_front()))) + } + + fn size_hint(&self) -> (usize, Option) { + (self.connection.notifications().len(), None) + } +} + +/// A blocking iterator over pending notifications. +pub struct BlockingIter<'a> { + connection: ConnectionRef<'a>, +} + +impl<'a> FallibleIterator for BlockingIter<'a> { + type Item = Notification; + type Error = Error; + + fn next(&mut self) -> Result, Self::Error> { + if let Some(notification) = self.connection.notifications_mut().pop_front() { + return Ok(Some(notification)); + } + + self.connection + .poll_block_on(|_, notifications, done| match notifications.pop_front() { + Some(notification) => Poll::Ready(Ok(Some(notification))), + None if done => Poll::Ready(Ok(None)), + None => Poll::Pending, + }) + } + + fn size_hint(&self) -> (usize, Option) { + (self.connection.notifications().len(), None) + } +} + +/// A time-limited blocking iterator over pending notifications. +pub struct TimeoutIter<'a> { + connection: ConnectionRef<'a>, + delay: Delay, + timeout: Duration, +} + +impl<'a> FallibleIterator for TimeoutIter<'a> { + type Item = Notification; + type Error = Error; + + fn next(&mut self) -> Result, Self::Error> { + if let Some(notification) = self.connection.notifications_mut().pop_front() { + self.delay.reset(Instant::now() + self.timeout); + return Ok(Some(notification)); + } + + let delay = &mut self.delay; + let timeout = self.timeout; + self.connection.poll_block_on(|cx, notifications, done| { + match notifications.pop_front() { + Some(notification) => { + delay.reset(Instant::now() + timeout); + return Poll::Ready(Ok(Some(notification))); + } + None if done => return Poll::Ready(Ok(None)), + None => {} + } + + ready!(delay.poll_unpin(cx)); + Poll::Ready(Ok(None)) + }) + } + + fn size_hint(&self) -> (usize, Option) { + (self.connection.notifications().len(), None) + } +} diff --git a/postgres/src/row_iter.rs b/postgres/src/row_iter.rs index 4be5f3477..3cd41b900 100644 --- a/postgres/src/row_iter.rs +++ b/postgres/src/row_iter.rs @@ -1,4 +1,4 @@ -use crate::Rt; +use crate::connection::ConnectionRef; use fallible_iterator::FallibleIterator; use futures::StreamExt; use std::pin::Pin; @@ -6,19 +6,14 @@ use tokio_postgres::{Error, Row, RowStream}; /// The iterator returned by `query_raw`. pub struct RowIter<'a> { - runtime: Rt<'a>, + connection: ConnectionRef<'a>, it: Pin>, } -// no-op impl to extend the borrow until drop -impl Drop for RowIter<'_> { - fn drop(&mut self) {} -} - impl<'a> RowIter<'a> { - pub(crate) fn new(runtime: Rt<'a>, stream: RowStream) -> RowIter<'a> { + pub(crate) fn new(connection: ConnectionRef<'a>, stream: RowStream) -> RowIter<'a> { RowIter { - runtime, + connection, it: Box::pin(stream), } } @@ -29,6 +24,8 @@ impl FallibleIterator for RowIter<'_> { type Error = Error; fn next(&mut self) -> Result, Error> { - self.runtime.block_on(self.it.next()).transpose() + let it = &mut self.it; + self.connection + .block_on(async { it.next().await.transpose() }) } } diff --git a/postgres/src/test.rs b/postgres/src/test.rs index 449aac012..9edde8e32 100644 --- a/postgres/src/test.rs +++ b/postgres/src/test.rs @@ -309,3 +309,93 @@ fn cancel_query() { cancel_thread.join().unwrap(); } + +#[test] +fn notifications_iter() { + let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); + + client + .batch_execute( + "\ + LISTEN notifications_iter; + NOTIFY notifications_iter, 'hello'; + NOTIFY notifications_iter, 'world'; + ", + ) + .unwrap(); + + let notifications = client.notifications().iter().collect::>().unwrap(); + assert_eq!(notifications.len(), 2); + assert_eq!(notifications[0].payload(), "hello"); + assert_eq!(notifications[1].payload(), "world"); +} + +#[test] +fn notifications_blocking_iter() { + let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); + + client + .batch_execute( + "\ + LISTEN notifications_blocking_iter; + NOTIFY notifications_blocking_iter, 'hello'; + ", + ) + .unwrap(); + + thread::spawn(|| { + let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); + + thread::sleep(Duration::from_secs(1)); + client + .batch_execute("NOTIFY notifications_blocking_iter, 'world'") + .unwrap(); + }); + + let notifications = client + .notifications() + .blocking_iter() + .take(2) + .collect::>() + .unwrap(); + assert_eq!(notifications.len(), 2); + assert_eq!(notifications[0].payload(), "hello"); + assert_eq!(notifications[1].payload(), "world"); +} + +#[test] +fn notifications_timeout_iter() { + let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); + + client + .batch_execute( + "\ + LISTEN notifications_timeout_iter; + NOTIFY notifications_timeout_iter, 'hello'; + ", + ) + .unwrap(); + + thread::spawn(|| { + let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); + + thread::sleep(Duration::from_secs(1)); + client + .batch_execute("NOTIFY notifications_timeout_iter, 'world'") + .unwrap(); + + thread::sleep(Duration::from_secs(10)); + client + .batch_execute("NOTIFY notifications_timeout_iter, '!'") + .unwrap(); + }); + + let notifications = client + .notifications() + .timeout_iter(Duration::from_secs(2)) + .collect::>() + .unwrap(); + assert_eq!(notifications.len(), 2); + assert_eq!(notifications[0].payload(), "hello"); + assert_eq!(notifications[1].payload(), "world"); +} diff --git a/postgres/src/transaction.rs b/postgres/src/transaction.rs index e5b3682f0..25bfff578 100644 --- a/postgres/src/transaction.rs +++ b/postgres/src/transaction.rs @@ -1,7 +1,5 @@ -use crate::{ - CancelToken, CopyInWriter, CopyOutReader, Portal, RowIter, Rt, Statement, ToStatement, -}; -use tokio::runtime::Runtime; +use crate::connection::ConnectionRef; +use crate::{CancelToken, CopyInWriter, CopyOutReader, Portal, RowIter, Statement, ToStatement}; use tokio_postgres::types::{ToSql, Type}; use tokio_postgres::{Error, Row, SimpleQueryMessage}; @@ -10,45 +8,41 @@ use tokio_postgres::{Error, Row, SimpleQueryMessage}; /// Transactions will implicitly roll back by default when dropped. Use the `commit` method to commit the changes made /// in the transaction. Transactions can be nested, with inner transactions implemented via savepoints. pub struct Transaction<'a> { - runtime: &'a mut Runtime, + connection: ConnectionRef<'a>, transaction: tokio_postgres::Transaction<'a>, } impl<'a> Transaction<'a> { pub(crate) fn new( - runtime: &'a mut Runtime, + connection: ConnectionRef<'a>, transaction: tokio_postgres::Transaction<'a>, ) -> Transaction<'a> { Transaction { - runtime, + connection, transaction, } } - fn rt(&mut self) -> Rt<'_> { - Rt(self.runtime) - } - /// Consumes the transaction, committing all changes made within it. - pub fn commit(self) -> Result<(), Error> { - self.runtime.block_on(self.transaction.commit()) + pub fn commit(mut self) -> Result<(), Error> { + self.connection.block_on(self.transaction.commit()) } /// Rolls the transaction back, discarding all changes made within it. /// /// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller. - pub fn rollback(self) -> Result<(), Error> { - self.runtime.block_on(self.transaction.rollback()) + pub fn rollback(mut self) -> Result<(), Error> { + self.connection.block_on(self.transaction.rollback()) } /// Like `Client::prepare`. pub fn prepare(&mut self, query: &str) -> Result { - self.runtime.block_on(self.transaction.prepare(query)) + self.connection.block_on(self.transaction.prepare(query)) } /// Like `Client::prepare_typed`. pub fn prepare_typed(&mut self, query: &str, types: &[Type]) -> Result { - self.runtime + self.connection .block_on(self.transaction.prepare_typed(query, types)) } @@ -57,7 +51,7 @@ impl<'a> Transaction<'a> { where T: ?Sized + ToStatement, { - self.runtime + self.connection .block_on(self.transaction.execute(query, params)) } @@ -66,7 +60,8 @@ impl<'a> Transaction<'a> { where T: ?Sized + ToStatement, { - self.runtime.block_on(self.transaction.query(query, params)) + self.connection + .block_on(self.transaction.query(query, params)) } /// Like `Client::query_one`. @@ -74,7 +69,7 @@ impl<'a> Transaction<'a> { where T: ?Sized + ToStatement, { - self.runtime + self.connection .block_on(self.transaction.query_one(query, params)) } @@ -87,7 +82,7 @@ impl<'a> Transaction<'a> { where T: ?Sized + ToStatement, { - self.runtime + self.connection .block_on(self.transaction.query_opt(query, params)) } @@ -99,9 +94,9 @@ impl<'a> Transaction<'a> { I::IntoIter: ExactSizeIterator, { let stream = self - .runtime + .connection .block_on(self.transaction.query_raw(query, params))?; - Ok(RowIter::new(self.rt(), stream)) + Ok(RowIter::new(self.connection.as_ref(), stream)) } /// Binds parameters to a statement, creating a "portal". @@ -118,7 +113,8 @@ impl<'a> Transaction<'a> { where T: ?Sized + ToStatement, { - self.runtime.block_on(self.transaction.bind(query, params)) + self.connection + .block_on(self.transaction.bind(query, params)) } /// Continues execution of a portal, returning the next set of rows. @@ -126,7 +122,7 @@ impl<'a> Transaction<'a> { /// Unlike `query`, portals can be incrementally evaluated by limiting the number of rows returned in each call to /// `query_portal`. If the requested number is negative or 0, all remaining rows will be returned. pub fn query_portal(&mut self, portal: &Portal, max_rows: i32) -> Result, Error> { - self.runtime + self.connection .block_on(self.transaction.query_portal(portal, max_rows)) } @@ -137,9 +133,9 @@ impl<'a> Transaction<'a> { max_rows: i32, ) -> Result, Error> { let stream = self - .runtime + .connection .block_on(self.transaction.query_portal_raw(portal, max_rows))?; - Ok(RowIter::new(self.rt(), stream)) + Ok(RowIter::new(self.connection.as_ref(), stream)) } /// Like `Client::copy_in`. @@ -147,8 +143,8 @@ impl<'a> Transaction<'a> { where T: ?Sized + ToStatement, { - let sink = self.runtime.block_on(self.transaction.copy_in(query))?; - Ok(CopyInWriter::new(self.rt(), sink)) + let sink = self.connection.block_on(self.transaction.copy_in(query))?; + Ok(CopyInWriter::new(self.connection.as_ref(), sink)) } /// Like `Client::copy_out`. @@ -156,18 +152,20 @@ impl<'a> Transaction<'a> { where T: ?Sized + ToStatement, { - let stream = self.runtime.block_on(self.transaction.copy_out(query))?; - Ok(CopyOutReader::new(self.rt(), stream)) + let stream = self.connection.block_on(self.transaction.copy_out(query))?; + Ok(CopyOutReader::new(self.connection.as_ref(), stream)) } /// Like `Client::simple_query`. pub fn simple_query(&mut self, query: &str) -> Result, Error> { - self.runtime.block_on(self.transaction.simple_query(query)) + self.connection + .block_on(self.transaction.simple_query(query)) } /// Like `Client::batch_execute`. pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> { - self.runtime.block_on(self.transaction.batch_execute(query)) + self.connection + .block_on(self.transaction.batch_execute(query)) } /// Like `Client::cancel_token`. @@ -177,9 +175,9 @@ impl<'a> Transaction<'a> { /// Like `Client::transaction`. pub fn transaction(&mut self) -> Result, Error> { - let transaction = self.runtime.block_on(self.transaction.transaction())?; + let transaction = self.connection.block_on(self.transaction.transaction())?; Ok(Transaction { - runtime: self.runtime, + connection: self.connection.as_ref(), transaction, }) } diff --git a/postgres/src/transaction_builder.rs b/postgres/src/transaction_builder.rs index d87d1a128..e0f8a56e8 100644 --- a/postgres/src/transaction_builder.rs +++ b/postgres/src/transaction_builder.rs @@ -1,18 +1,21 @@ +use crate::connection::ConnectionRef; use crate::{Error, IsolationLevel, Transaction}; -use tokio::runtime::Runtime; /// A builder for database transactions. pub struct TransactionBuilder<'a> { - runtime: &'a mut Runtime, + connection: ConnectionRef<'a>, builder: tokio_postgres::TransactionBuilder<'a>, } impl<'a> TransactionBuilder<'a> { pub(crate) fn new( - runtime: &'a mut Runtime, + connection: ConnectionRef<'a>, builder: tokio_postgres::TransactionBuilder<'a>, ) -> TransactionBuilder<'a> { - TransactionBuilder { runtime, builder } + TransactionBuilder { + connection, + builder, + } } /// Sets the isolation level of the transaction. @@ -40,8 +43,8 @@ impl<'a> TransactionBuilder<'a> { /// Begins the transaction. /// /// The transaction will roll back by default - use the `commit` method to commit it. - pub fn start(self) -> Result, Error> { - let transaction = self.runtime.block_on(self.builder.start())?; - Ok(Transaction::new(self.runtime, transaction)) + pub fn start(mut self) -> Result, Error> { + let transaction = self.connection.block_on(self.builder.start())?; + Ok(Transaction::new(self.connection, transaction)) } } diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 738601159..b01037edc 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -559,11 +559,9 @@ async fn copy_out() { .copy_out(&stmt) .await .unwrap() - .try_fold(BytesMut::new(), |mut buf, chunk| { - async move { - buf.extend_from_slice(&chunk); - Ok(buf) - } + .try_fold(BytesMut::new(), |mut buf, chunk| async move { + buf.extend_from_slice(&chunk); + Ok(buf) }) .await .unwrap();