Skip to content

Avoid copies in copy_in #451

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions postgres-protocol/src/message/frontend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#![allow(missing_docs)]

use byteorder::{BigEndian, ByteOrder, WriteBytesExt};
use bytes::{Buf, BufMut, BytesMut, IntoBuf};
use std::convert::TryFrom;
use std::error::Error;
use std::io;
use std::marker;
Expand Down Expand Up @@ -263,6 +265,40 @@ pub fn copy_data(data: &[u8], buf: &mut Vec<u8>) -> io::Result<()> {
})
}

pub struct CopyData<T> {
buf: T,
len: i32,
}

impl<T> CopyData<T>
where
T: Buf,
{
pub fn new<U>(buf: U) -> io::Result<CopyData<T>>
where
U: IntoBuf<Buf = T>,
{
let buf = buf.into_buf();

let len = buf
.remaining()
.checked_add(4)
.and_then(|l| i32::try_from(l).ok())
.ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "message length overflow")
})?;

Ok(CopyData { buf, len })
}

pub fn write(self, out: &mut BytesMut) {
out.reserve(self.len as usize + 1);
out.put_u8(b'd');
out.put_i32_be(self.len);
out.put(self.buf);
}
}

#[inline]
pub fn copy_done(buf: &mut Vec<u8>) {
buf.push(b'c');
Expand Down
2 changes: 1 addition & 1 deletion tokio-postgres/src/impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ pub struct CopyIn<S>(pub(crate) proto::CopyInFuture<S>)
where
S: Stream,
S::Item: IntoBuf,
<S::Item as IntoBuf>::Buf: Send,
<S::Item as IntoBuf>::Buf: 'static + Send,
S::Error: Into<Box<dyn error::Error + Sync + Send>>;

impl<S> Future for CopyIn<S>
Expand Down
2 changes: 1 addition & 1 deletion tokio-postgres/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ impl Client {
where
S: Stream,
S::Item: IntoBuf,
<S::Item as IntoBuf>::Buf: Send,
<S::Item as IntoBuf>::Buf: 'static + Send,
// FIXME error type?
S::Error: Into<Box<dyn StdError + Sync + Send>>,
{
Expand Down
28 changes: 18 additions & 10 deletions tokio-postgres/src/proto/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use std::sync::{Arc, Weak};
use tokio_io::{AsyncRead, AsyncWrite};

use crate::proto::bind::BindFuture;
use crate::proto::codec::FrontendMessage;
use crate::proto::connection::{Request, RequestMessages};
use crate::proto::copy_in::{CopyInFuture, CopyInReceiver, CopyMessage};
use crate::proto::copy_out::CopyOutStream;
Expand Down Expand Up @@ -185,8 +186,12 @@ impl Client {
if let Ok(ref mut buf) = buf {
frontend::sync(buf);
}
let pending =
PendingRequest(buf.map(|m| (RequestMessages::Single(m), self.0.idle.guard())));
let pending = PendingRequest(buf.map(|m| {
(
RequestMessages::Single(FrontendMessage::Raw(m)),
self.0.idle.guard(),
)
}));
BindFuture::new(self.clone(), pending, name, statement.clone())
}

Expand All @@ -208,12 +213,12 @@ impl Client {
where
S: Stream,
S::Item: IntoBuf,
<S::Item as IntoBuf>::Buf: Send,
<S::Item as IntoBuf>::Buf: 'static + Send,
S::Error: Into<Box<dyn StdError + Sync + Send>>,
{
let (mut sender, receiver) = mpsc::channel(1);
let pending = PendingRequest(self.excecute_message(statement, params).map(|data| {
match sender.start_send(CopyMessage { data, done: false }) {
match sender.start_send(CopyMessage::Message(data)) {
Ok(AsyncSink::Ready) => {}
_ => unreachable!("channel should have capacity"),
}
Expand Down Expand Up @@ -278,7 +283,7 @@ impl Client {
frontend::sync(&mut buf);
let (sender, _) = mpsc::channel(0);
let _ = self.0.sender.unbounded_send(Request {
messages: RequestMessages::Single(buf),
messages: RequestMessages::Single(FrontendMessage::Raw(buf)),
sender,
idle: None,
});
Expand Down Expand Up @@ -326,20 +331,23 @@ impl Client {
&self,
statement: &Statement,
params: &[&dyn ToSql],
) -> Result<Vec<u8>, Error> {
) -> Result<FrontendMessage, Error> {
let mut buf = self.bind_message(statement, "", params)?;
frontend::execute("", 0, &mut buf).map_err(Error::parse)?;
frontend::sync(&mut buf);
Ok(buf)
Ok(FrontendMessage::Raw(buf))
}

fn pending<F>(&self, messages: F) -> PendingRequest
where
F: FnOnce(&mut Vec<u8>) -> Result<(), Error>,
{
let mut buf = vec![];
PendingRequest(
messages(&mut buf).map(|()| (RequestMessages::Single(buf), self.0.idle.guard())),
)
PendingRequest(messages(&mut buf).map(|()| {
(
RequestMessages::Single(FrontendMessage::Raw(buf)),
self.0.idle.guard(),
)
}))
}
}
18 changes: 14 additions & 4 deletions tokio-postgres/src/proto/codec.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
use bytes::BytesMut;
use bytes::{Buf, BytesMut};
use postgres_protocol::message::backend;
use postgres_protocol::message::frontend::CopyData;
use std::io;
use tokio_codec::{Decoder, Encoder};

pub enum FrontendMessage {
Raw(Vec<u8>),
CopyData(CopyData<Box<dyn Buf + Send>>),
}

pub struct PostgresCodec;

impl Encoder for PostgresCodec {
type Item = Vec<u8>;
type Item = FrontendMessage;
type Error = io::Error;

fn encode(&mut self, item: Vec<u8>, dst: &mut BytesMut) -> Result<(), io::Error> {
dst.extend_from_slice(&item);
fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> Result<(), io::Error> {
match item {
FrontendMessage::Raw(buf) => dst.extend_from_slice(&buf),
FrontendMessage::CopyData(data) => data.write(dst),
}

Ok(())
}
}
Expand Down
12 changes: 6 additions & 6 deletions tokio-postgres/src/proto/connect_raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::collections::HashMap;
use tokio_codec::Framed;
use tokio_io::{AsyncRead, AsyncWrite};

use crate::proto::{Client, Connection, MaybeTlsStream, PostgresCodec, TlsFuture};
use crate::proto::{Client, Connection, FrontendMessage, MaybeTlsStream, PostgresCodec, TlsFuture};
use crate::tls::ChannelBinding;
use crate::{Config, Error, TlsConnect};

Expand Down Expand Up @@ -111,7 +111,7 @@ where
let stream = Framed::new(stream, PostgresCodec);

transition!(SendingStartup {
future: stream.send(buf),
future: stream.send(FrontendMessage::Raw(buf)),
config: state.config,
idx: state.idx,
channel_binding,
Expand Down Expand Up @@ -156,7 +156,7 @@ where
let mut buf = vec![];
frontend::password_message(pass, &mut buf).map_err(Error::encode)?;
transition!(SendingPassword {
future: state.stream.send(buf),
future: state.stream.send(FrontendMessage::Raw(buf)),
config: state.config,
idx: state.idx,
})
Expand All @@ -178,7 +178,7 @@ where
let mut buf = vec![];
frontend::password_message(output.as_bytes(), &mut buf).map_err(Error::encode)?;
transition!(SendingPassword {
future: state.stream.send(buf),
future: state.stream.send(FrontendMessage::Raw(buf)),
config: state.config,
idx: state.idx,
})
Expand Down Expand Up @@ -235,7 +235,7 @@ where
.map_err(Error::encode)?;

transition!(SendingSasl {
future: state.stream.send(buf),
future: state.stream.send(FrontendMessage::Raw(buf)),
scram,
config: state.config,
idx: state.idx,
Expand Down Expand Up @@ -293,7 +293,7 @@ where
let mut buf = vec![];
frontend::sasl_response(state.scram.message(), &mut buf).map_err(Error::encode)?;
transition!(SendingSasl {
future: state.stream.send(buf),
future: state.stream.send(FrontendMessage::Raw(buf)),
scram: state.scram,
config: state.config,
idx: state.idx,
Expand Down
8 changes: 4 additions & 4 deletions tokio-postgres/src/proto/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@ use std::io;
use tokio_codec::Framed;
use tokio_io::{AsyncRead, AsyncWrite};

use crate::proto::codec::PostgresCodec;
use crate::proto::codec::{FrontendMessage, PostgresCodec};
use crate::proto::copy_in::CopyInReceiver;
use crate::proto::idle::IdleGuard;
use crate::{AsyncMessage, Notification};
use crate::{DbError, Error};

pub enum RequestMessages {
Single(Vec<u8>),
Single(FrontendMessage),
CopyIn {
receiver: CopyInReceiver,
pending_message: Option<Vec<u8>>,
pending_message: Option<FrontendMessage>,
},
}

Expand Down Expand Up @@ -188,7 +188,7 @@ where
self.state = State::Terminating;
let mut request = vec![];
frontend::terminate(&mut request);
RequestMessages::Single(request)
RequestMessages::Single(FrontendMessage::Raw(request))
}
Async::Ready(None) => {
trace!(
Expand Down
Loading