Skip to content

Commit

Permalink
Series/Expr::reshape
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Nov 25, 2021
1 parent 7a37efb commit fb57767
Show file tree
Hide file tree
Showing 9 changed files with 175 additions and 0 deletions.
83 changes: 83 additions & 0 deletions polars/polars-core/src/series/ops/to_list.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use crate::chunked_array::builder::get_list_builder;
use crate::prelude::*;
use std::borrow::Cow;

impl Series {
/// Convert the values of this Series to a ListChunked with a length of 1,
Expand All @@ -21,6 +23,67 @@ impl Series {
vec![Arc::new(arr)],
))
}

pub fn reshape(&self, dims: &[i64]) -> Result<Series> {
let s = if let DataType::List(_) = self.dtype() {
Cow::Owned(self.explode()?)
} else {
Cow::Borrowed(self)
};
let s_ref = s.as_ref();

let mut dims = dims.to_vec();
if let Some(idx) = dims.iter().position(|i| *i == -1) {
let mut product = 1;

for (cnt, dim) in dims.iter().enumerate() {
if cnt != idx {
product *= *dim
}
}
dims[idx] = s_ref.len() as i64 / product;
}

let prod = dims.iter().product::<i64>() as usize;
if prod != s_ref.len() {
return Err(PolarsError::ValueError(
format!("cannot reshape len {} into shape {:?}", s_ref.len(), dims).into(),
));
}

match dims.len() {
0 => {
panic!("dimensions cannot be empty")
}
1 => Ok(s_ref.slice(0, dims[0] as usize)),
2 => {
let mut rows = dims[0];
let mut cols = dims[1];

// infer dimension
if rows == -1 {
rows = cols / s_ref.len() as i64
}
if cols == -1 {
cols = rows / s_ref.len() as i64
}

let mut builder =
get_list_builder(s_ref.dtype(), s_ref.len(), rows as usize, self.name());

let mut offset = 0i64;
for _ in 0..rows {
let row = s_ref.slice(offset, cols as usize);
builder.append_series(&row);
offset += cols;
}
Ok(builder.finish().into_series())
}
_ => {
panic!("more than two dimensions not yet supported");
}
}
}
}

#[cfg(test)]
Expand All @@ -41,4 +104,24 @@ mod test {

Ok(())
}

#[test]
fn test_reshape() -> Result<()> {
let s = Series::new("a", &[1, 2, 3, 4]);

for (dims, list_len) in [
(&[-1, 1], 4),
(&[4, 1], 4),
(&[2, 2], 2),
(&[-1, 2], 2),
(&[2, -1], 2),
] {
let out = s.reshape(dims)?;
assert_eq!(out.len(), list_len);
assert!(matches!(out.dtype(), DataType::List(_)));
assert_eq!(out.explode()?.len(), 4);
}

Ok(())
}
}
26 changes: 26 additions & 0 deletions polars/polars-lazy/src/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1731,6 +1731,32 @@ impl Expr {
GetOutput::same_type(),
)
}

pub fn reshape(self, dims: &[i64]) -> Self {
let dims = dims.to_vec();
let output_type = if dims.len() == 1 {
GetOutput::map_field(|fld| {
Field::new(
fld.name(),
fld.data_type()
.inner_dtype()
.unwrap_or_else(|| fld.data_type())
.clone(),
)
})
} else {
GetOutput::map_field(|fld| {
let dtype = fld
.data_type()
.inner_dtype()
.unwrap_or_else(|| fld.data_type())
.clone();

Field::new(fld.name(), DataType::List(Box::new(dtype)))
})
};
self.apply(move |s| s.reshape(&dims), output_type)
}
}

/// Create a Column Expression based on a column name.
Expand Down
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/expression.rst
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ Manipulation/ selection
Expr.lower_bound
Expr.upper_bound
Expr.str_concat
Expr.reshape

Column names
------------
Expand Down
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/series.rst
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ Manipulation/ selection
Series.interpolate
Series.clip
Series.str_concat
Series.reshape

Various
--------
Expand Down
18 changes: 18 additions & 0 deletions py-polars/polars/internals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1905,6 +1905,24 @@ def arctan(self) -> "Expr":
"""
return np.arctan(self) # type: ignore

def reshape(self, dims: tp.Tuple[int, ...]) -> "Expr":
"""
Reshape this Expr to a flat series, shape: (len,)
or a List series, shape: (rows, cols)
if a -1 is used in any of the dimensions, that dimension is inferred.
Parameters
----------
dims
Tuple of the dimension sizes
Returns
-------
Expr
"""
return wrap_expr(self._pyexpr.reshape(dims))


class ExprListNameSpace:
"""
Expand Down
18 changes: 18 additions & 0 deletions py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2998,6 +2998,24 @@ def str_concat(self, delimiter: str = "-") -> "Series": # type: ignore
pli.col(self.name).str_concat(delimiter) # type: ignore
)[self.name]

def reshape(self, dims: tp.Tuple[int, ...]) -> "Series":
"""
Reshape this Series to a flat series, shape: (len,)
or a List series, shape: (rows, cols)
if a -1 is used in any of the dimensions, that dimension is inferred.
Parameters
----------
dims
Tuple of the dimension sizes
Returns
-------
Series
"""
return wrap_s(self._s.reshape(dims))


class StringNameSpace:
"""
Expand Down
4 changes: 4 additions & 0 deletions py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,10 @@ impl PyExpr {
}
}, GetOutput::same_type()).into()
}

pub fn reshape(&self, dims: Vec<i64>) -> Self {
self.inner.clone().reshape(&dims).into()
}
}

impl From<dsl::Expr> for PyExpr {
Expand Down
5 changes: 5 additions & 0 deletions py-polars/src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1350,6 +1350,11 @@ impl PySeries {
let out = self.series.abs().map_err(PyPolarsEr::from)?;
Ok(out.into())
}

pub fn reshape(&self, dims: Vec<i64>) -> PyResult<Self> {
let out = self.series.reshape(&dims).map_err(PyPolarsEr::from)?;
Ok(out.into())
}
}

macro_rules! impl_ufuncs {
Expand Down
19 changes: 19 additions & 0 deletions py-polars/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,3 +966,22 @@ def test_compare_series_value_exact_mismatch() -> None:
AssertionError, match="Series are different\n\nExact value mismatch"
):
testing.assert_series_equal(srs1, srs2, check_exact=True)


def test_reshape() -> None:
s = pl.Series("a", [1, 2, 3, 4])
out = s.reshape((-1, 2))
expected = pl.Series("a", [[1, 2], [3, 4]])
assert out.series_equal(expected)
out = s.reshape((2, 2))
assert out.series_equal(expected)
out = s.reshape((2, -1))
assert out.series_equal(expected)

out = s.reshape((-1, 1))
expected = pl.Series("a", [[1], [2], [3], [4]])
assert out.series_equal(expected)

# test lazy_dispatch
out = pl.select(pl.lit(s).reshape((-1, 1))).to_series()
assert out.series_equal(expected)

0 comments on commit fb57767

Please sign in to comment.