Skip to content

Commit

Permalink
respect header flag in csv type inference and writing
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Oct 22, 2021
1 parent db08156 commit 5a1adce
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 5 deletions.
34 changes: 29 additions & 5 deletions polars/polars-io/src/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
//! let mut file = File::create("example.csv").expect("could not create file");
//!
//! CsvWriter::new(&mut file)
//! .has_headers(true)
//! .has_header(true)
//! .with_delimiter(b',')
//! .finish(df)
//! }
Expand Down Expand Up @@ -60,6 +60,7 @@ pub struct CsvWriter<W: Write> {
writer_builder: write::WriterBuilder,
/// arrow specific options
options: write::SerializeOptions,
header: bool,
}

impl<W> SerWriter<W> for CsvWriter<W>
Expand All @@ -81,13 +82,16 @@ where
buffer,
writer_builder: write::WriterBuilder::new(),
options,
header: true,
}
}

fn finish(self, df: &DataFrame) -> Result<()> {
let mut writer = self.writer_builder.from_writer(self.buffer);
let iter = df.iter_record_batches();
write::write_header(&mut writer, &df.schema().to_arrow())?;
if self.header {
write::write_header(&mut writer, &df.schema().to_arrow())?;
}
for batch in iter {
write::write_batch(&mut writer, &batch, &self.options)?;
}
Expand All @@ -100,8 +104,8 @@ where
W: Write,
{
/// Set whether to write headers
pub fn has_headers(mut self, has_headers: bool) -> Self {
self.writer_builder.has_headers(has_headers);
pub fn has_header(mut self, has_header: bool) -> Self {
self.header = has_header;
self
}

Expand Down Expand Up @@ -595,11 +599,19 @@ mod test {
let df = create_df();

CsvWriter::new(&mut buf)
.has_headers(true)
.has_header(true)
.finish(&df)
.expect("csv written");
let csv = std::str::from_utf8(&buf).unwrap();
assert_eq!("days,temp\n0,22.1\n1,19.9\n2,7.0\n3,2.0\n4,3.0\n", csv);

let mut buf: Vec<u8> = Vec::new();
CsvWriter::new(&mut buf)
.has_header(false)
.finish(&df)
.expect("csv written");
let csv = std::str::from_utf8(&buf).unwrap();
assert_eq!("0,22.1\n1,19.9\n2,7.0\n3,2.0\n4,3.0\n", csv);
}

#[test]
Expand Down Expand Up @@ -1183,4 +1195,16 @@ linenum,last_name,first_name

Ok(())
}

#[test]
fn test_header_inference() -> Result<()> {
let csv = r#"not_a_header,really,even_if,it_looks_like_one
1,2,3,4
4,3,2,1
"#;
let file = Cursor::new(csv);
let df = CsvReader::new(file).has_header(false).finish()?;
assert_eq!(df.dtypes(), vec![DataType::Utf8; 4]);
Ok(())
}
}
4 changes: 4 additions & 0 deletions polars/polars-io/src/csv_core/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,10 @@ pub fn infer_file_schema(
} else {
return Err(PolarsError::NoData("empty csv".into()));
};
if !has_header {
// re-init lines so that the header is included in type inference.
lines = SplitLines::new(bytes, b'\n').skip(skip_rows);
}

let header_length = headers.len();
// keep track of inferred field types
Expand Down

0 comments on commit 5a1adce

Please sign in to comment.