diff --git a/src/types/column/concat.rs b/src/types/column/concat.rs index 6fae8d09..3a774be3 100644 --- a/src/types/column/concat.rs +++ b/src/types/column/concat.rs @@ -2,6 +2,7 @@ use std::iter; use crate::{ binary::Encoder, + errors::{Result, Error, FromSqlError}, types::{SqlType, Value, ValueRef}, }; @@ -64,6 +65,15 @@ impl ColumnData for ConcatColumnData { fn clone_instance(&self) -> BoxColumnData { unimplemented!() } + + unsafe fn get_internal(&self, pointers: &[*mut *const u8], level: u8) -> Result<()> { + if level == 0xff { + *pointers[0] = &self.data as *const Vec as *mut u8; + Ok(()) + } else { + Err(Error::FromSql(FromSqlError::UnsupportedOperation)) + } + } } fn build_index<'a, I>(sizes: I) -> Vec diff --git a/src/types/column/iter.rs b/src/types/column/iter/mod.rs similarity index 86% rename from src/types/column/iter.rs rename to src/types/column/iter/mod.rs index 98612fb5..29156dcf 100644 --- a/src/types/column/iter.rs +++ b/src/types/column/iter/mod.rs @@ -12,13 +12,17 @@ use chrono_tz::Tz; use crate::{ errors::{Error, FromSqlError, Result}, - types::{column::StringPool, decimal::NoBits, Column, Decimal, Simple, SqlType}, + types::{ + column::{column_data::ArcColumnData, StringPool}, + decimal::NoBits, + Column, ColumnType, Complex, Decimal, Simple, SqlType, + }, }; macro_rules! simple_num_iterable { ( $($t:ty: $k:ident),* ) => { $( - impl<'a> SimpleIterable<'a> for $t { + impl<'a> Iterable<'a, Simple> for $t { type Iter = slice::Iter<'a, $t>; fn iter(column: &'a Column, column_type: SqlType) -> Result { @@ -112,10 +116,10 @@ macro_rules! exact_size_iterator { }; } -pub trait SimpleIterable<'a> { - type Iter: Iterator + 'a; +pub trait Iterable<'a, K: ColumnType> { + type Iter: Iterator; - fn iter(column: &'a Column, column_type: SqlType) -> Result; + fn iter(column: &'a Column, column_type: SqlType) -> Result; } enum StringInnerIterator<'a> { @@ -384,8 +388,6 @@ impl<'a> ExactSizeIterator for UuidIterator<'a> { } } - - impl<'a> DateIterator<'a> { #[inline(always)] unsafe fn next_unchecked(&mut self) -> Date { @@ -519,7 +521,7 @@ impl<'a, I: Iterator> Iterator for ArrayIterator<'a, I> { impl<'a, I: Iterator> FusedIterator for ArrayIterator<'a, I> {} -impl<'a> SimpleIterable<'a> for Ipv4Addr { +impl<'a> Iterable<'a, Simple> for Ipv4Addr { type Iter = Ipv4Iterator<'a>; fn iter(column: &'a Column, _column_type: SqlType) -> Result { @@ -540,8 +542,7 @@ impl<'a> SimpleIterable<'a> for Ipv4Addr { } } - -impl<'a> SimpleIterable<'a> for Ipv6Addr { +impl<'a> Iterable<'a, Simple> for Ipv6Addr { type Iter = Ipv6Iterator<'a>; fn iter(column: &'a Column, _column_type: SqlType) -> Result { @@ -562,7 +563,7 @@ impl<'a> SimpleIterable<'a> for Ipv6Addr { } } -impl<'a> SimpleIterable<'a> for uuid::Uuid { +impl<'a> Iterable<'a, Simple> for uuid::Uuid { type Iter = UuidIterator<'a>; fn iter(column: &'a Column, _column_type: SqlType) -> Result { @@ -583,7 +584,7 @@ impl<'a> SimpleIterable<'a> for uuid::Uuid { } } -impl<'a> SimpleIterable<'a> for &[u8] { +impl<'a> Iterable<'a, Simple> for &[u8] { type Iter = StringIterator<'a>; fn iter(column: &'a Column, column_type: SqlType) -> Result { @@ -630,7 +631,7 @@ impl<'a> SimpleIterable<'a> for &[u8] { } } -impl<'a> SimpleIterable<'a> for Decimal { +impl<'a> Iterable<'a, Simple> for Decimal { type Iter = DecimalIterator<'a>; fn iter(column: &'a Column, column_type: SqlType) -> Result { @@ -677,7 +678,7 @@ impl<'a> SimpleIterable<'a> for Decimal { } } -impl<'a> SimpleIterable<'a> for DateTime { +impl<'a> Iterable<'a, Simple> for DateTime { type Iter = DateTimeIterator<'a>; fn iter(column: &'a Column, column_type: SqlType) -> Result { @@ -691,7 +692,7 @@ impl<'a> SimpleIterable<'a> for DateTime { } } -impl<'a> SimpleIterable<'a> for Date { +impl<'a> Iterable<'a, Simple> for Date { type Iter = DateIterator<'a>; fn iter(column: &'a Column, column_type: SqlType) -> Result { @@ -738,9 +739,9 @@ fn date_iter( Ok((ptr, end, *tz)) } -impl<'a, T> SimpleIterable<'a> for Option +impl<'a, T> Iterable<'a, Simple> for Option where - T: SimpleIterable<'a>, + T: Iterable<'a, Simple>, { type Iter = NullableIterator<'a, T::Iter>; @@ -775,9 +776,9 @@ where } } -impl<'a, T> SimpleIterable<'a> for Vec +impl<'a, T> Iterable<'a, Simple> for Vec where - T: SimpleIterable<'a>, + T: Iterable<'a, Simple>, { type Iter = ArrayIterator<'a, T::Iter>; @@ -810,3 +811,104 @@ where }) } } + +pub struct ComplexIterator<'a, T> +where + T: Iterable<'a, Simple>, +{ + column_type: SqlType, + + data: &'a Vec, + + current_index: usize, + current: Option<>::Iter>, + + _marker: marker::PhantomData, +} + +impl<'a, T> Iterator for ComplexIterator<'a, T> +where + T: Iterable<'a, Simple>, +{ + type Item = <>::Iter as Iterator>::Item; + + fn next(&mut self) -> Option { + if self.current_index == self.data.len() && self.current.is_none() { + return None; + } + + if self.current.is_none() { + let column: Column = Column { + name: String::new(), + data: self.data[self.current_index].clone(), + _marker: marker::PhantomData, + }; + + let iter = unsafe { T::iter(mem::transmute(&column), self.column_type) }.unwrap(); + + self.current = Some(iter); + self.current_index += 1; + } + + let ret = match self.current { + None => None, + Some(ref mut iter) => iter.next(), + }; + + match ret { + None => { + self.current = None; + self.next() + } + Some(r) => Some(r), + } + } +} + +impl<'a, T> Iterable<'a, Complex> for T +where + T: Iterable<'a, Simple> + 'a, +{ + type Iter = ComplexIterator<'a, T>; + + fn iter(column: &Column, column_type: SqlType) -> Result { + let data: &Vec = unsafe { + let mut data: *const Vec = ptr::null(); + + column.get_internal( + &[&mut data as *mut *const Vec as *mut *const u8], + 0xff, + )?; + + &*data + }; + + Ok(ComplexIterator { + column_type, + data, + + current_index: 0, + current: None, + + _marker: marker::PhantomData, + }) + } +} + +#[cfg(test)] +mod test { + use crate::types::Block; + + #[test] + fn test_complex_iter() { + let lo = Block::new().column("?", vec![1_u32, 2]); + let hi = Block::new().column("?", vec![3_u32, 4, 5]); + + let block = Block::concat(&[lo, hi]); + + let columns = block.columns()[0].iter::().unwrap(); + let actual: Vec<_> = columns.collect(); + + assert_eq!(actual, vec![&1_u32, &2, &3, &4, &5]) + } +} diff --git a/src/types/column/mod.rs b/src/types/column/mod.rs index c2cc5852..62569cc2 100644 --- a/src/types/column/mod.rs +++ b/src/types/column/mod.rs @@ -11,13 +11,13 @@ use crate::{ binary::{Encoder, ReadEx}, errors::{Error, FromSqlError, Result}, types::{ - column::iter::SimpleIterable, column::{ column_data::ArcColumnData, decimal::{DecimalAdapter, NullableDecimalAdapter}, fixed_string::{FixedStringAdapter, NullableFixedStringAdapter}, ip::{Ipv6, IpColumnData, Ipv4}, string::StringAdapter, + iter::Iterable, }, decimal::NoBits, SqlType, Value, ValueRef, @@ -37,7 +37,7 @@ mod decimal; mod factory; pub(crate) mod fixed_string; mod ip; -mod iter; +pub(crate) mod iter; mod list; mod nullable; mod numeric; @@ -139,7 +139,9 @@ impl Column { } } } +} +impl Column { /// Returns an iterator over the column. /// /// ### Example @@ -154,7 +156,7 @@ impl Column { /// # let pool = Pool::new(database_url); /// # let mut client = pool.get_handle().await?; /// let mut stream = client - /// .query("SELECT number as n1, number as n2, number as n3 FROM numbers(10000000)") + /// .query("SELECT number as n1, number as n2, number as n3 FROM numbers(100000000)") /// .stream_blocks(); /// /// let mut sum = 0; @@ -175,11 +177,8 @@ impl Column { /// # }); /// # ret.unwrap() /// ``` - pub fn iter<'a, T>(&'a self) -> Result - where - T: SimpleIterable<'a>, - { - T::iter(self, self.sql_type()) + pub fn iter<'a, T: Iterable<'a, K>>(&'a self) -> Result { + >::iter(self, self.sql_type()) } }