Skip to content

Commit

Permalink
remove unnecessary downcasting of generic ChunkedArray<T>
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 5, 2020
1 parent b89d917 commit a5505ab
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 223 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ Polars is written to be performant. Below are some comparisons with the (also ve
let mask = s.eq(1);
let valid = [true, false, false].iter();

assert_eq!(Vec::from(mask.bool().unwrap()), &[Some(true), Some(false), Some(false)]);
assert_eq!(Vec::from(mask), &[Some(true), Some(false), Some(false)]);
```

## Temporal data types
Expand Down
223 changes: 2 additions & 221 deletions polars/src/chunked_array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -417,227 +417,8 @@ where
}
}

/// Downcast generic `ChunkedArray<T>` to u8.
pub fn u8(self) -> Result<UInt8Chunked> {
match T::get_data_type() {
ArrowDataType::UInt8 => unsafe { Ok(std::mem::transmute(self)) },
_ => Err(PolarsError::DataTypeMisMatch),
}
}

/// Downcast generic `ChunkedArray<T>` to u16.
pub fn u16(self) -> Result<UInt16Chunked> {
match T::get_data_type() {
ArrowDataType::UInt16 => unsafe { Ok(std::mem::transmute(self)) },
_ => Err(PolarsError::DataTypeMisMatch),
}
}

/// Downcast generic `ChunkedArray<T>` to u32.
pub fn u32(self) -> Result<UInt32Chunked> {
match T::get_data_type() {
ArrowDataType::UInt32 => unsafe { Ok(std::mem::transmute(self)) },
_ => Err(PolarsError::DataTypeMisMatch),
}
}

/// Downcast generic `ChunkedArray<T>` to u64.
pub fn u64(self) -> Result<UInt64Chunked> {
match T::get_data_type() {
ArrowDataType::UInt64 => unsafe { Ok(std::mem::transmute(self)) },
_ => Err(PolarsError::DataTypeMisMatch),
}
}

/// Downcast generic `ChunkedArray<T>` to i8.
pub fn i8(self) -> Result<Int8Chunked> {
match T::get_data_type() {
ArrowDataType::Int8 => unsafe { Ok(std::mem::transmute(self)) },
_ => Err(PolarsError::DataTypeMisMatch),
}
}

/// Downcast generic `ChunkedArray<T>` to i16.
pub fn i16(self) -> Result<Int16Chunked> {
match T::get_data_type() {
ArrowDataType::Int16 => unsafe { Ok(std::mem::transmute(self)) },
_ => Err(PolarsError::DataTypeMisMatch),
}
}

/// Downcast generic `ChunkedArray<T>` to i32.
pub fn i32(self) -> Result<Int32Chunked> {
match T::get_data_type() {
ArrowDataType::Int32 => unsafe { Ok(std::mem::transmute(self)) },
_ => Err(PolarsError::DataTypeMisMatch),
}
}

/// Downcast generic `ChunkedArray<T>` to i64.
pub fn i64(self) -> Result<Int64Chunked> {
match T::get_data_type() {
ArrowDataType::Int64 => unsafe { Ok(std::mem::transmute(self)) },
_ => Err(PolarsError::DataTypeMisMatch),
}
}

/// Downcast generic `ChunkedArray<T>` to f32.
pub fn f32(self) -> Result<Float32Chunked> {
match T::get_data_type() {
ArrowDataType::Float32 => unsafe { Ok(std::mem::transmute(self)) },
_ => Err(PolarsError::DataTypeMisMatch),
}
}

/// Downcast generic `ChunkedArray<T>` to f64.
pub fn f64(self) -> Result<Float64Chunked> {
match T::get_data_type() {
ArrowDataType::Float64 => unsafe { Ok(std::mem::transmute(self)) },
_ => Err(PolarsError::DataTypeMisMatch),
}
}

/// Downcast generic `ChunkedArray<T>` to bool.
pub fn bool(self) -> Result<BooleanChunked> {
match T::get_data_type() {
ArrowDataType::Boolean => unsafe { Ok(std::mem::transmute(self)) },
_ => Err(PolarsError::DataTypeMisMatch),
}
}

/// Downcast generic `ChunkedArray<T>` to UTF-8 encoded string.
pub fn utf8(self) -> Result<Utf8Chunked> {
match T::get_data_type() {
ArrowDataType::Utf8 => unsafe { Ok(std::mem::transmute(self)) },
_ => Err(PolarsError::DataTypeMisMatch),
}
}

/// Downcast generic `ChunkedArray<T>` to date32.
pub fn date32(self) -> Result<Date32Chunked> {
match T::get_data_type() {
ArrowDataType::Date32(DateUnit::Day) => unsafe { Ok(std::mem::transmute(self)) },
_ => Err(PolarsError::DataTypeMisMatch),
}
}

/// Downcast generic `ChunkedArray<T>` to date32.
pub fn date64(self) -> Result<Date64Chunked> {
match T::get_data_type() {
ArrowDataType::Date64(DateUnit::Millisecond) => unsafe {
Ok(std::mem::transmute(self))
},
_ => Err(PolarsError::DataTypeMisMatch),
}
}
// TODO: insert types

/// Downcast generic `ChunkedArray<T>` to time32 with milliseconds unit.
pub fn time32_second(self) -> Result<Time32SecondChunked> {
match T::get_data_type() {
ArrowDataType::Time32(TimeUnit::Second) => unsafe { Ok(std::mem::transmute(self)) },
_ => Err(PolarsError::DataTypeMisMatch),
}
}

/// Downcast generic `ChunkedArray<T>` to time32 with milliseconds unit.
pub fn time32_millisecond(self) -> Result<Time32MillisecondChunked> {
match T::get_data_type() {
ArrowDataType::Time32(TimeUnit::Millisecond) => unsafe {
Ok(std::mem::transmute(self))
},
_ => Err(PolarsError::DataTypeMisMatch),
}
}

/// Downcast generic `ChunkedArray<T>` to time64 with nanoseconds unit.
pub fn time64_nanosecond(self) -> Result<Time64NanosecondChunked> {
match T::get_data_type() {
ArrowDataType::Time64(TimeUnit::Nanosecond) => unsafe { Ok(std::mem::transmute(self)) },
_ => Err(PolarsError::DataTypeMisMatch),
}
}

/// Downcast generic `ChunkedArray<T>` to time64 with microseconds unit.
pub fn time64_microsecond(self) -> Result<Time64MicrosecondChunked> {
match T::get_data_type() {
ArrowDataType::Time64(TimeUnit::Microsecond) => unsafe {
Ok(std::mem::transmute(self))
},
_ => Err(PolarsError::DataTypeMisMatch),
}
}

/// Downcast generic `ChunkedArray<T>` to duration with nanoseconds unit.
pub fn duration_nanosecond(self) -> Result<DurationNanosecondChunked> {
match T::get_data_type() {
ArrowDataType::Duration(TimeUnit::Nanosecond) => unsafe {
Ok(std::mem::transmute(self))
},
_ => Err(PolarsError::DataTypeMisMatch),
}
}

/// Downcast generic `ChunkedArray<T>` to duration with microseconds unit.
pub fn duration_microsecond(self) -> Result<DurationMicrosecondChunked> {
match T::get_data_type() {
ArrowDataType::Duration(TimeUnit::Microsecond) => unsafe {
Ok(std::mem::transmute(self))
},
_ => Err(PolarsError::DataTypeMisMatch),
}
}

/// Downcast generic `ChunkedArray<T>` to duration with seconds unit.
pub fn duration_second(self) -> Result<DurationSecondChunked> {
match T::get_data_type() {
ArrowDataType::Duration(TimeUnit::Second) => unsafe { Ok(std::mem::transmute(self)) },
_ => Err(PolarsError::DataTypeMisMatch),
}
}

/// Downcast generic `ChunkedArray<T>` to timestamp with nanoseconds unit.
pub fn timestamp_nanosecond(self) -> Result<TimestampNanosecondChunked> {
match T::get_data_type() {
ArrowDataType::Timestamp(TimeUnit::Nanosecond, _) => unsafe {
Ok(std::mem::transmute(self))
},
_ => Err(PolarsError::DataTypeMisMatch),
}
}

/// Downcast generic `ChunkedArray<T>` to timestamp with microseconds unit.
pub fn timestamp_microsecond(self) -> Result<TimestampMicrosecondChunked> {
match T::get_data_type() {
ArrowDataType::Timestamp(TimeUnit::Microsecond, _) => unsafe {
Ok(std::mem::transmute(self))
},
_ => Err(PolarsError::DataTypeMisMatch),
}
}

/// Downcast generic `ChunkedArray<T>` to timestamp with milliseconds unit.
pub fn timestamp_millisecond(self) -> Result<TimestampMillisecondChunked> {
match T::get_data_type() {
ArrowDataType::Timestamp(TimeUnit::Millisecond, _) => unsafe {
Ok(std::mem::transmute(self))
},
_ => Err(PolarsError::DataTypeMisMatch),
}
}

/// Downcast generic `ChunkedArray<T>` to timestamp with seconds unit.
pub fn timestamp_second(self) -> Result<TimestampSecondChunked> {
match T::get_data_type() {
ArrowDataType::Timestamp(TimeUnit::Second, _) => unsafe {
Ok(std::mem::transmute(self))
},
_ => Err(PolarsError::DataTypeMisMatch),
}
}

/// Get a single value. Beware this is slow.
pub fn get_any(&self, index: usize) -> AnyType {
/// Get a single value. Beware this is slow. (only used for formatting)
pub(crate) fn get_any(&self, index: usize) -> AnyType {
let (chunk_idx, idx) = self.index_to_chunked_index(index);
let arr = &self.chunks[chunk_idx];

Expand Down
2 changes: 1 addition & 1 deletion polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@
//! let mask = s.eq(1);
//! let valid = [true, false, false].iter();
//!
//! assert_eq!(Vec::from(mask.bool().unwrap()), &[Some(true), Some(false), Some(false)]);
//! assert_eq!(Vec::from(mask), &[Some(true), Some(false), Some(false)]);
//! ```
//!
//! ## Temporal data types
Expand Down

0 comments on commit a5505ab

Please sign in to comment.