-
Notifications
You must be signed in to change notification settings - Fork 20
Add type-safe accessors for primitive types in Row #86
Changes from 7 commits
abcf257
198a1e3
d09907e
acf66fa
d1aeaa5
910f118
15533c2
e5c9b6e
89b04f9
76d4ab0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,7 @@ use std::fmt; | |
use basic::{LogicalType, Type as PhysicalType}; | ||
use chrono::{Local, TimeZone}; | ||
use data_type::{ByteArray, Int96}; | ||
use errors::{ParquetError, Result}; | ||
|
||
/// Macro as a shortcut to generate 'not yet implemented' panic error. | ||
macro_rules! nyi { | ||
|
@@ -41,6 +42,79 @@ pub struct Row { | |
fields: Vec<(String, Field)> | ||
} | ||
|
||
impl Row { | ||
/// Get then number of fields in this row | ||
pub fn len(&self) -> usize { | ||
self.fields.len() | ||
} | ||
} | ||
|
||
/// Trait for type-safe convenient access to fields within a Row | ||
trait RowAccessor { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This needs to be public and be exported in |
||
fn get_bool(&self, i: usize) -> Result<bool>; | ||
fn get_byte(&self, i: usize) -> Result<i8>; | ||
fn get_short(&self, i: usize) -> Result<i16>; | ||
fn get_int(&self, i: usize) -> Result<i32>; | ||
fn get_long(&self, i: usize) -> Result<i64>; | ||
fn get_float(&self, i: usize) -> Result<f32>; | ||
fn get_double(&self, i: usize) -> Result<f64>; | ||
fn get_timestamp(&self, i: usize) -> Result<u64>; | ||
fn get_string(&self, i: usize) -> Result<&String>; | ||
fn get_bytes(&self, i: usize) -> Result<&ByteArray>; | ||
fn get_group(&self, i: usize) -> Result<&Row>; | ||
fn get_list(&self, i: usize) -> Result<&List>; | ||
fn get_map(&self, i: usize) -> Result<&Map>; | ||
} | ||
|
||
/// Macro to generate type-safe get_xxx methods for primitive types e.g. get_bool, get_short | ||
macro_rules! row_primitive_accessor { | ||
($METHOD:ident, $VARIANT:ident, $TY:ty) => { | ||
|
||
fn $METHOD(&self, i: usize) -> Result<$TY> { | ||
match self.fields[i].1 { | ||
Field::$VARIANT(v) => Ok(v), | ||
_ => Err(general_err!("Cannot access {} as {}", | ||
self.fields[i].1.get_type_name(), stringify!($VARIANT))) | ||
} | ||
} | ||
|
||
} | ||
} | ||
|
||
/// Macro to generate type-safe get_xxx methods for reference types e.g. get_list, get_map | ||
macro_rules! row_complex_accessor { | ||
($METHOD:ident, $VARIANT:ident, $TY:ty) => { | ||
|
||
fn $METHOD(&self, i: usize) -> Result<&$TY> { | ||
match self.fields[i].1 { | ||
Field::$VARIANT(ref v) => Ok(v), | ||
_ => Err(general_err!("Cannot access {} as {}", | ||
self.fields[i].1.get_type_name(), stringify!($VARIANT))) | ||
} | ||
} | ||
|
||
} | ||
} | ||
|
||
impl RowAccessor for Row { | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you remove empty line? |
||
row_primitive_accessor!(get_bool, Bool, bool); | ||
row_primitive_accessor!(get_byte, Byte, i8); | ||
row_primitive_accessor!(get_short, Short, i16); | ||
row_primitive_accessor!(get_int, Int, i32); | ||
row_primitive_accessor!(get_long, Long, i64); | ||
row_primitive_accessor!(get_float, Float, f32); | ||
row_primitive_accessor!(get_double, Double, f64); | ||
row_primitive_accessor!(get_timestamp, Timestamp, u64); | ||
|
||
row_complex_accessor!(get_string, Str, String); | ||
row_complex_accessor!(get_bytes, Bytes, ByteArray); | ||
row_complex_accessor!(get_group, Group, Row); | ||
row_complex_accessor!(get_list, ListInternal, List); | ||
row_complex_accessor!(get_map, MapInternal, Map); | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you remove empty line? |
||
} | ||
|
||
/// Constructs a `Row` from the list of `fields` and returns it. | ||
#[inline] | ||
pub fn make_row(fields: Vec<(String, Field)>) -> Row { | ||
|
@@ -71,6 +145,13 @@ pub struct List { | |
elements: Vec<Field> | ||
} | ||
|
||
impl List { | ||
/// Get then number of fields in this row | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should it be |
||
pub fn len(&self) -> usize { | ||
self.elements.len() | ||
} | ||
} | ||
|
||
/// Constructs a `List` from the list of `fields` and returns it. | ||
#[inline] | ||
pub fn make_list(elements: Vec<Field>) -> List { | ||
|
@@ -86,6 +167,13 @@ pub struct Map { | |
entries: Vec<(Field, Field)> | ||
} | ||
|
||
impl Map { | ||
/// Get then number of fields in this row | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here. |
||
pub fn len(&self) -> usize { | ||
self.entries.len() | ||
} | ||
} | ||
|
||
/// Constructs a `Map` from the list of `entries` and returns it. | ||
#[inline] | ||
pub fn make_map(entries: Vec<(Field, Field)>) -> Map { | ||
|
@@ -137,7 +225,39 @@ pub enum Field { | |
MapInternal(Map) | ||
} | ||
|
||
|
||
impl Field { | ||
/// Get the type name | ||
fn get_type_name(&self) -> &'static str { | ||
match *self { | ||
Field::Null => "Null", | ||
Field::Bool(_) => "Bool", | ||
Field::Byte(_) => "Byte", | ||
Field::Short(_) => "Short", | ||
Field::Int(_) => "Int", | ||
Field::Long(_) => "Long", | ||
Field::Float(_) => "Float", | ||
Field::Double(_) => "Double", | ||
Field::Date(_) => "Date", | ||
Field::Str(_) => "Str", | ||
Field::Bytes(_) => "Bytes", | ||
Field::Timestamp(_) => "Timestamp", | ||
Field::Group(_) => "Group", | ||
Field::ListInternal(_) => "ListInternal", | ||
Field::MapInternal(_) => "MapInternal", | ||
} | ||
} | ||
|
||
/// Determines if this Row represents a primitive value | ||
pub fn is_primitive(&self) -> bool { | ||
match *self { | ||
Field::Group(_) => false, | ||
Field::ListInternal(_) => false, | ||
Field::MapInternal(_) => false, | ||
_ => true | ||
} | ||
} | ||
|
||
/// Converts Parquet BOOLEAN type with logical type into `bool` value. | ||
pub fn convert_bool( | ||
_physical_type: PhysicalType, | ||
|
@@ -482,4 +602,148 @@ mod tests { | |
])); | ||
assert_eq!(format!("{}", row), "{1 -> 1.2, 2 -> 4.5, 3 -> 2.3}"); | ||
} | ||
|
||
#[test] | ||
fn test_is_primitive() { | ||
// primitives | ||
assert!(Field::Null.is_primitive()); | ||
assert!(Field::Bool(true).is_primitive()); | ||
assert!(Field::Bool(false).is_primitive()); | ||
assert!(Field::Byte(1).is_primitive()); | ||
assert!(Field::Short(2).is_primitive()); | ||
assert!(Field::Int(3).is_primitive()); | ||
assert!(Field::Long(4).is_primitive()); | ||
assert!(Field::Float(5.0).is_primitive()); | ||
assert!(Field::Float(5.1234).is_primitive()); | ||
assert!(Field::Double(6.0).is_primitive()); | ||
assert!(Field::Double(6.1234).is_primitive()); | ||
assert!(Field::Str("abc".to_string()).is_primitive()); | ||
assert!(Field::Bytes(ByteArray::from(vec![1, 2, 3])).is_primitive()); | ||
assert!(Field::Timestamp(12345678).is_primitive()); | ||
|
||
let value = ByteArray::from(vec![1, 2, 3, 4, 5]); | ||
assert!(Field::Bytes(value).is_primitive()); | ||
|
||
// complex types | ||
assert_eq!(false, Field::Group(make_row(vec![ | ||
("x".to_string(), Field::Null), | ||
("Y".to_string(), Field::Int(2)), | ||
("z".to_string(), Field::Float(3.1)), | ||
("a".to_string(), Field::Str("abc".to_string())) | ||
])).is_primitive()); | ||
|
||
assert_eq!(false, Field::ListInternal(make_list(vec![ | ||
Field::Int(2), | ||
Field::Int(1), | ||
Field::Null, | ||
Field::Int(12) | ||
])).is_primitive()); | ||
|
||
assert_eq!(false, Field::MapInternal(make_map(vec![ | ||
(Field::Int(1), Field::Float(1.2)), | ||
(Field::Int(2), Field::Float(4.5)), | ||
(Field::Int(3), Field::Float(2.3)) | ||
])).is_primitive()); | ||
} | ||
|
||
#[test] | ||
fn test_row_primitive_accessors() { | ||
// primitives | ||
let row = make_row(vec![ | ||
("a".to_string(), Field::Null), | ||
("b".to_string(), Field::Bool(false)), | ||
("c".to_string(), Field::Byte(3)), | ||
("d".to_string(), Field::Short(4)), | ||
("e".to_string(), Field::Int(5)), | ||
("f".to_string(), Field::Long(6)), | ||
("g".to_string(), Field::Float(7.1)), | ||
("h".to_string(), Field::Double(8.1)), | ||
("i".to_string(), Field::Str("abc".to_string())), | ||
("j".to_string(), Field::Bytes(ByteArray::from(vec![1, 2, 3, 4, 5]))) | ||
]); | ||
|
||
assert_eq!(false, row.get_bool(1).unwrap()); | ||
assert_eq!(3, row.get_byte(2).unwrap()); | ||
assert_eq!(4, row.get_short(3).unwrap()); | ||
assert_eq!(5, row.get_int(4).unwrap()); | ||
assert_eq!(6, row.get_long(5).unwrap()); | ||
assert_eq!(7.1, row.get_float(6).unwrap()); | ||
assert_eq!(8.1, row.get_double(7).unwrap()); | ||
assert!("abc".to_string().eq(row.get_string(8).unwrap())); | ||
assert_eq!(5, row.get_bytes(9).unwrap().len()); | ||
} | ||
|
||
#[test] | ||
fn test_row_primitive_invalid_accessors() { | ||
// primitives | ||
let row = make_row(vec![ | ||
("a".to_string(), Field::Null), | ||
("b".to_string(), Field::Bool(false)), | ||
("c".to_string(), Field::Byte(3)), | ||
("d".to_string(), Field::Short(4)), | ||
("e".to_string(), Field::Int(5)), | ||
("f".to_string(), Field::Long(6)), | ||
("g".to_string(), Field::Float(7.1)), | ||
("h".to_string(), Field::Double(8.1)), | ||
("i".to_string(), Field::Str("abc".to_string())), | ||
("j".to_string(), Field::Bytes(ByteArray::from(vec![1, 2, 3, 4, 5]))) | ||
]); | ||
|
||
for i in 0..10 { | ||
assert!(row.get_group(i).is_err()); | ||
} | ||
} | ||
|
||
#[test] | ||
fn test_row_complex_accessors() { | ||
let row = make_row(vec![ | ||
("a".to_string(), Field::Group(make_row(vec![ | ||
("x".to_string(), Field::Null), | ||
("Y".to_string(), Field::Int(2)) | ||
]))), | ||
("b".to_string(), Field::ListInternal(make_list(vec![ | ||
Field::Int(2), | ||
Field::Int(1), | ||
Field::Null, | ||
Field::Int(12) | ||
]))), | ||
("c".to_string(), Field::MapInternal(make_map(vec![ | ||
(Field::Int(1), Field::Float(1.2)), | ||
(Field::Int(2), Field::Float(4.5)), | ||
(Field::Int(3), Field::Float(2.3)) | ||
]))) | ||
]); | ||
|
||
assert_eq!(2, row.get_group(0).unwrap().len()); | ||
assert_eq!(4, row.get_list(1).unwrap().len()); | ||
assert_eq!(3, row.get_map(2).unwrap().len()); | ||
} | ||
|
||
#[test] | ||
fn test_row_complex_invalid_accessors() { | ||
let row = make_row(vec![ | ||
("a".to_string(), Field::Group(make_row(vec![ | ||
("x".to_string(), Field::Null), | ||
("Y".to_string(), Field::Int(2)) | ||
]))), | ||
("b".to_string(), Field::ListInternal(make_list(vec![ | ||
Field::Int(2), | ||
Field::Int(1), | ||
Field::Null, | ||
Field::Int(12) | ||
]))), | ||
("c".to_string(), Field::MapInternal(make_map(vec![ | ||
(Field::Int(1), Field::Float(1.2)), | ||
(Field::Int(2), Field::Float(4.5)), | ||
(Field::Int(3), Field::Float(2.3)) | ||
]))) | ||
]); | ||
|
||
assert_eq!(ParquetError::General("Cannot access Group as Float".to_string()), | ||
row.get_float(0).unwrap_err()); | ||
assert_eq!(ParquetError::General("Cannot access ListInternal as Float".to_string()), | ||
row.get_float(1).unwrap_err()); | ||
assert_eq!(ParquetError::General("Cannot access MapInternal as Float".to_string()), | ||
row.get_float(2).unwrap_err()); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should it be
the
?