Skip to content

Commit

Permalink
feat: 🎸 support dtypes option for read_csv function (#64)
Browse files Browse the repository at this point in the history
Co-authored-by: 舒荣贵 <ronggui.shu@sci-inv.com>
  • Loading branch information
littledian and 舒荣贵 committed Apr 14, 2023
1 parent c3df322 commit a591ded
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 1 deletion.
6 changes: 6 additions & 0 deletions __tests__/io.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ describe("read:csv", () => {
const maxRowCount = df.getColumn("rc").max();
expect(expectedMaxRowCount).toStrictEqual(maxRowCount);
});
test("csv with dtypes", () => {
const df = pl.readCSV(csvpath, { dtypes: { calories: pl.Utf8 }});
expect(df.dtypes[1].equals(pl.Utf8)).toBeTruthy();
const df2 = pl.readCSV(csvpath);
expect(df2.dtypes[1].equals(pl.Int64)).toBeTruthy();
});
it.todo("can read from a stream");
});

Expand Down
2 changes: 1 addition & 1 deletion polars/io.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ export interface ReadCsvOptions {
rechunk: boolean;
encoding: "utf8" | "utf8-lossy";
numThreads: number;
dtype: any;
dtypes: Record<string, DataType>;
lowMemory: boolean;
commentChar: string;
quotChar: string;
Expand Down
12 changes: 12 additions & 0 deletions src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use polars::io::RowCount;
use std::borrow::Borrow;
use std::fs::File;
use std::io::{BufReader, BufWriter, Cursor};
use std::collections::HashMap;

#[napi]
#[repr(transparent)]
Expand Down Expand Up @@ -59,6 +60,7 @@ pub struct ReadCsvOptions {
pub columns: Option<Vec<String>>,
pub encoding: String,
pub n_threads: Option<u32>,
pub dtypes: Option<HashMap<String, Wrap<DataType>>>,
pub null_values: Option<Wrap<NullValues>>,
pub path: Option<String>,
pub low_memory: bool,
Expand Down Expand Up @@ -98,6 +100,15 @@ pub fn read_csv(
"utf8-lossy" => CsvEncoding::LossyUtf8,
e => return Err(JsPolarsErr::Other(format!("encoding not {} not implemented.", e)).into()),
};

let dtypes = options.dtypes.map(|map| {
let fields = map.iter().map(|(key, val)| {
let value = val.clone().0;
Field::new(key, value)
});
Schema::from(fields)
});

let df = match path_or_buffer {
Either::A(path) => CsvReader::from_path(path)
.expect("unable to read file")
Expand All @@ -112,6 +123,7 @@ pub fn read_csv(
.with_encoding(encoding)
.with_columns(options.columns)
.with_n_threads(n_threads)
.with_dtypes(dtypes.as_ref())
.low_memory(options.low_memory)
.with_comment_char(comment_char)
.with_null_values(null_values)
Expand Down

0 comments on commit a591ded

Please sign in to comment.