Skip to content
This repository has been archived by the owner on Jan 11, 2021. It is now read-only.

Add type-safe accessors for primitive types in Row #86

Merged
merged 10 commits into from
Apr 21, 2018
260 changes: 257 additions & 3 deletions src/record/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use std::fmt;
use basic::{LogicalType, Type as PhysicalType};
use chrono::{Local, TimeZone};
use data_type::{ByteArray, Int96};
use errors::{ParquetError, Result};

/// Macro as a shortcut to generate 'not yet implemented' panic error.
macro_rules! nyi {
Expand All @@ -41,14 +42,78 @@ pub struct Row {
fields: Vec<(String, Field)>
}

impl Row {
/// Get the number of fields in this row
pub fn len(&self) -> usize {
self.fields.len()
}
}

/// Trait for type-safe convenient access to fields within a Row
pub trait RowAccessor {
fn get_bool(&self, i: usize) -> Result<bool>;
fn get_byte(&self, i: usize) -> Result<i8>;
fn get_short(&self, i: usize) -> Result<i16>;
fn get_int(&self, i: usize) -> Result<i32>;
fn get_long(&self, i: usize) -> Result<i64>;
fn get_float(&self, i: usize) -> Result<f32>;
fn get_double(&self, i: usize) -> Result<f64>;
fn get_timestamp(&self, i: usize) -> Result<u64>;
fn get_string(&self, i: usize) -> Result<&String>;
fn get_bytes(&self, i: usize) -> Result<&ByteArray>;
fn get_group(&self, i: usize) -> Result<&Row>;
fn get_list(&self, i: usize) -> Result<&List>;
fn get_map(&self, i: usize) -> Result<&Map>;
}

/// Macro to generate type-safe get_xxx methods for primitive types e.g. get_bool, get_short
macro_rules! row_primitive_accessor {
($METHOD:ident, $VARIANT:ident, $TY:ty) => {
fn $METHOD(&self, i: usize) -> Result<$TY> {
match self.fields[i].1 {
Field::$VARIANT(v) => Ok(v),
_ => Err(general_err!("Cannot access {} as {}",
self.fields[i].1.get_type_name(), stringify!($VARIANT)))
}
}
}
}

/// Macro to generate type-safe get_xxx methods for reference types e.g. get_list, get_map
macro_rules! row_complex_accessor {
($METHOD:ident, $VARIANT:ident, $TY:ty) => {
fn $METHOD(&self, i: usize) -> Result<&$TY> {
match self.fields[i].1 {
Field::$VARIANT(ref v) => Ok(v),
_ => Err(general_err!("Cannot access {} as {}",
self.fields[i].1.get_type_name(), stringify!($VARIANT)))
}
}
}
}

impl RowAccessor for Row {
row_primitive_accessor!(get_bool, Bool, bool);
row_primitive_accessor!(get_byte, Byte, i8);
row_primitive_accessor!(get_short, Short, i16);
row_primitive_accessor!(get_int, Int, i32);
row_primitive_accessor!(get_long, Long, i64);
row_primitive_accessor!(get_float, Float, f32);
row_primitive_accessor!(get_double, Double, f64);
row_primitive_accessor!(get_timestamp, Timestamp, u64);
row_complex_accessor!(get_string, Str, String);
row_complex_accessor!(get_bytes, Bytes, ByteArray);
row_complex_accessor!(get_group, Group, Row);
row_complex_accessor!(get_list, ListInternal, List);
row_complex_accessor!(get_map, MapInternal, Map);
}

/// Constructs a `Row` from the list of `fields` and returns it.
#[inline]
pub fn make_row(fields: Vec<(String, Field)>) -> Row {
Row { fields: fields }
}

// TODO: implement `getXXX` for different `Field`s

impl fmt::Display for Row {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{{")?;
Expand All @@ -64,13 +129,19 @@ impl fmt::Display for Row {
}
}


/// `List` represents a list which contains an array of elements.
#[derive(Clone, Debug, PartialEq)]
pub struct List {
elements: Vec<Field>
}

impl List {
/// Get the number of fields in this row
pub fn len(&self) -> usize {
self.elements.len()
}
}

/// Constructs a `List` from the list of `fields` and returns it.
#[inline]
pub fn make_list(elements: Vec<Field>) -> List {
Expand All @@ -86,6 +157,13 @@ pub struct Map {
entries: Vec<(Field, Field)>
}

impl Map {
/// Get the number of fields in this row
pub fn len(&self) -> usize {
self.entries.len()
}
}

/// Constructs a `Map` from the list of `entries` and returns it.
#[inline]
pub fn make_map(entries: Vec<(Field, Field)>) -> Map {
Expand Down Expand Up @@ -137,7 +215,39 @@ pub enum Field {
MapInternal(Map)
}


impl Field {
/// Get the type name
fn get_type_name(&self) -> &'static str {
match *self {
Field::Null => "Null",
Field::Bool(_) => "Bool",
Field::Byte(_) => "Byte",
Field::Short(_) => "Short",
Field::Int(_) => "Int",
Field::Long(_) => "Long",
Field::Float(_) => "Float",
Field::Double(_) => "Double",
Field::Date(_) => "Date",
Field::Str(_) => "Str",
Field::Bytes(_) => "Bytes",
Field::Timestamp(_) => "Timestamp",
Field::Group(_) => "Group",
Field::ListInternal(_) => "ListInternal",
Field::MapInternal(_) => "MapInternal",
}
}

/// Determines if this Row represents a primitive value
pub fn is_primitive(&self) -> bool {
match *self {
Field::Group(_) => false,
Field::ListInternal(_) => false,
Field::MapInternal(_) => false,
_ => true
}
}

/// Converts Parquet BOOLEAN type with logical type into `bool` value.
pub fn convert_bool(
_physical_type: PhysicalType,
Expand Down Expand Up @@ -482,4 +592,148 @@ mod tests {
]));
assert_eq!(format!("{}", row), "{1 -> 1.2, 2 -> 4.5, 3 -> 2.3}");
}

#[test]
fn test_is_primitive() {
// primitives
assert!(Field::Null.is_primitive());
assert!(Field::Bool(true).is_primitive());
assert!(Field::Bool(false).is_primitive());
assert!(Field::Byte(1).is_primitive());
assert!(Field::Short(2).is_primitive());
assert!(Field::Int(3).is_primitive());
assert!(Field::Long(4).is_primitive());
assert!(Field::Float(5.0).is_primitive());
assert!(Field::Float(5.1234).is_primitive());
assert!(Field::Double(6.0).is_primitive());
assert!(Field::Double(6.1234).is_primitive());
assert!(Field::Str("abc".to_string()).is_primitive());
assert!(Field::Bytes(ByteArray::from(vec![1, 2, 3])).is_primitive());
assert!(Field::Timestamp(12345678).is_primitive());

let value = ByteArray::from(vec![1, 2, 3, 4, 5]);
assert!(Field::Bytes(value).is_primitive());

// complex types
assert_eq!(false, Field::Group(make_row(vec![
("x".to_string(), Field::Null),
("Y".to_string(), Field::Int(2)),
("z".to_string(), Field::Float(3.1)),
("a".to_string(), Field::Str("abc".to_string()))
])).is_primitive());

assert_eq!(false, Field::ListInternal(make_list(vec![
Field::Int(2),
Field::Int(1),
Field::Null,
Field::Int(12)
])).is_primitive());

assert_eq!(false, Field::MapInternal(make_map(vec![
(Field::Int(1), Field::Float(1.2)),
(Field::Int(2), Field::Float(4.5)),
(Field::Int(3), Field::Float(2.3))
])).is_primitive());
}

#[test]
fn test_row_primitive_accessors() {
// primitives
let row = make_row(vec![
("a".to_string(), Field::Null),
("b".to_string(), Field::Bool(false)),
("c".to_string(), Field::Byte(3)),
("d".to_string(), Field::Short(4)),
("e".to_string(), Field::Int(5)),
("f".to_string(), Field::Long(6)),
("g".to_string(), Field::Float(7.1)),
("h".to_string(), Field::Double(8.1)),
("i".to_string(), Field::Str("abc".to_string())),
("j".to_string(), Field::Bytes(ByteArray::from(vec![1, 2, 3, 4, 5])))
]);

assert_eq!(false, row.get_bool(1).unwrap());
assert_eq!(3, row.get_byte(2).unwrap());
assert_eq!(4, row.get_short(3).unwrap());
assert_eq!(5, row.get_int(4).unwrap());
assert_eq!(6, row.get_long(5).unwrap());
assert_eq!(7.1, row.get_float(6).unwrap());
assert_eq!(8.1, row.get_double(7).unwrap());
assert!("abc".to_string().eq(row.get_string(8).unwrap()));
assert_eq!(5, row.get_bytes(9).unwrap().len());
}

#[test]
fn test_row_primitive_invalid_accessors() {
// primitives
let row = make_row(vec![
("a".to_string(), Field::Null),
("b".to_string(), Field::Bool(false)),
("c".to_string(), Field::Byte(3)),
("d".to_string(), Field::Short(4)),
("e".to_string(), Field::Int(5)),
("f".to_string(), Field::Long(6)),
("g".to_string(), Field::Float(7.1)),
("h".to_string(), Field::Double(8.1)),
("i".to_string(), Field::Str("abc".to_string())),
("j".to_string(), Field::Bytes(ByteArray::from(vec![1, 2, 3, 4, 5])))
]);

for i in 0..10 {
assert!(row.get_group(i).is_err());
}
}

#[test]
fn test_row_complex_accessors() {
let row = make_row(vec![
("a".to_string(), Field::Group(make_row(vec![
("x".to_string(), Field::Null),
("Y".to_string(), Field::Int(2))
]))),
("b".to_string(), Field::ListInternal(make_list(vec![
Field::Int(2),
Field::Int(1),
Field::Null,
Field::Int(12)
]))),
("c".to_string(), Field::MapInternal(make_map(vec![
(Field::Int(1), Field::Float(1.2)),
(Field::Int(2), Field::Float(4.5)),
(Field::Int(3), Field::Float(2.3))
])))
]);

assert_eq!(2, row.get_group(0).unwrap().len());
assert_eq!(4, row.get_list(1).unwrap().len());
assert_eq!(3, row.get_map(2).unwrap().len());
}

#[test]
fn test_row_complex_invalid_accessors() {
let row = make_row(vec![
("a".to_string(), Field::Group(make_row(vec![
("x".to_string(), Field::Null),
("Y".to_string(), Field::Int(2))
]))),
("b".to_string(), Field::ListInternal(make_list(vec![
Field::Int(2),
Field::Int(1),
Field::Null,
Field::Int(12)
]))),
("c".to_string(), Field::MapInternal(make_map(vec![
(Field::Int(1), Field::Float(1.2)),
(Field::Int(2), Field::Float(4.5)),
(Field::Int(3), Field::Float(2.3))
])))
]);

assert_eq!(ParquetError::General("Cannot access Group as Float".to_string()),
row.get_float(0).unwrap_err());
assert_eq!(ParquetError::General("Cannot access ListInternal as Float".to_string()),
row.get_float(1).unwrap_err());
assert_eq!(ParquetError::General("Cannot access MapInternal as Float".to_string()),
row.get_float(2).unwrap_err());
}
}