Skip to content

Commit

Permalink
iter for fetch_all
Browse files Browse the repository at this point in the history
(cherry picked from commit 8a59465)
  • Loading branch information
suharev7 committed Feb 11, 2020
1 parent a4f66a6 commit cf08297
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 27 deletions.
10 changes: 10 additions & 0 deletions src/types/column/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::iter;

use crate::{
binary::Encoder,
errors::{Result, Error, FromSqlError},
types::{SqlType, Value, ValueRef},
};

Expand Down Expand Up @@ -64,6 +65,15 @@ impl ColumnData for ConcatColumnData {
fn clone_instance(&self) -> BoxColumnData {
unimplemented!()
}

unsafe fn get_internal(&self, pointers: &[*mut *const u8], level: u8) -> Result<()> {
if level == 0xff {
*pointers[0] = &self.data as *const Vec<ArcColumnData> as *mut u8;
Ok(())
} else {
Err(Error::FromSql(FromSqlError::UnsupportedOperation))
}
}
}

fn build_index<'a, I>(sizes: I) -> Vec<usize>
Expand Down
140 changes: 121 additions & 19 deletions src/types/column/iter.rs → src/types/column/iter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@ use chrono_tz::Tz;

use crate::{
errors::{Error, FromSqlError, Result},
types::{column::StringPool, decimal::NoBits, Column, Decimal, Simple, SqlType},
types::{
column::{column_data::ArcColumnData, StringPool},
decimal::NoBits,
Column, ColumnType, Complex, Decimal, Simple, SqlType,
},
};

