Skip to content

Commit

Permalink
feat: add Decimal type (#226)
Browse files Browse the repository at this point in the history
  • Loading branch information
universalmind303 committed Jun 12, 2024
1 parent 0707fb9 commit 9f1e2f1
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 22 deletions.
12 changes: 12 additions & 0 deletions __tests__/io.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,18 @@ describe("parquet", () => {
const df = pl.scanParquet(parquetpath, { nRows: 4 }).collectSync();
expect(df.shape).toEqual({ height: 4, width: 4 });
});

test("writeParquet with decimals", async () => {
const df = pl.DataFrame([
pl.Series("decimal", [1n, 2n, 3n], pl.Decimal()),
pl.Series("u32", [1, 2, 3], pl.UInt32),
pl.Series("str", ["a", "b", "c"]),
]);

const buf = df.writeParquet();
const newDF = pl.readParquet(buf);
expect(newDF).toFrameEqual(df);
});
});

describe("ipc", () => {
Expand Down
27 changes: 26 additions & 1 deletion __tests__/series.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
/* eslint-disable newline-per-chained-call */
import pl from "@polars";
import { InvalidOperationError } from "../polars/error";
import Chance from "chance";

describe("from lists", () => {
Expand Down Expand Up @@ -186,6 +185,32 @@ describe("typedArrays", () => {
const actual = pl.Series(float64Array).toTypedArray();
expect(JSON.stringify(actual)).toEqual(JSON.stringify(float64Array));
});

test("decimal", () => {
const expected = [1n, 2n, 3n];
const expectedDtype = pl.Decimal(10, 2);
const actual = pl.Series("", expected, expectedDtype);
expect(actual.dtype).toEqual(expectedDtype);
try {
actual.toArray();
} catch (e: any) {
expect(e.message).toContain(
"Decimal is not a supported type in javascript, please convert to string or number before collecting to js",
);
}
});

test("fixed list", () => {
const expectedDtype = pl.FixedSizeList(pl.Float32, 3);
const expected = [
[1, 2, 3],
[4, 5, 6],
];
const actual = pl.Series("", expected, expectedDtype);
expect(actual.dtype).toEqual(expectedDtype);
const actualValues = actual.toArray();
expect(actualValues).toEqual(expected);
});
});
describe("series", () => {
const chance = new Chance();
Expand Down
43 changes: 39 additions & 4 deletions polars/datatypes/datatype.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ export abstract class DataType {
return new Categorical();
}

/** Decimal type */
public static Decimal(precision?: number, scale?: number): DataType {
return new Decimal(precision, scale);
}

/**
* Calendar date and time type
* @param timeUnit any of 'ms' | 'ns' | 'us'
Expand Down Expand Up @@ -186,6 +191,39 @@ export class String extends DataType {}

export class Categorical extends DataType {}

export class Decimal extends DataType {
private precision: number | null;
private scale: number | null;
constructor(precision?: number, scale?: number) {
super();
this.precision = precision ?? null;
this.scale = scale ?? null;
}
override get inner() {
return [this.precision, this.scale];
}
override equals(other: DataType): boolean {
if (other.variant === this.variant) {
return (
this.precision === (other as Decimal).precision &&
this.scale === (other as Decimal).scale
);
}
return false;
}

override toJSON() {
return {
[this.identity]: {
[this.variant]: {
precision: this.precision,
scale: this.scale,
},
},
};
}
}

/**
* Datetime type
*/
Expand Down Expand Up @@ -234,10 +272,6 @@ export class FixedSizeList extends DataType {
super();
}

override get variant() {
return "FixedSizeList";
}

override get inner(): [DataType, number] {
return [this.__inner, this.listSize];
}
Expand Down Expand Up @@ -349,6 +383,7 @@ export namespace DataType {
export type Object = import(".").Object_;
export type Null = import(".").Null;
export type Struct = import(".").Struct;
export type Decimal = import(".").Decimal;
/**
* deserializes a datatype from the serde output of rust polars `DataType`
* @param dtype dtype object
Expand Down
5 changes: 4 additions & 1 deletion polars/index.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import { Decimal } from "./datatypes/datatype";
import * as series from "./series";
import * as df from "./dataframe";
import { DataType, Field as _field } from "./datatypes";
export { DataType } from "./datatypes";
export * from "./datatypes";
import * as func from "./functions";
import * as io from "./io";
import * as cfg from "./cfg";
Expand Down Expand Up @@ -111,6 +112,7 @@ export namespace pl {
export type Object = import("./datatypes").Object_;
export type Null = import("./datatypes").Null;
export type Struct = import("./datatypes").Struct;
export type Decimal = import("./datatypes").Decimal;

export const Categorical = DataType.Categorical;
export const Int8 = DataType.Int8;
Expand All @@ -137,6 +139,7 @@ export namespace pl {
export const Object = DataType.Object;
export const Null = DataType.Null;
export const Struct = DataType.Struct;
export const Decimal = DataType.Decimal;

/**
* Run SQL queries against DataFrame/LazyFrame data.
Expand Down
7 changes: 7 additions & 0 deletions polars/internals/construction.ts
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,13 @@ export function arrayToJsSeries(

return df.toStruct(name);
}

if (dtype?.variant === "Decimal") {
if (typeof firstValue !== "bigint") {
throw new Error("Decimal type can only be constructed from BigInt");
}
return pli.JsSeries.newAnyvalue(name, values, dtype, strict);
}
if (firstValue instanceof Date) {
series = pli.JsSeries.newOptDate(name, values, strict);
} else {
Expand Down
40 changes: 39 additions & 1 deletion src/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ impl ToSeries for JsUnknown {
Series::new("", v)
}
}

impl ToNapiValue for Wrap<&Series> {
unsafe fn to_napi_value(napi_env: sys::napi_env, val: Self) -> napi::Result<sys::napi_value> {
let s = val.0;
Expand Down Expand Up @@ -101,6 +102,7 @@ impl ToNapiValue for Wrap<&Series> {
}
}
}

impl<'a> ToNapiValue for Wrap<AnyValue<'a>> {
unsafe fn to_napi_value(env: sys::napi_env, val: Self) -> Result<sys::napi_value> {
match val.0 {
Expand Down Expand Up @@ -152,7 +154,16 @@ impl<'a> ToNapiValue for Wrap<AnyValue<'a>> {
AnyValue::Time(v) => i64::to_napi_value(env, v),
AnyValue::List(ser) => Wrap::<&Series>::to_napi_value(env, Wrap(&ser)),
ref av @ AnyValue::Struct(_, _, flds) => struct_dict(env, av._iter_struct_av(), flds),
_ => todo!(),
AnyValue::Array(ser, _) => Wrap::<&Series>::to_napi_value(env, Wrap(&ser)),
AnyValue::Enum(_, _, _) => todo!(),
AnyValue::Object(_) => todo!(),
AnyValue::ObjectOwned(_) => todo!(),
AnyValue::StructOwned(_) => todo!(),
AnyValue::Binary(_) => todo!(),
AnyValue::BinaryOwned(_) => todo!(),
AnyValue::Decimal(_, _) => {
Err(napi::Error::from_reason("Decimal is not a supported type in javascript, please convert to string or number before collecting to js"))
}
}
}
}
Expand Down Expand Up @@ -679,6 +690,12 @@ impl FromNapiValue for Wrap<DataType> {
}
DataType::Struct(fldvec)
}
"Decimal" => {
let inner = obj.get::<_, Array>("inner")?.unwrap(); // [precision, scale]
let precision = inner.get::<Option<i32>>(0)?.unwrap().map(|x| x as usize);
let scale = inner.get::<Option<i32>>(1)?.unwrap().map(|x| x as usize);
DataType::Decimal(precision, scale)
}
tp => panic!("Type {} not implemented in str_to_polarstype", tp),
};
Ok(Wrap(dtype))
Expand Down Expand Up @@ -963,6 +980,27 @@ impl ToNapiValue for Wrap<DataType> {

Object::to_napi_value(env, obj)
}
DataType::Array(dtype, size) => {
let env_ctx = Env::from_raw(env);
let mut obj = env_ctx.create_object()?;
let wrapped = Wrap(*dtype);
let mut inner_arr = env_ctx.create_array(2)?;
inner_arr.set(0, wrapped)?;
inner_arr.set(1, size as u32)?;
obj.set("variant", "FixedSizeList")?;
obj.set("inner", inner_arr)?;
Object::to_napi_value(env, obj)
}
DataType::Decimal(precision, scale) => {
let env_ctx = Env::from_raw(env);
let mut obj = env_ctx.create_object()?;
let mut inner_arr = env_ctx.create_array(2)?;
inner_arr.set(0, precision.map(|p| p as u32))?;
inner_arr.set(1, scale.map(|s| s as u32))?;
obj.set("variant", "Decimal")?;
obj.set("inner", inner_arr)?;
Object::to_napi_value(env, obj)
}
_ => {
todo!()
}
Expand Down
10 changes: 6 additions & 4 deletions src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,7 @@ pub struct ReadCsvOptions {
fn mmap_reader_to_df<'a>(
csv: impl MmapBytesReader + 'a,
options: ReadCsvOptions,
) -> napi::Result<JsDataFrame>
{
) -> napi::Result<JsDataFrame> {
let null_values = options.null_values.map(|w| w.0);
let row_count = options.row_count.map(RowIndex::from);
let projection = options
Expand Down Expand Up @@ -598,12 +597,15 @@ impl JsDataFrame {

let df = self
.df
.join(&other.df, left_on, right_on,
.join(
&other.df,
left_on,
right_on,
JoinArgs {
how: how,
suffix: suffix,
..Default::default()
}
},
)
.map_err(JsPolarsErr::from)?;
Ok(JsDataFrame::new(df))
Expand Down
24 changes: 15 additions & 9 deletions src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -955,16 +955,22 @@ impl JsExpr {
.into()
}
#[napi(catch_unwind)]
pub fn replace(&self, old: &JsExpr, new: &JsExpr, default: Option<&JsExpr>, return_dtype: Option<Wrap<DataType>>) -> JsExpr {
pub fn replace(
&self,
old: &JsExpr,
new: &JsExpr,
default: Option<&JsExpr>,
return_dtype: Option<Wrap<DataType>>,
) -> JsExpr {
self.inner
.clone()
.replace(
old.inner.clone(),
new.inner.clone(),
default.map(|e| e.inner.clone()),
return_dtype.map(|dt| dt.0),
)
.into()
.clone()
.replace(
old.inner.clone(),
new.inner.clone(),
default.map(|e| e.inner.clone()),
return_dtype.map(|dt| dt.0),
)
.into()
}
#[napi(catch_unwind)]
pub fn year(&self) -> JsExpr {
Expand Down
27 changes: 25 additions & 2 deletions src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,22 @@ impl JsSeries {
.into_series()
.into())
}

#[napi(factory, catch_unwind)]
pub fn new_anyvalue(
name: String,
values: Vec<Wrap<AnyValue>>,
dtype: Wrap<DataType>,
strict: bool,
) -> napi::Result<JsSeries> {
let values = values.into_iter().map(|v| v.0).collect::<Vec<_>>();

let s = Series::from_any_values_and_dtype(&name, &values, &dtype.0, strict)
.map_err(JsPolarsErr::from)?;

Ok(s.into())
}

#[napi(factory, catch_unwind)]
pub fn new_list(name: String, values: Array, dtype: Wrap<DataType>) -> napi::Result<JsSeries> {
use crate::list_construction::js_arr_to_list;
Expand Down Expand Up @@ -993,8 +1009,15 @@ impl JsSeries {
// Ok(ca.into_series().into())
// }
#[napi(catch_unwind)]
pub fn to_dummies(&self, separator: Option<&str>, drop_first: bool) -> napi::Result<JsDataFrame> {
let df = self.series.to_dummies(separator, drop_first).map_err(JsPolarsErr::from)?;
pub fn to_dummies(
&self,
separator: Option<&str>,
drop_first: bool,
) -> napi::Result<JsDataFrame> {
let df = self
.series
.to_dummies(separator, drop_first)
.map_err(JsPolarsErr::from)?;
Ok(df.into())
}
#[napi(catch_unwind)]
Expand Down

0 comments on commit 9f1e2f1

Please sign in to comment.