diff --git a/postgres-protocol/src/types/mod.rs b/postgres-protocol/src/types/mod.rs index 621c01cc2..a5999e34f 100644 --- a/postgres-protocol/src/types/mod.rs +++ b/postgres-protocol/src/types/mod.rs @@ -1045,3 +1045,79 @@ impl Inet { self.netmask } } + +/// A fallible iterator over the fields of a composite type. +pub struct CompositeTypeRanges<'a> { + buf: &'a [u8], + len: usize, + remaining: u16, +} + +impl<'a> CompositeTypeRanges<'a> { + /// Returns a fallible iterator over the fields of the composite type. + #[inline] + pub fn new(buf: &'a [u8], len: usize, remaining: u16) -> CompositeTypeRanges<'a> { + CompositeTypeRanges { + buf, + len, + remaining, + } + } +} + +impl<'a> FallibleIterator for CompositeTypeRanges<'a> { + type Item = Option>; + type Error = std::io::Error; + + #[inline] + fn next(&mut self) -> std::io::Result>>> { + if self.remaining == 0 { + if self.buf.is_empty() { + return Ok(None); + } else { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "invalid buffer length: compositetyperanges is not empty", + )); + } + } + + self.remaining -= 1; + + // Binary format of a composite type: + // [for each field] + // + // [if value is NULL] + // <-1: 4 bytes> + // [else] + // + // bytes> + // [end if] + // [end for] + // https://www.postgresql.org/message-id/16CCB2D3-197E-4D9F-BC6F-9B123EA0D40D%40phlo.org + // https://github.com/postgres/postgres/blob/29e321cdd63ea48fd0223447d58f4742ad729eb0/src/backend/utils/adt/rowtypes.c#L736 + + let _oid = self.buf.read_i32::()?; + let len = self.buf.read_i32::()?; + if len < 0 { + Ok(Some(None)) + } else { + let len = len as usize; + if self.buf.len() < len { + return Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "unexpected EOF", + )); + } + let base = self.len - self.buf.len(); + self.buf = &self.buf[len as usize..]; + Ok(Some(Some(base..base + len))) + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.remaining as usize; + (len, Some(len)) + } +} diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index 0d8aa8436..51b40a8da 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -110,7 +110,7 @@ pub use crate::error::Error; pub use crate::generic_client::GenericClient; pub use crate::portal::Portal; pub use crate::query::RowStream; -pub use crate::row::{Row, SimpleQueryRow}; +pub use crate::row::{CompositeType, Row, SimpleQueryRow}; pub use crate::simple_query::SimpleQueryStream; #[cfg(feature = "runtime")] pub use crate::socket::Socket; diff --git a/tokio-postgres/src/row.rs b/tokio-postgres/src/row.rs index 03c7635b2..9bd4b06bb 100644 --- a/tokio-postgres/src/row.rs +++ b/tokio-postgres/src/row.rs @@ -2,10 +2,12 @@ use crate::row::sealed::{AsName, Sealed}; use crate::statement::Column; -use crate::types::{FromSql, Type, WrongType}; +use crate::types::{Field, FromSql, Kind, Type, WrongType}; use crate::{Error, Statement}; +use byteorder::{BigEndian, ByteOrder}; use fallible_iterator::FallibleIterator; use postgres_protocol::message::backend::DataRowBody; +use postgres_protocol::types::CompositeTypeRanges; use std::fmt; use std::ops::Range; use std::str; @@ -31,6 +33,12 @@ impl AsName for String { } } +impl AsName for Field { + fn as_name(&self) -> &str { + self.name() + } +} + /// A trait implemented by types that can index into columns of a row. /// /// This cannot be implemented outside of this crate. @@ -175,6 +183,125 @@ impl Row { } } +/// A PostgreSQL composite type. +/// Fields of a type can be accessed using `CompositeType::get` and `CompositeType::try_get` methods. +pub struct CompositeType<'a> { + type_: Type, + body: &'a [u8], + ranges: Vec>>, +} + +impl<'a> FromSql<'a> for CompositeType<'a> { + fn from_sql( + type_: &Type, + body: &'a [u8], + ) -> Result, Box> { + match *type_.kind() { + Kind::Composite(_) => { + let fields: &[Field] = composite_type_fields(&type_); + if body.len() < 4 { + let message = format!("invalid composite type body length: {}", body.len()); + return Err(message.into()); + } + let num_fields: i32 = BigEndian::read_i32(&body[0..4]); + if num_fields as usize != fields.len() { + let message = + format!("invalid field count: {} vs {}", num_fields, fields.len()); + return Err(message.into()); + } + let ranges = CompositeTypeRanges::new(&body[4..], body.len(), num_fields as u16) + .collect() + .map_err(Error::parse)?; + Ok(CompositeType { + type_: type_.clone(), + body, + ranges, + }) + } + _ => Err(format!("expected composite type, got {}", type_).into()), + } + } + fn accepts(ty: &Type) -> bool { + match *ty.kind() { + Kind::Composite(_) => true, + _ => false, + } + } +} + +fn composite_type_fields(type_: &Type) -> &[Field] { + match type_.kind() { + Kind::Composite(ref fields) => fields, + _ => unreachable!(), + } +} + +impl<'a> CompositeType<'a> { + /// Returns information about the fields of the composite type. + pub fn fields(&self) -> &[Field] { + composite_type_fields(&self.type_) + } + + /// Determines if the composite contains no values. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns the number of fields of the composite type. + pub fn len(&self) -> usize { + self.fields().len() + } + + /// Deserializes a value from the composite type. + /// + /// The value can be specified either by its numeric index, or by its field name. + /// + /// # Panics + /// + /// Panics if the index is out of bounds or if the value cannot be converted to the specified type. + pub fn get<'b, I, T>(&'b self, idx: I) -> T + where + I: RowIndex + fmt::Display, + T: FromSql<'b>, + { + match self.get_inner(&idx) { + Ok(ok) => ok, + Err(err) => panic!("error retrieving column {}: {}", idx, err), + } + } + + /// Like `CompositeType::get`, but returns a `Result` rather than panicking. + pub fn try_get<'b, I, T>(&'b self, idx: I) -> Result + where + I: RowIndex + fmt::Display, + T: FromSql<'b>, + { + self.get_inner(&idx) + } + + fn get_inner<'b, I, T>(&'b self, idx: &I) -> Result + where + I: RowIndex + fmt::Display, + T: FromSql<'b>, + { + let idx = match idx.__idx(self.fields()) { + Some(idx) => idx, + None => return Err(Error::column(idx.to_string())), + }; + + let ty = self.fields()[idx].type_(); + if !T::accepts(ty) { + return Err(Error::from_sql( + Box::new(WrongType::new::(ty.clone())), + idx, + )); + } + + let buf = self.ranges[idx].clone().map(|r| &self.body[r]); + FromSql::from_sql_nullable(ty, buf).map_err(|e| Error::from_sql(e, idx)) + } +} + /// A row of data returned from the database by a simple query. pub struct SimpleQueryRow { columns: Arc<[String]>, diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 92f1edce6..b9c3e7653 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -13,7 +13,8 @@ use tokio_postgres::error::SqlState; use tokio_postgres::tls::{NoTls, NoTlsStream}; use tokio_postgres::types::{Kind, Type}; use tokio_postgres::{ - AsyncMessage, Client, Config, Connection, Error, IsolationLevel, SimpleQueryMessage, + AsyncMessage, Client, CompositeType, Config, Connection, Error, IsolationLevel, + SimpleQueryMessage, }; mod binary_copy; @@ -762,3 +763,55 @@ async fn query_opt() { .err() .unwrap(); } + +#[tokio::test] +async fn composite_type() { + let client = connect("user=postgres").await; + + client + .batch_execute( + " + CREATE TYPE pg_temp.message AS ( + id INTEGER, + content TEXT, + link TEXT + ); + CREATE TYPE pg_temp.person AS ( + id INTEGER, + name TEXT, + messages message[], + email TEXT + ); + + ", + ) + .await + .unwrap(); + + let row = client + .query_one( + "select (123,'alice',ARRAY[(1,'message1',NULL)::message,(2,'message2',NULL)::message],NULL)::person", + &[], + ) + .await + .unwrap(); + + let person: CompositeType<'_> = row.get(0); + + assert_eq!(person.get::<_, Option>("id"), Some(123)); + assert_eq!(person.get::<_, Option<&str>>("name"), Some("alice")); + assert_eq!(person.get::<_, Option<&str>>("email"), None); + + let messages: Vec> = person.get("messages"); + + assert_eq!(messages.len(), 2); + assert_eq!(messages[0].get::<_, Option>("id"), Some(1)); + assert_eq!( + messages[0].get::<_, Option<&str>>("content"), + Some("message1") + ); + assert_eq!(messages[0].get::<_, Option<&str>>("link"), None); + assert_eq!(messages[1].get::<_, Option>(0), Some(2)); + assert_eq!(messages[1].get::<_, Option<&str>>(1), Some("message2")); + assert_eq!(messages[1].get::<_, Option<&str>>(2), None); +}