Skip to content

Commit

Permalink
csv: improve data/datetime/bool overwrite (#4247)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 4, 2022
1 parent 52120c6 commit 033881f
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 19 deletions.
20 changes: 13 additions & 7 deletions polars/polars-io/src/csv/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,10 @@ pub(crate) fn init_buffers(
ignore_errors,
)),
#[cfg(feature = "dtype-datetime")]
&DataType::Datetime(_, _) => Buffer::Datetime(DatetimeField::new(name, capacity)),
&DataType::Datetime(tu, _) => Buffer::Datetime {
buf: DatetimeField::new(name, capacity),
tu,
},
#[cfg(feature = "dtype-date")]
&DataType::Date => Buffer::Date(DatetimeField::new(name, capacity)),
other => {
Expand All @@ -425,7 +428,10 @@ pub(crate) enum Buffer {
/// Stores the Utf8 fields and the total string length seen for that column
Utf8(Utf8Field),
#[cfg(feature = "dtype-datetime")]
Datetime(DatetimeField<Int64Type>),
Datetime {
buf: DatetimeField<Int64Type>,
tu: TimeUnit,
},
#[cfg(feature = "dtype-date")]
Date(DatetimeField<Int32Type>),
}
Expand All @@ -441,11 +447,11 @@ impl Buffer {
Buffer::Float32(v) => v.finish().into_series(),
Buffer::Float64(v) => v.finish().into_series(),
#[cfg(feature = "dtype-datetime")]
Buffer::Datetime(v) => v
Buffer::Datetime { buf, tu } => buf
.builder
.finish()
.into_series()
.cast(&DataType::Datetime(TimeUnit::Microseconds, None))
.cast(&DataType::Datetime(tu, None))
.unwrap(),
#[cfg(feature = "dtype-date")]
Buffer::Date(v) => v
Expand Down Expand Up @@ -518,7 +524,7 @@ impl Buffer {
v.validity.push(false);
}
#[cfg(feature = "dtype-datetime")]
Buffer::Datetime(v) => v.builder.append_null(),
Buffer::Datetime { buf, .. } => buf.builder.append_null(),
#[cfg(feature = "dtype-date")]
Buffer::Date(v) => v.builder.append_null(),
};
Expand All @@ -535,7 +541,7 @@ impl Buffer {
Buffer::Float64(_) => DataType::Float64,
Buffer::Utf8(_) => DataType::Utf8,
#[cfg(feature = "dtype-datetime")]
Buffer::Datetime(_) => DataType::Datetime(TimeUnit::Microseconds, None),
Buffer::Datetime { tu, .. } => DataType::Datetime(*tu, None),
#[cfg(feature = "dtype-date")]
Buffer::Date(_) => DataType::Date,
}
Expand Down Expand Up @@ -596,7 +602,7 @@ impl Buffer {
<Utf8Field as ParsedBuffer>::parse_bytes(buf, bytes, ignore_errors, needs_escaping)
}
#[cfg(feature = "dtype-datetime")]
Datetime(buf) => <DatetimeField<Int64Type> as ParsedBuffer>::parse_bytes(
Datetime { buf, .. } => <DatetimeField<Int64Type> as ParsedBuffer>::parse_bytes(
buf,
bytes,
ignore_errors,
Expand Down
18 changes: 7 additions & 11 deletions polars/polars-io/src/csv/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,30 +376,26 @@ where
#[allow(clippy::unnecessary_filter_map)]
let fields: Vec<_> = schema
.iter_fields()
.filter_map(|fld| {
.filter_map(|mut fld| {
use DataType::*;
match fld.data_type() {
// For categorical we first read as utf8 and later cast to categorical
#[cfg(feature = "dtype-categorical")]
Categorical(_) => {
to_cast_local.push(fld.clone());
Some(Field::new(fld.name(), DataType::Utf8))
}
Date | Datetime(_, _) => {
to_cast.push(fld);
// let inference decide the column type
None
fld.coerce(DataType::Utf8);
Some(fld)
}
Time => {
to_cast.push(fld);
// let inference decide the column type
None
}
Int8 | Int16 | UInt8 | UInt16 | Boolean => {
Int8 | Int16 | UInt8 | UInt16 => {
// We have not compiled these buffers, so we cast them later.
to_cast.push(fld);
// let inference decide the column type
None
to_cast.push(fld.clone());
fld.coerce(DataType::Int32);
Some(fld)
}
_ => Some(fld),
}
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-time/src/chunkedarray/utf8/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub struct DatetimeInfer<T> {
transform: fn(&str, &str) -> Option<T>,
transform_bytes: fn(&[u8], &[u8], u16) -> Option<T>,
fmt_len: u16,
logical_type: DataType,
pub logical_type: DataType,
}

impl TryFrom<Pattern> for DatetimeInfer<i64> {
Expand Down
9 changes: 9 additions & 0 deletions py-polars/tests/io/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,3 +607,12 @@ def test_skip_new_line_embedded_lines() -> None:
"d": ["Test A", "Test B \\n"],
"e\\n": ["\\n", "\\n"],
}


def test_csv_dtype_overwrite_bool() -> None:
csv = "a, b\n" + ",false\n" + ",false\n" + ",false"
df = pl.read_csv(
csv.encode(),
dtypes={"a": pl.Boolean, "b": pl.Boolean},
)
assert df.dtypes == [pl.Boolean, pl.Boolean]

0 comments on commit 033881f

Please sign in to comment.