macro_rules! simple_num_iterable {
( $($t:ty: $k:ident),* ) => {
$(
impl<'a> SimpleIterable<'a> for $t {
impl<'a> Iterable<'a, Simple> for $t {
type Iter = slice::Iter<'a, $t>;

fn iter(column: &'a Column<Simple>, column_type: SqlType) -> Result<Self::Iter> {
Expand Down Expand Up @@ -112,10 +116,10 @@ macro_rules! exact_size_iterator {
};
}

pub trait SimpleIterable<'a> {
type Iter: Iterator + 'a;
pub trait Iterable<'a, K: ColumnType> {
type Iter: Iterator;

fn iter(column: &'a Column<Simple>, column_type: SqlType) -> Result<Self::Iter>;
fn iter(column: &'a Column<K>, column_type: SqlType) -> Result<Self::Iter>;
}

enum StringInnerIterator<'a> {
Expand Down Expand Up @@ -384,8 +388,6 @@ impl<'a> ExactSizeIterator for UuidIterator<'a> {
}
}



impl<'a> DateIterator<'a> {
#[inline(always)]
unsafe fn next_unchecked(&mut self) -> Date<Tz> {
Expand Down Expand Up @@ -519,7 +521,7 @@ impl<'a, I: Iterator> Iterator for ArrayIterator<'a, I> {

impl<'a, I: Iterator> FusedIterator for ArrayIterator<'a, I> {}

impl<'a> SimpleIterable<'a> for Ipv4Addr {
impl<'a> Iterable<'a, Simple> for Ipv4Addr {
type Iter = Ipv4Iterator<'a>;

fn iter(column: &'a Column<Simple>, _column_type: SqlType) -> Result<Self::Iter> {
Expand All @@ -540,8 +542,7 @@ impl<'a> SimpleIterable<'a> for Ipv4Addr {
}
}


impl<'a> SimpleIterable<'a> for Ipv6Addr {
impl<'a> Iterable<'a, Simple> for Ipv6Addr {
type Iter = Ipv6Iterator<'a>;

fn iter(column: &'a Column<Simple>, _column_type: SqlType) -> Result<Self::Iter> {
Expand All @@ -562,7 +563,7 @@ impl<'a> SimpleIterable<'a> for Ipv6Addr {
}
}

impl<'a> SimpleIterable<'a> for uuid::Uuid {
impl<'a> Iterable<'a, Simple> for uuid::Uuid {
type Iter = UuidIterator<'a>;

fn iter(column: &'a Column<Simple>, _column_type: SqlType) -> Result<Self::Iter> {
Expand All @@ -583,7 +584,7 @@ impl<'a> SimpleIterable<'a> for uuid::Uuid {
}
}

impl<'a> SimpleIterable<'a> for &[u8] {
impl<'a> Iterable<'a, Simple> for &[u8] {
type Iter = StringIterator<'a>;

fn iter(column: &'a Column<Simple>, column_type: SqlType) -> Result<Self::Iter> {
Expand Down Expand Up @@ -630,7 +631,7 @@ impl<'a> SimpleIterable<'a> for &[u8] {
}
}

impl<'a> SimpleIterable<'a> for Decimal {
impl<'a> Iterable<'a, Simple> for Decimal {
type Iter = DecimalIterator<'a>;

fn iter(column: &'a Column<Simple>, column_type: SqlType) -> Result<Self::Iter> {
Expand Down Expand Up @@ -677,7 +678,7 @@ impl<'a> SimpleIterable<'a> for Decimal {
}
}

impl<'a> SimpleIterable<'a> for DateTime<Tz> {
impl<'a> Iterable<'a, Simple> for DateTime<Tz> {
type Iter = DateTimeIterator<'a>;

fn iter(column: &'a Column<Simple>, column_type: SqlType) -> Result<Self::Iter> {
Expand All @@ -691,7 +692,7 @@ impl<'a> SimpleIterable<'a> for DateTime<Tz> {
}
}

impl<'a> SimpleIterable<'a> for Date<Tz> {
impl<'a> Iterable<'a, Simple> for Date<Tz> {
type Iter = DateIterator<'a>;

fn iter(column: &'a Column<Simple>, column_type: SqlType) -> Result<Self::Iter> {
Expand Down Expand Up @@ -738,9 +739,9 @@ fn date_iter<T>(
Ok((ptr, end, *tz))
}

impl<'a, T> SimpleIterable<'a> for Option<T>
impl<'a, T> Iterable<'a, Simple> for Option<T>
where
T: SimpleIterable<'a>,
T: Iterable<'a, Simple>,
{
type Iter = NullableIterator<'a, T::Iter>;

Expand Down Expand Up @@ -775,9 +776,9 @@ where
}
}

impl<'a, T> SimpleIterable<'a> for Vec<T>
impl<'a, T> Iterable<'a, Simple> for Vec<T>
where
T: SimpleIterable<'a>,
T: Iterable<'a, Simple>,
{
type Iter = ArrayIterator<'a, T::Iter>;

Expand Down Expand Up @@ -810,3 +811,104 @@ where
})
}
}

pub struct ComplexIterator<'a, T>
where
T: Iterable<'a, Simple>,
{
column_type: SqlType,

data: &'a Vec<ArcColumnData>,

current_index: usize,
current: Option<<T as Iterable<'a, Simple>>::Iter>,

_marker: marker::PhantomData<T>,
}

impl<'a, T> Iterator for ComplexIterator<'a, T>
where
T: Iterable<'a, Simple>,
{
type Item = <<T as Iterable<'a, Simple>>::Iter as Iterator>::Item;

fn next(&mut self) -> Option<Self::Item> {
if self.current_index == self.data.len() && self.current.is_none() {
return None;
}

if self.current.is_none() {
let column: Column<Simple> = Column {
name: String::new(),
data: self.data[self.current_index].clone(),
_marker: marker::PhantomData,
};

let iter = unsafe { T::iter(mem::transmute(&column), self.column_type) }.unwrap();

self.current = Some(iter);
self.current_index += 1;
}

let ret = match self.current {
None => None,
Some(ref mut iter) => iter.next(),
};

match ret {
None => {
self.current = None;
self.next()
}
Some(r) => Some(r),
}
}
}

impl<'a, T> Iterable<'a, Complex> for T
where
T: Iterable<'a, Simple> + 'a,
{
type Iter = ComplexIterator<'a, T>;

fn iter(column: &Column<Complex>, column_type: SqlType) -> Result<Self::Iter> {
let data: &Vec<ArcColumnData> = unsafe {
let mut data: *const Vec<ArcColumnData> = ptr::null();

column.get_internal(
&[&mut data as *mut *const Vec<ArcColumnData> as *mut *const u8],
0xff,
)?;

&*data
};

Ok(ComplexIterator {
column_type,
data,

current_index: 0,
current: None,

_marker: marker::PhantomData,
})
}
}

#[cfg(test)]
mod test {
use crate::types::Block;

#[test]
fn test_complex_iter() {
let lo = Block::new().column("?", vec![1_u32, 2]);
let hi = Block::new().column("?", vec![3_u32, 4, 5]);

let block = Block::concat(&[lo, hi]);

let columns = block.columns()[0].iter::<u32>().unwrap();
let actual: Vec<_> = columns.collect();

assert_eq!(actual, vec![&1_u32, &2, &3, &4, &5])
}
}
15 changes: 7 additions & 8 deletions src/types/column/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ use crate::{
binary::{Encoder, ReadEx},
errors::{Error, FromSqlError, Result},
types::{
column::iter::SimpleIterable,
column::{
column_data::ArcColumnData,
decimal::{DecimalAdapter, NullableDecimalAdapter},
fixed_string::{FixedStringAdapter, NullableFixedStringAdapter},
ip::{Ipv6, IpColumnData, Ipv4},
string::StringAdapter,
iter::Iterable,
},
decimal::NoBits,
SqlType, Value, ValueRef,
Expand All @@ -37,7 +37,7 @@ mod decimal;
mod factory;
pub(crate) mod fixed_string;
mod ip;
mod iter;
pub(crate) mod iter;
mod list;
mod nullable;
mod numeric;
Expand Down Expand Up @@ -139,7 +139,9 @@ impl Column<Simple> {
}
}
}
}

impl<K: ColumnType> Column<K> {
/// Returns an iterator over the column.
///
/// ### Example
Expand All @@ -154,7 +156,7 @@ impl Column<Simple> {
/// # let pool = Pool::new(database_url);
/// # let mut client = pool.get_handle().await?;
/// let mut stream = client
/// .query("SELECT number as n1, number as n2, number as n3 FROM numbers(10000000)")
/// .query("SELECT number as n1, number as n2, number as n3 FROM numbers(100000000)")
/// .stream_blocks();
///
/// let mut sum = 0;
Expand All @@ -175,11 +177,8 @@ impl Column<Simple> {
/// # });
/// # ret.unwrap()
/// ```
pub fn iter<'a, T>(&'a self) -> Result<T::Iter>
where
T: SimpleIterable<'a>,
{
T::iter(self, self.sql_type())
pub fn iter<'a, T: Iterable<'a, K>>(&'a self) -> Result<T::Iter> {
<T as Iterable<'a, K>>::iter(self, self.sql_type())
}
}

Expand Down

0 comments on commit cf08297

Please sign in to comment.