From 8bd0e0c88cdddc5449ef76e4ee47c00d3a2f59d6 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sun, 18 Dec 2016 14:33:23 -0800 Subject: [PATCH] Rewrite the backend message parser The owned variant of the old parser incurred a lot of allocation overhead, particularly for row data messages. Rewrite it so that it behaves like the borrowed parser, except that the message can optionally own any dynamically sized data it contains. --- Cargo.toml | 1 + src/lib.rs | 1 + src/message/backend.rs | 757 ++++++++++++++++++++++++++++++ src/message/backend/borrowed.rs | 785 -------------------------------- src/message/backend/mod.rs | 98 ---- 5 files changed, 759 insertions(+), 883 deletions(-) create mode 100644 src/message/backend.rs delete mode 100644 src/message/backend/borrowed.rs delete mode 100644 src/message/backend/mod.rs diff --git a/Cargo.toml b/Cargo.toml index dfba0a0..53895dc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,3 +13,4 @@ byteorder = "0.5" fallible-iterator = "0.1" hex = "0.2" md5 = "0.2" +memchr = "0.1" diff --git a/src/lib.rs b/src/lib.rs index 269090c..0db837b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,6 +15,7 @@ extern crate byteorder; extern crate fallible_iterator; extern crate hex; extern crate md5; +extern crate memchr; use byteorder::{WriteBytesExt, BigEndian}; use std::io; diff --git a/src/message/backend.rs b/src/message/backend.rs new file mode 100644 index 0000000..26995ce --- /dev/null +++ b/src/message/backend.rs @@ -0,0 +1,757 @@ +#![allow(missing_docs)] + +use byteorder::{ReadBytesExt, BigEndian}; +use memchr::memchr; +use fallible_iterator::FallibleIterator; +use std::io::{self, Read}; +use std::marker::PhantomData; +use std::ops::Deref; +use std::str; + +use Oid; + +/// An enum representing Postgres backend messages. +pub enum Message { + AuthenticationCleartextPassword, + AuthenticationGss, + AuthenticationKerberosV5, + AuthenticationMd5Password(AuthenticationMd5PasswordBody), + AuthenticationOk, + AuthenticationScmCredential, + AuthenticationSspi, + BackendKeyData(BackendKeyDataBody), + BindComplete, + CloseComplete, + CommandComplete(CommandCompleteBody), + CopyData(CopyDataBody), + CopyDone, + CopyInResponse(CopyInResponseBody), + CopyOutResponse(CopyOutResponseBody), + DataRow(DataRowBody), + EmptyQueryResponse, + ErrorResponse(ErrorResponseBody), + NoData, + NoticeResponse(NoticeResponseBody), + NotificationResponse(NotificationResponseBody), + ParameterDescription(ParameterDescriptionBody), + ParameterStatus(ParameterStatusBody), + ParseComplete, + PortalSuspended, + ReadyForQuery(ReadyForQueryBody), + RowDescription(RowDescriptionBody), + #[doc(hidden)] + __ForExtensibility, +} + +impl<'a> Message<&'a [u8]> { + /// Attempts to parse a backend message from the buffer. + /// + /// This method is unfortunately difficult to use due to deficiencies in the compiler's borrow + /// checker. + #[inline] + pub fn parse(buf: &'a [u8]) -> io::Result> { + Message::parse_inner(buf) + } +} + +impl Message> { + /// Attempts to parse a backend message from the buffer. + /// + /// In contrast to `parse`, this method produces messages that do not reference the input, + /// buffer by copying any necessary portions internally. + #[inline] + pub fn parse_owned(buf: &[u8]) -> io::Result>> { + Message::parse_inner(buf) + } +} + +impl<'a, T> Message + where T: From<&'a [u8]> +{ + #[inline] + fn parse_inner(buf: &'a [u8]) -> io::Result> { + if buf.len() < 5 { + return Ok(ParseResult::Incomplete { required_size: None }); + } + + let mut r = buf; + let tag = r.read_u8().unwrap(); + // add a byte for the tag + let len = r.read_u32::().unwrap() as usize + 1; + + if buf.len() < len { + return Ok(ParseResult::Incomplete { required_size: Some(len) }); + } + + let mut buf = &buf[5..len]; + let message = match tag { + b'1' => Message::ParseComplete, + b'2' => Message::BindComplete, + b'3' => Message::CloseComplete, + b'A' => { + let process_id = try!(buf.read_i32::()); + let channel_end = try!(find_null(buf, 0)); + let message_end = try!(find_null(buf, channel_end + 1)); + let storage = buf[..message_end].into(); + buf = &buf[message_end + 1..]; + Message::NotificationResponse(NotificationResponseBody { + storage: storage, + process_id: process_id, + channel_end: channel_end, + }) + } + b'c' => Message::CopyDone, + b'C' => { + let tag_end = try!(find_null(buf, 0)); + let storage = buf[..tag_end].into(); + buf = &buf[tag_end + 1..]; + Message::CommandComplete(CommandCompleteBody { + storage: storage, + }) + } + b'd' => { + let storage = buf.into(); + buf = &[]; + Message::CopyData(CopyDataBody { storage: storage }) + } + b'D' => { + let len = try!(buf.read_u16::()); + let storage = buf.into(); + buf = &[]; + Message::DataRow(DataRowBody { + storage: storage, + len: len, + }) + } + b'E' => { + let storage = buf.into(); + buf = &[]; + Message::ErrorResponse(ErrorResponseBody { storage: storage }) + } + b'G' => { + let format = try!(buf.read_u8()); + let len = try!(buf.read_u16::()); + let storage = buf.into(); + buf = &[]; + Message::CopyInResponse(CopyInResponseBody { + format: format, + len: len, + storage: storage, + }) + } + b'H' => { + let format = try!(buf.read_u8()); + let len = try!(buf.read_u16::()); + let storage = buf.into(); + buf = &[]; + Message::CopyOutResponse(CopyOutResponseBody { + format: format, + len: len, + storage: storage, + }) + } + b'I' => Message::EmptyQueryResponse, + b'K' => { + let process_id = try!(buf.read_i32::()); + let secret_key = try!(buf.read_i32::()); + Message::BackendKeyData(BackendKeyDataBody { + process_id: process_id, + secret_key: secret_key, + _p: PhantomData, + }) + } + b'n' => Message::NoData, + b'N' => { + let storage = buf.into(); + buf = &[]; + Message::NoticeResponse(NoticeResponseBody { + storage: storage, + }) + } + b'R' => { + match try!(buf.read_i32::()) { + 0 => Message::AuthenticationOk, + 2 => Message::AuthenticationKerberosV5, + 3 => Message::AuthenticationCleartextPassword, + 5 => { + let mut salt = [0; 4]; + try!(buf.read_exact(&mut salt)); + Message::AuthenticationMd5Password(AuthenticationMd5PasswordBody { + salt: salt, + _p: PhantomData, + }) + } + 6 => Message::AuthenticationScmCredential, + 7 => Message::AuthenticationGss, + 9 => Message::AuthenticationSspi, + tag => { + return Err(io::Error::new(io::ErrorKind::InvalidInput, + format!("unknown authentication tag `{}`", tag))); + } + } + } + b's' => Message::PortalSuspended, + b'S' => { + let name_end = try!(find_null(buf, 0)); + let value_end = try!(find_null(buf, name_end + 1)); + let storage = buf[0..value_end].into(); + buf = &buf[value_end + 1..]; + Message::ParameterStatus(ParameterStatusBody { + storage: storage, + name_end: name_end, + }) + } + b't' => { + let len = try!(buf.read_u16::()); + let storage = buf.into(); + buf = &[]; + Message::ParameterDescription(ParameterDescriptionBody { + storage: storage, + len: len, + }) + } + b'T' => { + let len = try!(buf.read_u16::()); + let storage = buf.into(); + buf = &[]; + Message::RowDescription(RowDescriptionBody { + storage: storage, + len: len, + }) + } + b'Z' => { + let status = try!(buf.read_u8()); + Message::ReadyForQuery(ReadyForQueryBody { + status: status, + _p: PhantomData, + }) + } + tag => { + return Err(io::Error::new(io::ErrorKind::InvalidInput, + format!("unknown message tag `{}`", tag))); + } + }; + + if !buf.is_empty() { + return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid message length")); + } + + Ok(ParseResult::Complete { + message: message, + consumed: len, + }) + } +} + +/// The result of an attempted parse. +pub enum ParseResult { + /// The message was successfully parsed. + Complete { + /// The message. + message: Message, + /// The number of bytes of the input buffer consumed to parse this message. + consumed: usize, + }, + /// The buffer did not contain a full message. + Incomplete { + /// The number of total bytes required to parse a message, if known. + /// + /// This value is present if the input buffer contains at least 5 bytes. + required_size: Option, + } +} + +pub struct AuthenticationMd5PasswordBody { + salt: [u8; 4], + _p: PhantomData, +} + +impl AuthenticationMd5PasswordBody + where T: Deref +{ + #[inline] + pub fn salt(&self) -> [u8; 4] { + self.salt + } +} + +pub struct BackendKeyDataBody { + process_id: i32, + secret_key: i32, + _p: PhantomData, +} + +impl BackendKeyDataBody + where T: Deref +{ + #[inline] + pub fn process_id(&self) -> i32 { + self.process_id + } + + #[inline] + pub fn secret_key(&self) -> i32 { + self.secret_key + } +} + +pub struct CommandCompleteBody { + storage: T, +} + +impl CommandCompleteBody + where T: Deref +{ + #[inline] + pub fn tag(&self) -> io::Result<&str> { + get_str(&self.storage) + } +} + +pub struct CopyDataBody { + storage: T, +} + +impl CopyDataBody + where T: Deref +{ + #[inline] + pub fn data(&self) -> &[u8] { + &self.storage + } +} + +pub struct CopyInResponseBody { + storage: T, + len: u16, + format: u8, +} + +impl CopyInResponseBody + where T: Deref +{ + #[inline] + pub fn format(&self) -> u8 { + self.format + } + + #[inline] + pub fn column_formats<'a>(&'a self) -> ColumnFormats<'a> { + ColumnFormats { + remaining: self.len, + buf: &self.storage, + } + } +} + +pub struct ColumnFormats<'a> { + buf: &'a [u8], + remaining: u16, +} + +impl<'a> FallibleIterator for ColumnFormats<'a> { + type Item = u16; + type Error = io::Error; + + #[inline] + fn next(&mut self) -> io::Result> { + if self.remaining == 0 { + if self.buf.is_empty() { + return Ok(None); + } else { + return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid message length")); + } + } + + self.remaining -= 1; + self.buf.read_u16::().map(Some) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.remaining as usize; + (len, Some(len)) + } +} + +pub struct CopyOutResponseBody { + storage: T, + len: u16, + format: u8, +} + +impl CopyOutResponseBody + where T: Deref +{ + #[inline] + pub fn format(&self) -> u8 { + self.format + } + + #[inline] + pub fn column_formats<'a>(&'a self) -> ColumnFormats<'a> { + ColumnFormats { + remaining: self.len, + buf: &self.storage, + } + } +} + +pub struct DataRowBody { + storage: T, + len: u16, +} + +impl DataRowBody + where T: Deref +{ + #[inline] + pub fn values<'a>(&'a self) -> DataRowValues<'a> { + DataRowValues { + buf: &self.storage, + remaining: self.len, + } + } +} + +pub struct DataRowValues<'a> { + buf: &'a [u8], + remaining: u16, +} + +impl<'a> FallibleIterator for DataRowValues<'a> { + type Item = Option<&'a [u8]>; + type Error = io::Error; + + #[inline] + fn next(&mut self) -> io::Result>> { + if self.remaining == 0 { + if self.buf.is_empty() { + return Ok(None); + } else { + return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid message length")); + } + } + + self.remaining -= 1; + let len = try!(self.buf.read_i32::()); + if len < 0 { + Ok(Some(None)) + } else { + let len = len as usize; + if self.buf.len() < len { + return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF")); + } + let (head, tail) = self.buf.split_at(len); + self.buf = tail; + Ok(Some(Some(head))) + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.remaining as usize; + (len, Some(len)) + } +} + +pub struct ErrorResponseBody { + storage: T, +} + +impl ErrorResponseBody + where T: Deref +{ + #[inline] + pub fn fields<'a>(&'a self) -> ErrorFields<'a> { + ErrorFields { + buf: &self.storage + } + } +} + +pub struct ErrorFields<'a> { + buf: &'a [u8], +} + +impl<'a> FallibleIterator for ErrorFields<'a> { + type Item = ErrorField<'a>; + type Error = io::Error; + + #[inline] + fn next(&mut self) -> io::Result>> { + let type_ = try!(self.buf.read_u8()); + if type_ == 0 { + if self.buf.is_empty() { + return Ok(None); + } else { + return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid message length")); + } + } + + let value_end = try!(find_null(self.buf, 0)); + let value = try!(get_str(&self.buf[..value_end])); + self.buf = &self.buf[value_end + 1..]; + + Ok(Some(ErrorField { + type_: type_, + value: value, + })) + } +} + +pub struct ErrorField<'a> { + type_: u8, + value: &'a str, +} + +impl<'a> ErrorField<'a> { + #[inline] + pub fn type_(&self) -> u8 { + self.type_ + } + + #[inline] + pub fn value(&self) -> &str { + self.value + } +} + +pub struct NoticeResponseBody { + storage: T, +} + +impl NoticeResponseBody + where T: Deref +{ + #[inline] + pub fn fields<'a>(&'a self) -> ErrorFields<'a> { + ErrorFields { + buf: &self.storage + } + } +} + +pub struct NotificationResponseBody { + storage: T, + process_id: i32, + channel_end: usize, +} + +impl NotificationResponseBody + where T: Deref +{ + #[inline] + pub fn process_id(&self) -> i32 { + self.process_id + } + + #[inline] + pub fn channel(&self) -> io::Result<&str> { + get_str(&self.storage[..self.channel_end]) + } + + #[inline] + pub fn message(&self) -> io::Result<&str> { + get_str(&self.storage[self.channel_end + 1..]) + } +} + +pub struct ParameterDescriptionBody { + storage: T, + len: u16, +} + +impl ParameterDescriptionBody + where T: Deref +{ + #[inline] + pub fn parameters<'a>(&'a self) -> Parameters<'a> { + Parameters { + buf: &self.storage, + remaining: self.len, + } + } +} + +pub struct Parameters<'a> { + buf: &'a [u8], + remaining: u16, +} + +impl<'a> FallibleIterator for Parameters<'a> { + type Item = Oid; + type Error = io::Error; + + #[inline] + fn next(&mut self) -> io::Result> { + if self.remaining == 0 { + if self.buf.is_empty() { + return Ok(None); + } else { + return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid message length")); + } + } + + self.remaining -= 1; + self.buf.read_u32::().map(Some) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.remaining as usize; + (len, Some(len)) + } +} + +pub struct ParameterStatusBody { + storage: T, + name_end: usize, +} + +impl ParameterStatusBody + where T: Deref +{ + #[inline] + pub fn name(&self) -> io::Result<&str> { + get_str(&self.storage[..self.name_end]) + } + + #[inline] + pub fn value(&self) -> io::Result<&str> { + get_str(&self.storage[self.name_end + 1..]) + } +} + +pub struct ReadyForQueryBody { + status: u8, + _p: PhantomData, +} + +impl ReadyForQueryBody + where T: Deref +{ + #[inline] + pub fn status(&self) -> u8 { + self.status + } +} + +pub struct RowDescriptionBody { + storage: T, + len: u16, +} + +impl RowDescriptionBody + where T: Deref +{ + #[inline] + pub fn fields<'a>(&'a self) -> Fields<'a> { + Fields { + buf: &self.storage, + remaining: self.len, + } + } +} + +pub struct Fields<'a> { + buf: &'a [u8], + remaining: u16, +} + +impl<'a> FallibleIterator for Fields<'a> { + type Item = Field<'a>; + type Error = io::Error; + + #[inline] + fn next(&mut self) -> io::Result>> { + if self.remaining == 0 { + if self.buf.is_empty() { + return Ok(None); + } else { + return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid message length")); + } + } + + self.remaining -= 1; + let name_end = try!(find_null(self.buf, 0)); + let name = try!(get_str(&self.buf[..name_end])); + self.buf = &self.buf[name_end + 1..]; + let table_oid = try!(self.buf.read_u32::()); + let column_id = try!(self.buf.read_i16::()); + let type_oid = try!(self.buf.read_u32::()); + let type_size = try!(self.buf.read_i16::()); + let type_modifier = try!(self.buf.read_i32::()); + let format = try!(self.buf.read_i16::()); + + Ok(Some(Field { + name: name, + table_oid: table_oid, + column_id: column_id, + type_oid: type_oid, + type_size: type_size, + type_modifier: type_modifier, + format: format, + })) + } +} + +pub struct Field<'a> { + name: &'a str, + table_oid: Oid, + column_id: i16, + type_oid: Oid, + type_size: i16, + type_modifier: i32, + format: i16, +} + +impl<'a> Field<'a> { + #[inline] + pub fn name(&self) -> &'a str { + self.name + } + + #[inline] + pub fn table_oid(&self) -> Oid { + self.table_oid + } + + #[inline] + pub fn column_id(&self) -> i16 { + self.column_id + } + + #[inline] + pub fn type_oid(&self) -> Oid { + self.type_oid + } + + #[inline] + pub fn type_size(&self) -> i16 { + self.type_size + } + + #[inline] + pub fn type_modifier(&self) -> i32 { + self.type_modifier + } + + #[inline] + pub fn format(&self) -> i16 { + self.format + } +} + +#[inline] +fn find_null(buf: &[u8], start: usize) -> io::Result { + match memchr(0, &buf[start..]) { + Some(pos) => Ok(pos + start), + None => Err(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF")) + } +} + +#[inline] +fn get_str(buf: &[u8]) -> io::Result<&str> { + str::from_utf8(buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e)) +} \ No newline at end of file diff --git a/src/message/backend/borrowed.rs b/src/message/backend/borrowed.rs deleted file mode 100644 index f8b5f62..0000000 --- a/src/message/backend/borrowed.rs +++ /dev/null @@ -1,785 +0,0 @@ -//! An allocation-free backend message parser. -//! -//! Due to borrow checker deficiencies, this parser is currently very -//! difficult to use in practice. -use byteorder::{ReadBytesExt, BigEndian}; -use fallible_iterator::FallibleIterator; -use std::io; -use std::marker::PhantomData; -use std::str; - -use Oid; -use message::backend::{self, ParseResult, RowDescriptionEntry}; - -macro_rules! check_empty { - ($buf:expr) => { - if !$buf.is_empty() { - return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid message length")); - } - } -} - -/// An enum representing Postgres backend messages. -pub enum Message<'a> { - AuthenticationCleartextPassword, - AuthenticationGss, - AuthenticationKerberosV5, - AuthenticationMd55Password(AuthenticationMd5PasswordBody<'a>), - AuthenticationOk, - AuthenticationScmCredential, - AuthenticationSspi, - BackendKeyData(BackendKeyDataBody<'a>), - BindComplete, - CloseComplete, - CommandComplete(CommandCompleteBody<'a>), - CopyData(CopyDataBody<'a>), - CopyDone, - CopyInResponse(CopyInResponseBody<'a>), - CopyOutResponse(CopyOutResponseBody<'a>), - DataRow(DataRowBody<'a>), - EmptyQueryResponse, - ErrorResponse(ErrorResponseBody<'a>), - NoData, - NoticeResponse(NoticeResponseBody<'a>), - NotificationResponse(NotificationResponseBody<'a>), - ParameterDescription(ParameterDescriptionBody<'a>), - ParameterStatus(ParameterStatusBody<'a>), - ParseComplete, - PortalSuspended, - ReadyForQuery(ReadyForQueryBody<'a>), - RowDescription(RowDescriptionBody<'a>), - #[doc(hidden)] - __ForExtensibility, -} - -impl<'a> Message<'a> { - /// Attempts to deserialize a backend message from the buffer. - pub fn parse(buf: &'a [u8]) -> io::Result>> { - if buf.len() < 5 { - return Ok(ParseResult::Incomplete { required_size: None }); - } - - let mut r = buf; - let tag = r.read_u8().unwrap(); - // add a byte for the tag - let len = r.read_u32::().unwrap() as usize + 1; - - if buf.len() < len { - return Ok(ParseResult::Incomplete { required_size: Some(len) }); - } - - let mut buf = &buf[5..len]; - let message = match tag { - b'1' => { - check_empty!(buf); - Message::ParseComplete - } - b'2' => { - check_empty!(buf); - Message::BindComplete - } - b'3' => { - check_empty!(buf); - Message::CloseComplete - } - b'A' => { - let process_id = try!(buf.read_i32::()); - let channel = try!(buf.read_cstr()); - let message = try!(buf.read_cstr()); - check_empty!(buf); - Message::NotificationResponse(NotificationResponseBody { - process_id: process_id, - channel: channel, - message: message, - }) - } - b'c' => { - check_empty!(buf); - Message::CopyDone - } - b'C' => { - let tag = try!(buf.read_cstr()); - check_empty!(buf); - Message::CommandComplete(CommandCompleteBody { tag: tag }) - } - b'd' => Message::CopyData(CopyDataBody { data: buf }), - b'D' => { - let len = try!(buf.read_u16::()); - Message::DataRow(DataRowBody { - len: len, - buf: buf, - }) - } - b'E' => Message::ErrorResponse(ErrorResponseBody(buf)), - b'G' => { - let format = try!(buf.read_u8()); - let len = try!(buf.read_u16::()); - Message::CopyInResponse(CopyInResponseBody { - format: format, - len: len, - buf: buf, - }) - } - b'H' => { - let format = try!(buf.read_u8()); - let len = try!(buf.read_u16::()); - Message::CopyOutResponse(CopyOutResponseBody { - format: format, - len: len, - buf: buf, - }) - } - b'I' => Message::EmptyQueryResponse, - b'K' => { - let process_id = try!(buf.read_i32::()); - let secret_key = try!(buf.read_i32::()); - check_empty!(buf); - Message::BackendKeyData(BackendKeyDataBody { - process_id: process_id, - secret_key: secret_key, - _p: PhantomData, - }) - } - b'n' => { - check_empty!(buf); - Message::NoData - } - b'N' => Message::NoticeResponse(NoticeResponseBody(buf)), - b'R' => { - match try!(buf.read_i32::()) { - 0 => { - check_empty!(buf); - Message::AuthenticationOk - } - 2 => { - check_empty!(buf); - Message::AuthenticationKerberosV5 - } - 3 => { - check_empty!(buf); - Message::AuthenticationCleartextPassword - } - 5 => { - if buf.len() != 4 { - return Err(io::Error::new(io::ErrorKind::InvalidInput, - "invalid message length")); - } - let mut salt = [0; 4]; - salt.copy_from_slice(buf); - Message::AuthenticationMd55Password(AuthenticationMd5PasswordBody { - salt: salt, - _p: PhantomData, - }) - } - 6 => { - check_empty!(buf); - Message::AuthenticationScmCredential - } - 7 => { - check_empty!(buf); - Message::AuthenticationGss - } - 9 => { - check_empty!(buf); - Message::AuthenticationSspi - } - tag => { - return Err(io::Error::new(io::ErrorKind::InvalidInput, - format!("unknown authentication tag `{}`", tag))); - } - } - } - b's' => { - check_empty!(buf); - Message::PortalSuspended - } - b'S' => { - let name = try!(buf.read_cstr()); - let value = try!(buf.read_cstr()); - check_empty!(buf); - Message::ParameterStatus(ParameterStatusBody { - name: name, - value: value, - }) - } - b't' => { - let len = try!(buf.read_u16::()); - Message::ParameterDescription(ParameterDescriptionBody { - len: len, - buf: buf, - }) - } - b'T' => { - let len = try!(buf.read_u16::()); - Message::RowDescription(RowDescriptionBody { - len: len, - buf: buf, - }) - } - b'Z' => { - let status = try!(buf.read_u8()); - check_empty!(buf); - Message::ReadyForQuery(ReadyForQueryBody { - status: status, - _p: PhantomData, - }) - } - tag => { - return Err(io::Error::new(io::ErrorKind::InvalidInput, - format!("unknown message tag `{}`", tag))); - } - }; - - Ok(ParseResult::Complete { - message: message, - consumed: len, - }) - } - - /// Converts this message into an owned representation. - pub fn to_owned(&self) -> io::Result { - let ret = match *self { - Message::AuthenticationCleartextPassword => { - backend::Message::AuthenticationCleartextPassword - } - Message::AuthenticationGss => backend::Message::AuthenticationGSS, - Message::AuthenticationKerberosV5 => backend::Message::AuthenticationKerberosV5, - Message::AuthenticationMd55Password(ref body) => { - backend::Message::AuthenticationMD5Password { salt: body.salt() } - } - Message::AuthenticationOk => backend::Message::AuthenticationOk, - Message::AuthenticationScmCredential => backend::Message::AuthenticationSCMCredential, - Message::AuthenticationSspi => backend::Message::AuthenticationSSPI, - Message::BackendKeyData(ref body) => { - backend::Message::BackendKeyData { - process_id: body.process_id(), - secret_key: body.secret_key(), - } - } - Message::BindComplete => backend::Message::BindComplete, - Message::CloseComplete => backend::Message::CloseComplete, - Message::CommandComplete(ref body) => { - backend::Message::CommandComplete { tag: body.tag().to_owned() } - } - Message::CopyData(ref body) => { - backend::Message::CopyData { data: body.data().to_owned() } - } - Message::CopyDone => backend::Message::CopyDone, - Message::CopyInResponse(ref body) => { - backend::Message::CopyInResponse { - format: body.format(), - column_formats: try!(body.column_formats().collect()), - } - } - Message::CopyOutResponse(ref body) => { - backend::Message::CopyOutResponse { - format: body.format(), - column_formats: try!(body.column_formats().collect()), - } - } - Message::DataRow(ref body) => { - backend::Message::DataRow { - row: try!(body.values().map(|r| r.map(|d| d.to_owned())).collect()), - } - } - Message::EmptyQueryResponse => backend::Message::EmptyQueryResponse, - Message::ErrorResponse(ref body) => { - backend::Message::ErrorResponse { - fields: try!(body.fields() - .map(|f| (f.type_(), f.value().to_owned())) - .collect()), - } - } - Message::NoData => backend::Message::NoData, - Message::NoticeResponse(ref body) => { - backend::Message::NoticeResponse { - fields: try!(body.fields() - .map(|f| (f.type_(), f.value().to_owned())) - .collect()), - } - } - Message::NotificationResponse(ref body) => { - backend::Message::NotificationResponse { - process_id: body.process_id(), - channel: body.channel().to_owned(), - payload: body.message().to_owned(), - } - } - Message::ParameterDescription(ref body) => { - backend::Message::ParameterDescription { types: try!(body.parameters().collect()) } - } - Message::ParameterStatus(ref body) => { - backend::Message::ParameterStatus { - parameter: body.name().to_owned(), - value: body.value().to_owned(), - } - } - Message::ParseComplete => backend::Message::ParseComplete, - Message::PortalSuspended => backend::Message::PortalSuspended, - Message::ReadyForQuery(ref body) => { - backend::Message::ReadyForQuery { state: body.status() } - } - Message::RowDescription(ref body) => { - let fields = body.fields() - .map(|f| { - RowDescriptionEntry { - name: f.name().to_owned(), - table_oid: f.table_oid(), - column_id: f.column_id(), - type_oid: f.type_oid(), - type_size: f.type_size(), - type_modifier: f.type_modifier(), - format: f.format(), - } - }); - backend::Message::RowDescription { descriptions: try!(fields.collect()) } - } - Message::__ForExtensibility => backend::Message::__ForExtensibility, - }; - - Ok(ret) - } -} - -pub struct AuthenticationMd5PasswordBody<'a> { - salt: [u8; 4], - _p: PhantomData<&'a [u8]>, -} - -impl<'a> AuthenticationMd5PasswordBody<'a> { - #[inline] - pub fn salt(&self) -> [u8; 4] { - self.salt - } -} - -pub struct BackendKeyDataBody<'a> { - process_id: i32, - secret_key: i32, - _p: PhantomData<&'a [u8]>, -} - -impl<'a> BackendKeyDataBody<'a> { - #[inline] - pub fn process_id(&self) -> i32 { - self.process_id - } - - #[inline] - pub fn secret_key(&self) -> i32 { - self.secret_key - } -} - -pub struct CommandCompleteBody<'a> { - tag: &'a str, -} - -impl<'a> CommandCompleteBody<'a> { - #[inline] - pub fn tag(&self) -> &'a str { - self.tag - } -} - -pub struct CopyDataBody<'a> { - data: &'a [u8], -} - -impl<'a> CopyDataBody<'a> { - #[inline] - pub fn data(&self) -> &'a [u8] { - self.data - } -} - -pub struct CopyInResponseBody<'a> { - format: u8, - len: u16, - buf: &'a [u8], -} - -impl<'a> CopyInResponseBody<'a> { - #[inline] - pub fn format(&self) -> u8 { - self.format - } - - #[inline] - pub fn column_formats(&self) -> ColumnFormats<'a> { - ColumnFormats { - remaining: self.len, - buf: self.buf, - } - } -} - -pub struct ColumnFormats<'a> { - remaining: u16, - buf: &'a [u8], -} - -impl<'a> FallibleIterator for ColumnFormats<'a> { - type Item = u16; - type Error = io::Error; - - #[inline] - fn next(&mut self) -> Result, io::Error> { - if self.remaining == 0 { - check_empty!(self.buf); - return Ok(None); - } - self.remaining -= 1; - self.buf.read_u16::().map(Some).map_err(Into::into) - } - - #[inline] - fn size_hint(&self) -> (usize, Option) { - let len = self.remaining as usize; - (len, Some(len)) - } -} - -pub struct CopyOutResponseBody<'a> { - format: u8, - len: u16, - buf: &'a [u8], -} - -impl<'a> CopyOutResponseBody<'a> { - #[inline] - pub fn format(&self) -> u8 { - self.format - } - - #[inline] - pub fn column_formats(&self) -> ColumnFormats<'a> { - ColumnFormats { - remaining: self.len, - buf: self.buf, - } - } -} - -pub struct DataRowBody<'a> { - len: u16, - buf: &'a [u8], -} - -impl<'a> DataRowBody<'a> { - #[inline] - pub fn values(&self) -> DataRowValues<'a> { - DataRowValues { - remaining: self.len, - buf: self.buf, - } - } -} - -pub struct DataRowValues<'a> { - remaining: u16, - buf: &'a [u8], -} - -impl<'a> FallibleIterator for DataRowValues<'a> { - type Item = Option<&'a [u8]>; - type Error = io::Error; - - #[inline] - fn next(&mut self) -> Result>, io::Error> { - if self.remaining == 0 { - check_empty!(self.buf); - return Ok(None); - } - self.remaining -= 1; - - let len = try!(self.buf.read_i32::()); - if len < 0 { - Ok(Some(None)) - } else { - let len = len as usize; - if self.buf.len() < len { - return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF").into()); - } - let (head, tail) = self.buf.split_at(len); - self.buf = tail; - Ok(Some(Some(head))) - } - } - - #[inline] - fn size_hint(&self) -> (usize, Option) { - let len = self.remaining as usize; - (len, Some(len)) - } -} - -pub struct ErrorResponseBody<'a>(&'a [u8]); - -impl<'a> ErrorResponseBody<'a> { - #[inline] - pub fn fields(&self) -> ErrorFields<'a> { - ErrorFields(self.0) - } -} - -pub struct ErrorFields<'a>(&'a [u8]); - -impl<'a> FallibleIterator for ErrorFields<'a> { - type Item = ErrorField<'a>; - type Error = io::Error; - - #[inline] - fn next(&mut self) -> Result>, io::Error> { - let type_ = try!(self.0.read_u8()); - if type_ == 0 { - check_empty!(self.0); - return Ok(None); - } - - let value = try!(self.0.read_cstr()); - - Ok(Some(ErrorField { - type_: type_, - value: value, - })) - } -} - -pub struct ErrorField<'a> { - type_: u8, - value: &'a str, -} - -impl<'a> ErrorField<'a> { - #[inline] - pub fn type_(&self) -> u8 { - self.type_ - } - - #[inline] - pub fn value(&self) -> &'a str { - self.value - } -} - -pub struct NoticeResponseBody<'a>(&'a [u8]); - -impl<'a> NoticeResponseBody<'a> { - #[inline] - pub fn fields(&self) -> ErrorFields<'a> { - ErrorFields(self.0) - } -} - -pub struct NotificationResponseBody<'a> { - process_id: i32, - channel: &'a str, - message: &'a str, -} - -impl<'a> NotificationResponseBody<'a> { - #[inline] - pub fn process_id(&self) -> i32 { - self.process_id - } - - #[inline] - pub fn channel(&self) -> &'a str { - self.channel - } - - #[inline] - pub fn message(&self) -> &'a str { - self.message - } -} - -pub struct ParameterDescriptionBody<'a> { - len: u16, - buf: &'a [u8], -} - -impl<'a> ParameterDescriptionBody<'a> { - #[inline] - pub fn parameters(&self) -> Parameters<'a> { - Parameters { - remaining: self.len, - buf: self.buf, - } - } -} - -pub struct Parameters<'a> { - remaining: u16, - buf: &'a [u8], -} - -impl<'a> FallibleIterator for Parameters<'a> { - type Item = Oid; - type Error = io::Error; - - #[inline] - fn next(&mut self) -> Result, io::Error> { - if self.remaining == 0 { - check_empty!(self.buf); - return Ok(None); - } - - self.remaining -= 1; - self.buf.read_u32::().map(Some).map_err(Into::into) - } - - #[inline] - fn size_hint(&self) -> (usize, Option) { - let len = self.remaining as usize; - (len, Some(len)) - } -} - -pub struct ParameterStatusBody<'a> { - name: &'a str, - value: &'a str, -} - -impl<'a> ParameterStatusBody<'a> { - #[inline] - pub fn name(&self) -> &'a str { - self.name - } - - #[inline] - pub fn value(&self) -> &'a str { - self.value - } -} - -pub struct ReadyForQueryBody<'a> { - status: u8, - _p: PhantomData<&'a [u8]>, -} - -impl<'a> ReadyForQueryBody<'a> { - #[inline] - pub fn status(&self) -> u8 { - self.status - } -} - -pub struct RowDescriptionBody<'a> { - len: u16, - buf: &'a [u8], -} - -impl<'a> RowDescriptionBody<'a> { - #[inline] - pub fn fields(&self) -> Fields<'a> { - Fields { - remaining: self.len, - buf: self.buf, - } - } -} - -pub struct Fields<'a> { - remaining: u16, - buf: &'a [u8], -} - -impl<'a> FallibleIterator for Fields<'a> { - type Item = Field<'a>; - type Error = io::Error; - - #[inline] - fn next(&mut self) -> Result>, io::Error> { - if self.remaining == 0 { - check_empty!(self.buf); - return Ok(None); - } - self.remaining -= 1; - - let name = try!(self.buf.read_cstr()); - let table_oid = try!(self.buf.read_u32::()); - let column_id = try!(self.buf.read_i16::()); - let type_oid = try!(self.buf.read_u32::()); - let type_size = try!(self.buf.read_i16::()); - let type_modifier = try!(self.buf.read_i32::()); - let format = try!(self.buf.read_i16::()); - - Ok(Some(Field { - name: name, - table_oid: table_oid, - column_id: column_id, - type_oid: type_oid, - type_size: type_size, - type_modifier: type_modifier, - format: format, - })) - } -} - -pub struct Field<'a> { - name: &'a str, - table_oid: Oid, - column_id: i16, - type_oid: Oid, - type_size: i16, - type_modifier: i32, - format: i16, -} - -impl<'a> Field<'a> { - #[inline] - pub fn name(&self) -> &'a str { - self.name - } - - #[inline] - pub fn table_oid(&self) -> Oid { - self.table_oid - } - - #[inline] - pub fn column_id(&self) -> i16 { - self.column_id - } - - #[inline] - pub fn type_oid(&self) -> Oid { - self.type_oid - } - - #[inline] - pub fn type_size(&self) -> i16 { - self.type_size - } - - #[inline] - pub fn type_modifier(&self) -> i32 { - self.type_modifier - } - - #[inline] - pub fn format(&self) -> i16 { - self.format - } -} - -trait ReadCStr<'a> { - fn read_cstr(&mut self) -> Result<&'a str, io::Error>; -} - -impl<'a> ReadCStr<'a> for &'a [u8] { - fn read_cstr(&mut self) -> Result<&'a str, io::Error> { - let end = match self.iter().position(|&b| b == 0) { - Some(end) => end, - None => { - return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF")); - } - }; - let s = try!(str::from_utf8(&self[..end]) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))); - *self = &self[end + 1..]; - Ok(s) - } -} diff --git a/src/message/backend/mod.rs b/src/message/backend/mod.rs deleted file mode 100644 index cedd344..0000000 --- a/src/message/backend/mod.rs +++ /dev/null @@ -1,98 +0,0 @@ -//! Backend message deserialization. -#![allow(missing_docs)] - -use std::io; - -use Oid; - -pub mod borrowed; - -/// An enum representing Postgres backend messages. -pub enum Message { - AuthenticationCleartextPassword, - AuthenticationGSS, - AuthenticationKerberosV5, - AuthenticationMD5Password { salt: [u8; 4] }, - AuthenticationOk, - AuthenticationSCMCredential, - AuthenticationSSPI, - BackendKeyData { process_id: i32, secret_key: i32 }, - BindComplete, - CloseComplete, - CommandComplete { tag: String }, - CopyData { data: Vec }, - CopyDone, - CopyInResponse { - format: u8, - column_formats: Vec, - }, - CopyOutResponse { - format: u8, - column_formats: Vec, - }, - DataRow { row: Vec>> }, - EmptyQueryResponse, - ErrorResponse { fields: Vec<(u8, String)> }, - NoData, - NoticeResponse { fields: Vec<(u8, String)> }, - NotificationResponse { - process_id: i32, - channel: String, - payload: String, - }, - ParameterDescription { types: Vec }, - ParameterStatus { parameter: String, value: String }, - ParseComplete, - PortalSuspended, - ReadyForQuery { state: u8 }, - RowDescription { descriptions: Vec, }, - #[doc(hidden)] - __ForExtensibility, -} - -impl Message { - /// Attempts to deserialize a backend message from the buffer. - pub fn parse(buf: &[u8]) -> io::Result> { - match borrowed::Message::parse(buf) { - Ok(ParseResult::Complete { message, consumed }) => { - Ok(ParseResult::Complete { - message: try!(message.to_owned()), - consumed: consumed, - }) - } - Ok(ParseResult::Incomplete { required_size }) => { - Ok(ParseResult::Incomplete { required_size: required_size }) - } - Err(e) => Err(e), - } - } -} - -/// The result of an attempted parse. -pub enum ParseResult { - /// A message was successfully parsed. - Complete { - /// The message. - message: T, - /// The number of bytes of the input buffer consumed to parse this message. - consumed: usize, - }, - /// The buffer did not contain a full message. - Incomplete { - /// The number of total bytes required to parse a message, if known. - /// - /// This value is present iff the input buffer contains at least 5 - /// bytes. - required_size: Option, - }, -} - -pub struct RowDescriptionEntry { - pub name: String, - pub table_oid: Oid, - pub column_id: i16, - pub type_oid: Oid, - pub type_size: i16, - pub type_modifier: i32, - pub format: i16, -}