Skip to content

Commit

Permalink
python fix concat (#3743)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jun 19, 2022
1 parent 8bbff3a commit 62f6a41
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 6 deletions.
11 changes: 5 additions & 6 deletions py-polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ use mimalloc::MiMalloc;
use polars::functions::{diag_concat_df, hor_concat_df};
use polars::prelude::Null;
use polars_core::datatypes::TimeUnit;
use polars_core::frame::row::Row;
use polars_core::prelude::DataFrame;
use polars_core::prelude::IntoSeries;
use polars_core::POOL;
Expand Down Expand Up @@ -266,29 +265,29 @@ fn py_duration(

#[pyfunction]
fn concat_df(dfs: &PyAny, py: Python) -> PyResult<PyDataFrame> {
use polars_core::utils::rayon::prelude::*;
use polars_core::{error::Result, utils::rayon::prelude::*};

let (seq, _len) = get_pyseq(dfs)?;
let mut iter = seq.iter()?;
let first = iter.next().unwrap()?;

let first_rdf = get_df(first)?;
let schema = first_rdf.schema();
let identity_df = first_rdf.slice(0, 0);

let mut rdfs: Vec<polars_core::error::Result<DataFrame>> = vec![Ok(first_rdf)];
let mut rdfs: Vec<Result<DataFrame>> = vec![Ok(first_rdf)];

for item in iter {
let rdf = get_df(item?)?;
rdfs.push(Ok(rdf));
}

let identity = || DataFrame::from_rows_and_schema(&[Row::default()], &schema);
let identity = || Ok(identity_df.clone());

let df = py
.allow_threads(|| {
polars_core::POOL.install(|| {
rdfs.into_par_iter()
.fold(identity, |acc, df| {
.fold(identity, |acc: Result<DataFrame>, df| {
let mut acc = acc?;
acc.vstack_mut(&df?)?;
Ok(acc)
Expand Down
27 changes: 27 additions & 0 deletions py-polars/tests/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,30 @@ def demean_dot() -> pl.Expr:
]
)
).to_dict(False) == {"key": ["a"], "demean_dot": [0.0]}


def test_dtype_concat_3735() -> None:
for dt in [
pl.Int8,
pl.Int16,
pl.Int32,
pl.Int64,
pl.UInt8,
pl.UInt16,
pl.UInt32,
pl.UInt64,
pl.Float32,
pl.Float64,
]:
d1 = pl.DataFrame(
[
pl.Series("val", [1, 2], dtype=dt),
]
)
d2 = pl.DataFrame(
[
pl.Series("val", [3, 4], dtype=dt),
]
)
df = pl.concat([d1, d2])
assert df.shape == (4, 1)

0 comments on commit 62f6a41

Please sign in to comment.