Skip to content

Commit

Permalink
csv: allow multiple null values
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jul 14, 2022
1 parent 5367833 commit 0e7e79f
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 38 deletions.
26 changes: 13 additions & 13 deletions polars/polars-io/src/csv/parser.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::buffer::*;
use crate::csv::read::NullValuesCompiled;
use num::traits::Pow;
use polars_core::prelude::*;

Expand Down Expand Up @@ -373,13 +374,13 @@ fn skip_this_line(bytes: &[u8], quote: Option<u8>) -> &[u8] {
/// * `buffers` - Parsed output will be written to these buffers. Except for UTF8 data. The offsets of the
/// fields are written to the buffers. The UTF8 data will be parsed later.
#[allow(clippy::too_many_arguments)]
pub(crate) fn parse_lines(
pub(super) fn parse_lines(
mut bytes: &[u8],
offset: usize,
delimiter: u8,
comment_char: Option<u8>,
quote_char: Option<u8>,
null_values: Option<&Vec<String>>,
null_values: Option<&NullValuesCompiled>,
projection: &[usize],
buffers: &mut [Buffer],
ignore_parser_errors: bool,
Expand Down Expand Up @@ -465,17 +466,16 @@ pub(crate) fn parse_lines(
let mut add_null = false;

// if we have null values argument, check if this field equal null value
if let Some(null_values) = &null_values {
if let Some(null_value) = null_values.get(processed_fields) {
let field = if needs_escaping && !field.is_empty() {
&field[1..field.len() - 1]
} else {
field
};
if field == null_value.as_bytes() {
add_null = true;
}
}
if let Some(null_values) = null_values {
let field = if needs_escaping && !field.is_empty() {
&field[1..field.len() - 1]
} else {
field
};

// safety:
// process fields is in bounds
add_null = unsafe { null_values.is_null(field, processed_fields) }
}
if add_null {
buf.add_null()
Expand Down
56 changes: 45 additions & 11 deletions polars/polars-io/src/csv/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,63 @@ pub enum CsvEncoding {
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum NullValues {
/// A single value that's used for all columns
AllColumns(String),
/// A different null value per column
Columns(Vec<String>),
AllColumnsSingle(String),
/// Multiple values that are used for all columns
AllColumns(Vec<String>),
/// Tuples that map column names to null value of that column
Named(Vec<(String, String)>),
}

pub(super) enum NullValuesCompiled {
/// A single value that's used for all columns
AllColumnsSingle(String),
// Multiple null values that are null for all columns
AllColumns(Vec<String>),
/// A different null value per column, computed from `NullValues::Named`
Columns(Vec<String>),
}

impl NullValuesCompiled {
pub(super) fn apply_projection(&mut self, projections: &[usize]) {
if let Self::Columns(nv) = self {
let nv = projections
.iter()
.map(|i| std::mem::take(&mut nv[*i]))
.collect::<Vec<_>>();

*self = NullValuesCompiled::Columns(nv);
}
}

/// Safety
/// The caller must ensure that `index` is in bounds
pub(super) unsafe fn is_null(&self, field: &[u8], index: usize) -> bool {
use NullValuesCompiled::*;
match self {
AllColumnsSingle(v) => v.as_bytes() == field,
AllColumns(v) => v.iter().any(|v| v.as_bytes() == field),
Columns(v) => {
debug_assert!(index < v.len());
v.get_unchecked(index).as_bytes() == field
}
}
}
}

impl NullValues {
/// Use the schema and the null values to produce a null value for every column.
pub(crate) fn process(self, schema: &Schema) -> Result<Vec<String>> {
let out = match self {
NullValues::Columns(v) => v,
NullValues::AllColumns(v) => (0..schema.len()).map(|_| v.clone()).collect(),
pub(super) fn compile(self, schema: &Schema) -> Result<NullValuesCompiled> {
Ok(match self {
NullValues::AllColumnsSingle(v) => NullValuesCompiled::AllColumnsSingle(v),
NullValues::AllColumns(v) => NullValuesCompiled::AllColumns(v),
NullValues::Named(v) => {
let mut null_values = vec!["".to_string(); schema.len()];
for (name, null_value) in v {
let i = schema.try_index_of(&name)?;
null_values[i] = null_value;
}
null_values
NullValuesCompiled::Columns(null_values)
}
};
Ok(out)
})
}
}

Expand Down
13 changes: 6 additions & 7 deletions polars/polars-io/src/csv/read_impl.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::aggregations::ScanAggregation;
use crate::csv::read::NullValuesCompiled;
use crate::csv::utils::*;
use crate::csv::{buffer::*, parser::*};
use crate::csv::{CsvEncoding, NullValues};
Expand Down Expand Up @@ -77,7 +78,7 @@ pub(crate) struct CoreReader<'a> {
low_memory: bool,
comment_char: Option<u8>,
quote_char: Option<u8>,
null_values: Option<Vec<String>>,
null_values: Option<NullValuesCompiled>,
predicate: Option<Arc<dyn PhysicalIoExpr>>,
aggregate: Option<&'a [ScanAggregation]>,
to_cast: &'a [Field],
Expand Down Expand Up @@ -237,7 +238,7 @@ impl<'a> CoreReader<'a> {
}

// create a null value for every column
let mut null_values = null_values.map(|nv| nv.process(&schema)).transpose()?;
let mut null_values = null_values.map(|nv| nv.compile(&schema)).transpose()?;

if let Some(cols) = columns {
let mut prj = Vec::with_capacity(cols.len());
Expand All @@ -247,11 +248,9 @@ impl<'a> CoreReader<'a> {
}

// update null values with projection
null_values = null_values.map(|mut nv| {
prj.iter()
.map(|i| std::mem::take(&mut nv[*i]))
.collect::<Vec<_>>()
});
if let Some(nv) = null_values.as_mut() {
nv.apply_projection(&prj);
}

projection = Some(prj);
}
Expand Down
6 changes: 3 additions & 3 deletions polars/polars-io/src/csv/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,12 +325,12 @@ pub fn infer_file_schema(
None => {
column_types[i].insert(infer_field_schema(&s, parse_dates));
}
Some(NullValues::Columns(names)) => {
if !names.iter().any(|name| name == s.as_ref()) {
Some(NullValues::AllColumns(names)) => {
if !names.iter().any(|nv| nv == s.as_ref()) {
column_types[i].insert(infer_field_schema(&s, parse_dates));
}
}
Some(NullValues::AllColumns(name)) => {
Some(NullValues::AllColumnsSingle(name)) => {
if s.as_ref() != name {
column_types[i].insert(infer_field_schema(&s, parse_dates));
}
Expand Down
4 changes: 2 additions & 2 deletions polars/tests/it/io/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ null-value,b,bar,
let file = Cursor::new(csv);
let df = CsvReader::new(file)
.has_header(false)
.with_null_values(NullValues::AllColumns("null-value".to_string()).into())
.with_null_values(NullValues::AllColumnsSingle("null-value".to_string()).into())
.finish()?;
assert!(df.get_columns()[0].null_count() > 0);
Ok(())
Expand Down Expand Up @@ -773,7 +773,7 @@ fn test_null_values_infer_schema() -> Result<()> {
5,6"#;
let file = Cursor::new(csv);
let df = CsvReader::new(file)
.with_null_values(Some(NullValues::AllColumns("NA".into())))
.with_null_values(Some(NullValues::AllColumnsSingle("NA".into())))
.finish()?;
let expected = &[DataType::Int64, DataType::Int64];
assert_eq!(df.dtypes(), expected);
Expand Down
4 changes: 2 additions & 2 deletions py-polars/src/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,9 @@ impl<'a> FromPyObject<'a> for Wrap<Utf8Chunked> {
impl<'a> FromPyObject<'a> for Wrap<NullValues> {
fn extract(ob: &'a PyAny) -> PyResult<Self> {
if let Ok(s) = ob.extract::<String>() {
Ok(Wrap(NullValues::AllColumns(s)))
Ok(Wrap(NullValues::AllColumnsSingle(s)))
} else if let Ok(s) = ob.extract::<Vec<String>>() {
Ok(Wrap(NullValues::Columns(s)))
Ok(Wrap(NullValues::AllColumns(s)))
} else if let Ok(s) = ob.extract::<Vec<(String, String)>>() {
Ok(Wrap(NullValues::Named(s)))
} else {
Expand Down
23 changes: 23 additions & 0 deletions py-polars/tests/io/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,3 +549,26 @@ def test_csv_whitepsace_delimiter_at_end_do_not_skip() -> None:
"column_5": [None],
"column_6": [None],
}


def test_csv_multiple_null_values() -> None:
df = pl.DataFrame(
{
"a": [1, 2, None, 4],
"b": ["2022-01-01", "__NA__", "", "NA"],
}
)

f = io.BytesIO()
df.write_csv(f)
f.seek(0)

df2 = pl.read_csv(f, null_values=["__NA__", "NA"])
expected = pl.DataFrame(
{
"a": [1, 2, None, 4],
"b": ["2022-01-01", None, "", None],
}
)

assert df2.frame_equal(expected)

0 comments on commit 0e7e79f

Please sign in to comment.