diff --git a/__tests__/io.test.ts b/__tests__/io.test.ts index daea0c1d..8b084dcb 100644 --- a/__tests__/io.test.ts +++ b/__tests__/io.test.ts @@ -76,6 +76,12 @@ describe("read:csv", () => { const maxRowCount = df.getColumn("rc").max(); expect(expectedMaxRowCount).toStrictEqual(maxRowCount); }); + test("csv with dtypes", () => { + const df = pl.readCSV(csvpath, { dtypes: { calories: pl.Utf8 }}); + expect(df.dtypes[1].equals(pl.Utf8)).toBeTruthy(); + const df2 = pl.readCSV(csvpath); + expect(df2.dtypes[1].equals(pl.Int64)).toBeTruthy(); + }); it.todo("can read from a stream"); }); diff --git a/polars/io.ts b/polars/io.ts index 2be3aa3a..b478a943 100644 --- a/polars/io.ts +++ b/polars/io.ts @@ -20,7 +20,7 @@ export interface ReadCsvOptions { rechunk: boolean; encoding: "utf8" | "utf8-lossy"; numThreads: number; - dtype: any; + dtypes: Record; lowMemory: boolean; commentChar: string; quotChar: string; diff --git a/src/dataframe.rs b/src/dataframe.rs index ac2f6a55..18ad8f55 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -8,6 +8,7 @@ use polars::io::RowCount; use std::borrow::Borrow; use std::fs::File; use std::io::{BufReader, BufWriter, Cursor}; +use std::collections::HashMap; #[napi] #[repr(transparent)] @@ -59,6 +60,7 @@ pub struct ReadCsvOptions { pub columns: Option>, pub encoding: String, pub n_threads: Option, + pub dtypes: Option>>, pub null_values: Option>, pub path: Option, pub low_memory: bool, @@ -98,6 +100,15 @@ pub fn read_csv( "utf8-lossy" => CsvEncoding::LossyUtf8, e => return Err(JsPolarsErr::Other(format!("encoding not {} not implemented.", e)).into()), }; + + let dtypes = options.dtypes.map(|map| { + let fields = map.iter().map(|(key, val)| { + let value = val.clone().0; + Field::new(key, value) + }); + Schema::from(fields) + }); + let df = match path_or_buffer { Either::A(path) => CsvReader::from_path(path) .expect("unable to read file") @@ -112,6 +123,7 @@ pub fn read_csv( .with_encoding(encoding) .with_columns(options.columns) .with_n_threads(n_threads) + .with_dtypes(dtypes.as_ref()) .low_memory(options.low_memory) .with_comment_char(comment_char) .with_null_values(null_values)