Skip to content
This repository has been archived by the owner on Jan 11, 2021. It is now read-only.

Add type-safe accessors for primitive types in Row #86

Merged
merged 10 commits into from
Apr 21, 2018
260 changes: 257 additions & 3 deletions src/record/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use std::fmt;
use basic::{LogicalType, Type as PhysicalType};
use chrono::{Local, TimeZone};
use data_type::{ByteArray, Int96};
use errors::{ParquetError, Result};

/// Macro as a shortcut to generate 'not yet implemented' panic error.
macro_rules! nyi {
Expand All @@ -41,14 +42,78 @@ pub struct Row {
fields: Vec<(String, Field)>
}

impl Row {
/// Get the number of fields in this row
pub fn len(&self) -> usize {
self.fields.len()
}
}

/// Trait for type-safe convenient access to fields within a Row
trait RowAccessor {
Copy link
Owner

@sunchao sunchao Apr 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be public and be exported in mod.rs, otherwise these methods will not be available.

fn get_bool(&self, i: usize) -> Result<bool>;
fn get_byte(&self, i: usize) -> Result<i8>;
fn get_short(&self, i: usize) -> Result<i16>;
fn get_int(&self, i: usize) -> Result<i32>;
fn get_long(&self, i: usize) -> Result<i64>;
fn get_float(&self, i: usize) -> Result<f32>;
fn get_double(&self, i: usize) -> Result<f64>;
fn get_timestamp(&self, i: usize) -> Result<u64>;
fn get_string(&self, i: usize) -> Result<&String>;
fn get_bytes(&self, i: usize) -> Result<&ByteArray>;
fn get_group(&self, i: usize) -> Result<&Row>;
fn get_list(&self, i: usize) -> Result<&List>;
fn get_map(&self, i: usize) -> Result<&Map>;
}

/// Macro to generate type-safe get_xxx methods for primitive types e.g. get_bool, get_short
macro_rules! row_primitive_accessor {
($METHOD:ident, $VARIANT:ident, $TY:ty) => {
fn $METHOD(&self, i: usize) -> Result<$TY> {
match self.fields[i].1 {
Field::$VARIANT(v) => Ok(v),
_ => Err(general_err!("Cannot access {} as {}",
self.fields[i].1.get_type_name(), stringify!($VARIANT)))
}
}
}
}

/// Macro to generate type-safe get_xxx methods for reference types e.g. get_list, get_map
macro_rules! row_complex_accessor {
($METHOD:ident, $VARIANT:ident, $TY:ty) => {
fn $METHOD(&self, i: usize) -> Result<&$TY> {
match self.fields[i].1 {
Field::$VARIANT(ref v) => Ok(v),
_ => Err(general_err!("Cannot access {} as {}",
self.fields[i].1.get_type_name(), stringify!($VARIANT)))
}
}
}
}

impl RowAccessor for Row {
row_primitive_accessor!(get_bool, Bool, bool);
row_primitive_accessor!(get_byte, Byte, i8);
row_primitive_accessor!(get_short, Short, i16);
row_primitive_accessor!(get_int, Int, i32);
row_primitive_accessor!(get_long, Long, i64);
row_primitive_accessor!(get_float, Float, f32);
row_primitive_accessor!(get_double, Double, f64);
row_primitive_accessor!(get_timestamp, Timestamp, u64);
row_complex_accessor!(get_string, Str, String);
row_complex_accessor!(get_bytes, Bytes, ByteArray);
row_complex_accessor!(get_group, Group, Row);
row_complex_accessor!(get_list, ListInternal, List);
row_complex_accessor!(get_map, MapInternal, Map);
}

/// Constructs a `Row` from the list of `fields` and returns it.
#[inline]
pub fn make_row(fields: Vec<(String, Field)>) -> Row {
Row { fields: fields }
}

// TODO: implement `getXXX` for different `Field`s

impl fmt::Display for Row {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{{")?;
Expand All @@ -64,13 +129,19 @@ impl fmt::Display for Row {
}
}


/// `List` represents a list which contains an array of elements.
#[derive(Clone, Debug, PartialEq)]
pub struct List {
elements: Vec<Field>
}

impl List {
/// Get the number of fields in this row
pub fn len(&self) -> usize {
self.elements.len()
}
}

/// Constructs a `List` from the list of `fields` and returns it.
#[inline]
pub fn make_list(elements: Vec<Field>) -> List {
Expand All @@ -86,6 +157,13 @@ pub struct Map {
entries: Vec<(Field, Field)>
}

impl Map {
/// Get the number of fields in this row
pub fn len(&self) -> usize {
self.entries.len()
}
}

/// Constructs a `Map` from the list of `entries` and returns it.
#[inline]
pub fn make_map(entries: Vec<(Field, Field)>) -> Map {
Expand Down Expand Up @@ -137,7 +215,39 @@ pub enum Field {
MapInternal(Map)
}


impl Field {
/// Get the type name
fn get_type_name(&self) -> &'static str {
match *self {
Field::Null => "Null",
Field::Bool(_) => "Bool",
Field::Byte(_) => "Byte",
Field::Short(_) => "Short",
Field::Int(_) => "Int",
Field::Long(_) => "Long",
Field::Float(_) => "Float",
Field::Double(_) => "Double",
Field::Date(_) => "Date",
Field::Str(_) => "Str",
Field::Bytes(_) => "Bytes",
Field::Timestamp(_) => "Timestamp",
Field::Group(_) => "Group",
Field::ListInternal(_) => "ListInternal",
Field::MapInternal(_) => "MapInternal",
}
}

/// Determines if this Row represents a primitive value
pub fn is_primitive(&self) -> bool {
match *self {
Field::Group(_) => false,
Field::ListInternal(_) => false,
Field::MapInternal(_) => false,
_ => true
}
}

/// Converts Parquet BOOLEAN type with logical type into `bool` value.
pub fn convert_bool(
_physical_type: PhysicalType,
Expand Down Expand Up @@ -482,4 +592,148 @@ mod tests {
]));
assert_eq!(format!("{}", row), "{1 -> 1.2, 2 -> 4.5, 3 -> 2.3}");
}

#[test]
fn test_is_primitive() {
// primitives
assert!(Field::Null.is_primitive());
assert!(Field::Bool(true).is_primitive());
assert!(Field::Bool(false).is_primitive());
assert!(Field::Byte(1).is_primitive());
assert!(Field::Short(2).is_primitive());
assert!(Field::Int(3).is_primitive());
assert!(Field::Long(4).is_primitive());
assert!(Field::Float(5.0).is_primitive());
assert!(Field::Float(5.1234).is_primitive());
assert!(Field::Double(6.0).is_primitive());
assert!(Field::Double(6.1234).is_primitive());
assert!(Field::Str("abc".to_string()).is_primitive());
assert!(Field::Bytes(ByteArray::from(vec![1, 2, 3])).is_primitive());
assert!(Field::Timestamp(12345678).is_primitive());

let value = ByteArray::from(vec![1, 2, 3, 4, 5]);
assert!(Field::Bytes(value).is_primitive());

// complex types
assert_eq!(false, Field::Group(make_row(vec![
("x".to_string(), Field::Null),
("Y".to_string(), Field::Int(2)),
("z".to_string(), Field::Float(3.1)),
("a".to_string(), Field::Str("abc".to_string()))
])).is_primitive());

assert_eq!(false, Field::ListInternal(make_list(vec![
Field::Int(2),
Field::Int(1),
Field::Null,
Field::Int(12)
])).is_primitive());

assert_eq!(false, Field::MapInternal(make_map(vec![
(Field::Int(1), Field::Float(1.2)),
(Field::Int(2), Field::Float(4.5)),
(Field::Int(3), Field::Float(2.3))
])).is_primitive());
}

#[test]
fn test_row_primitive_accessors() {
// primitives
let row = make_row(vec![
("a".to_string(), Field::Null),
("b".to_string(), Field::Bool(false)),
("c".to_string(), Field::Byte(3)),
("d".to_string(), Field::Short(4)),
("e".to_string(), Field::Int(5)),
("f".to_string(), Field::Long(6)),
("g".to_string(), Field::Float(7.1)),
("h".to_string(), Field::Double(8.1)),
("i".to_string(), Field::Str("abc".to_string())),
("j".to_string(), Field::Bytes(ByteArray::from(vec![1, 2, 3, 4, 5])))
]);

assert_eq!(false, row.get_bool(1).unwrap());
assert_eq!(3, row.get_byte(2).unwrap());
assert_eq!(4, row.get_short(3).unwrap());
assert_eq!(5, row.get_int(4).unwrap());
assert_eq!(6, row.get_long(5).unwrap());
assert_eq!(7.1, row.get_float(6).unwrap());
assert_eq!(8.1, row.get_double(7).unwrap());
assert!("abc".to_string().eq(row.get_string(8).unwrap()));
assert_eq!(5, row.get_bytes(9).unwrap().len());
}

#[test]
fn test_row_primitive_invalid_accessors() {
// primitives
let row = make_row(vec![
("a".to_string(), Field::Null),
("b".to_string(), Field::Bool(false)),
("c".to_string(), Field::Byte(3)),
("d".to_string(), Field::Short(4)),
("e".to_string(), Field::Int(5)),
("f".to_string(), Field::Long(6)),
("g".to_string(), Field::Float(7.1)),
("h".to_string(), Field::Double(8.1)),
("i".to_string(), Field::Str("abc".to_string())),
("j".to_string(), Field::Bytes(ByteArray::from(vec![1, 2, 3, 4, 5])))
]);

for i in 0..10 {
assert!(row.get_group(i).is_err());
}
}

#[test]
fn test_row_complex_accessors() {
let row = make_row(vec![
("a".to_string(), Field::Group(make_row(vec![
("x".to_string(), Field::Null),
("Y".to_string(), Field::Int(2))
]))),
("b".to_string(), Field::ListInternal(make_list(vec![
Field::Int(2),
Field::Int(1),
Field::Null,
Field::Int(12)
]))),
("c".to_string(), Field::MapInternal(make_map(vec![
(Field::Int(1), Field::Float(1.2)),
(Field::Int(2), Field::Float(4.5)),
(Field::Int(3), Field::Float(2.3))
])))
]);

assert_eq!(2, row.get_group(0).unwrap().len());
assert_eq!(4, row.get_list(1).unwrap().len());
assert_eq!(3, row.get_map(2).unwrap().len());
}

#[test]
fn test_row_complex_invalid_accessors() {
let row = make_row(vec![
("a".to_string(), Field::Group(make_row(vec![
("x".to_string(), Field::Null),
("Y".to_string(), Field::Int(2))
]))),
("b".to_string(), Field::ListInternal(make_list(vec![
Field::Int(2),
Field::Int(1),
Field::Null,
Field::Int(12)
]))),
("c".to_string(), Field::MapInternal(make_map(vec![
(Field::Int(1), Field::Float(1.2)),
(Field::Int(2), Field::Float(4.5)),
(Field::Int(3), Field::Float(2.3))
])))
]);

assert_eq!(ParquetError::General("Cannot access Group as Float".to_string()),
row.get_float(0).unwrap_err());
assert_eq!(ParquetError::General("Cannot access ListInternal as Float".to_string()),
row.get_float(1).unwrap_err());
assert_eq!(ParquetError::General("Cannot access MapInternal as Float".to_string()),
row.get_float(2).unwrap_err());
}
}