Skip to content

Commit

Permalink
fix[python]: make view memory safe (#4594)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 28, 2022
1 parent 2bed2e2 commit c3142cc
Show file tree
Hide file tree
Showing 10 changed files with 118 additions and 141 deletions.
5 changes: 3 additions & 2 deletions polars/polars-arrow/src/kernels/set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use arrow::{datatypes::DataType, types::NativeType};

use crate::array::default_arrays::FromData;
use crate::error::{PolarsError, Result};
use crate::index::IdxSize;
use crate::kernels::BinaryMaskedSliceIterator;
use crate::trusted_len::PushUnchecked;

Expand Down Expand Up @@ -75,15 +76,15 @@ pub fn set_at_idx_no_null<T, I>(
) -> Result<PrimitiveArray<T>>
where
T: NativeType,
I: IntoIterator<Item = usize>,
I: IntoIterator<Item = IdxSize>,
{
let mut buf = Vec::with_capacity(array.len());
buf.extend_from_slice(array.values().as_slice());
let mut_slice = buf.as_mut_slice();

idx.into_iter().try_for_each::<_, Result<_>>(|idx| {
let val = mut_slice
.get_mut(idx)
.get_mut(idx as usize)
.ok_or_else(|| PolarsError::ComputeError("idx is out of bounds".into()))?;
*val = set_value;
Ok(())
Expand Down
4 changes: 2 additions & 2 deletions polars/polars-core/src/chunked_array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ pub trait ChunkSet<'a, A, B> {
///
/// assert_eq!(Vec::from(&new), &[Some(10), Some(10), Some(3)]);
/// ```
fn set_at_idx<I: IntoIterator<Item = usize>>(
fn set_at_idx<I: IntoIterator<Item = IdxSize>>(
&'a self,
idx: I,
opt_value: Option<A>,
Expand All @@ -257,7 +257,7 @@ pub trait ChunkSet<'a, A, B> {
///
/// assert_eq!(Vec::from(&new), &[Some(-4), Some(-3), Some(3)]);
/// ```
fn set_at_idx_with<I: IntoIterator<Item = usize>, F>(&'a self, idx: I, f: F) -> Result<Self>
fn set_at_idx_with<I: IntoIterator<Item = IdxSize>, F>(&'a self, idx: I, f: F) -> Result<Self>
where
Self: Sized,
F: Fn(Option<A>) -> Option<B>;
Expand Down
30 changes: 15 additions & 15 deletions polars/polars-core/src/chunked_array/ops/set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ macro_rules! impl_set_at_idx_with {
let mut ca_iter = $self.into_iter().enumerate();

while let Some(current_idx) = idx_iter.next() {
if current_idx > $self.len() {
if current_idx as usize > $self.len() {
return Err(PolarsError::ComputeError(
format!(
"index: {} outside of ChunkedArray with length: {}",
Expand All @@ -23,7 +23,7 @@ macro_rules! impl_set_at_idx_with {
));
}
while let Some((cnt_idx, opt_val)) = ca_iter.next() {
if cnt_idx == current_idx {
if cnt_idx == current_idx as usize {
$builder.append_option($f(opt_val));
break;
} else {
Expand Down Expand Up @@ -55,7 +55,7 @@ impl<'a, T> ChunkSet<'a, T::Native, T::Native> for ChunkedArray<T>
where
T: PolarsNumericType,
{
fn set_at_idx<I: IntoIterator<Item = usize>>(
fn set_at_idx<I: IntoIterator<Item = IdxSize>>(
&'a self,
idx: I,
value: Option<T::Native>,
Expand All @@ -78,7 +78,7 @@ where
let data = av.as_mut_slice();

idx.into_iter().try_for_each::<_, Result<_>>(|idx| {
let val = data.get_mut(idx).ok_or_else(|| {
let val = data.get_mut(idx as usize).ok_or_else(|| {
PolarsError::ComputeError(
format!("{} out of bounds on array of length: {}", idx, self.len())
.into(),
Expand All @@ -94,7 +94,7 @@ where
self.set_at_idx_with(idx, |_| value)
}

fn set_at_idx_with<I: IntoIterator<Item = usize>, F>(&'a self, idx: I, f: F) -> Result<Self>
fn set_at_idx_with<I: IntoIterator<Item = IdxSize>, F>(&'a self, idx: I, f: F) -> Result<Self>
where
F: Fn(Option<T::Native>) -> Option<T::Native>,
{
Expand Down Expand Up @@ -136,15 +136,15 @@ where
}

impl<'a> ChunkSet<'a, bool, bool> for BooleanChunked {
fn set_at_idx<I: IntoIterator<Item = usize>>(
fn set_at_idx<I: IntoIterator<Item = IdxSize>>(
&'a self,
idx: I,
value: Option<bool>,
) -> Result<Self> {
self.set_at_idx_with(idx, |_| value)
}

fn set_at_idx_with<I: IntoIterator<Item = usize>, F>(&'a self, idx: I, f: F) -> Result<Self>
fn set_at_idx_with<I: IntoIterator<Item = IdxSize>, F>(&'a self, idx: I, f: F) -> Result<Self>
where
F: Fn(Option<bool>) -> Option<bool>,
{
Expand All @@ -161,14 +161,14 @@ impl<'a> ChunkSet<'a, bool, bool> for BooleanChunked {
}

for i in idx {
let input = if validity.get(i) {
Some(values.get(i))
let input = if validity.get(i as usize) {
Some(values.get(i as usize))
} else {
None
};
match f(input) {
None => validity.set(i, false),
Some(v) => values.set(i, v),
None => validity.set(i as usize, false),
Some(v) => values.set(i as usize, v),
}
}
let arr = BooleanArray::from_data_default(values.into(), Some(validity.into()));
Expand All @@ -194,7 +194,7 @@ impl<'a> ChunkSet<'a, bool, bool> for BooleanChunked {
}

impl<'a> ChunkSet<'a, &'a str, String> for Utf8Chunked {
fn set_at_idx<I: IntoIterator<Item = usize>>(
fn set_at_idx<I: IntoIterator<Item = IdxSize>>(
&'a self,
idx: I,
opt_value: Option<&'a str>,
Expand All @@ -207,7 +207,7 @@ impl<'a> ChunkSet<'a, &'a str, String> for Utf8Chunked {
let mut builder = Utf8ChunkedBuilder::new(self.name(), self.len(), self.get_values_size());

for current_idx in idx_iter {
if current_idx > self.len() {
if current_idx as usize > self.len() {
return Err(PolarsError::ComputeError(
format!(
"index: {} outside of ChunkedArray with length: {}",
Expand All @@ -218,7 +218,7 @@ impl<'a> ChunkSet<'a, &'a str, String> for Utf8Chunked {
));
}
for (cnt_idx, opt_val_self) in &mut ca_iter {
if cnt_idx == current_idx {
if cnt_idx == current_idx as usize {
builder.append_option(opt_value);
break;
} else {
Expand All @@ -235,7 +235,7 @@ impl<'a> ChunkSet<'a, &'a str, String> for Utf8Chunked {
Ok(ca)
}

fn set_at_idx_with<I: IntoIterator<Item = usize>, F>(&'a self, idx: I, f: F) -> Result<Self>
fn set_at_idx_with<I: IntoIterator<Item = IdxSize>, F>(&'a self, idx: I, f: F) -> Result<Self>
where
Self: Sized,
F: Fn(Option<&'a str>) -> Option<String>,
Expand Down
59 changes: 59 additions & 0 deletions py-polars/polars/internals/series/_numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from __future__ import annotations

import ctypes
from typing import Any

from polars import internals as pli

try:
import numpy as np

_NUMPY_AVAILABLE = True
except ImportError:
_NUMPY_AVAILABLE = False


# https://numpy.org/doc/stable/user/basics.subclassing.html#slightly-more-realistic-example-attribute-added-to-existing-array
class SeriesView(np.ndarray): # type: ignore[type-arg]
def __new__(
cls, input_array: np.ndarray[Any, Any], owned_series: pli.Series
) -> SeriesView:
# Input array is an already formed ndarray instance
# We first cast to be our class type
obj = input_array.view(cls)
# add the new attribute to the created instance
obj.owned_series = owned_series
# Finally, we must return the newly created object:
return obj

def __array_finalize__(self, obj: Any) -> None:
# see InfoArray.__array_finalize__ for comments
if obj is None:
return
self.owned_series = getattr(obj, "owned_series", None)


# https://stackoverflow.com/questions/4355524/getting-data-from-ctypes-array-into-numpy
def _ptr_to_numpy(ptr: int, len: int, ptr_type: Any) -> np.ndarray[Any, Any]:
"""
Create a memory block view as a numpy array.
Parameters
----------
ptr
C/Rust ptr casted to usize.
len
Length of the array values.
ptr_type
Example:
f32: ctypes.c_float)
Returns
-------
View of memory block as numpy array.
"""
if not _NUMPY_AVAILABLE:
raise ImportError("'numpy' is required for this functionality.")
ptr_ctype = ctypes.cast(ptr, ctypes.POINTER(ptr_type))
return np.ctypeslib.as_array(ptr_ctype, (len,))
75 changes: 14 additions & 61 deletions py-polars/polars/internals/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
sequence_to_pyseries,
series_to_pyseries,
)
from polars.internals.series._numpy import SeriesView, _ptr_to_numpy
from polars.internals.series.categorical import CatNameSpace
from polars.internals.series.datetime import DateTimeNameSpace
from polars.internals.series.list import ListNameSpace
Expand All @@ -52,7 +53,6 @@
from polars.utils import (
_date_to_pl_date,
_datetime_to_pl_timestamp,
_ptr_to_numpy,
deprecated_alias,
is_bool_sequence,
is_int_sequence,
Expand Down Expand Up @@ -956,11 +956,7 @@ def std(self, ddof: int = 1) -> float | None:
"""
if not self.is_numeric():
return None
if ddof == 1:
return self.to_frame().select(pli.col(self.name).std()).to_series()[0]
if not _NUMPY_AVAILABLE:
raise ImportError("'numpy' is required for this functionality.")
return np.std(self.drop_nulls().view(), ddof=ddof)
return self.to_frame().select(pli.col(self.name).std(ddof)).to_series()[0]

def var(self, ddof: int = 1) -> float | None:
"""
Expand All @@ -982,11 +978,7 @@ def var(self, ddof: int = 1) -> float | None:
"""
if not self.is_numeric():
return None
if ddof == 1:
return self.to_frame().select(pli.col(self.name).var()).to_series()[0]
if not _NUMPY_AVAILABLE:
raise ImportError("'numpy' is required for this functionality.")
return np.var(self.drop_nulls().view(), ddof=ddof)
return self.to_frame().select(pli.col(self.name).var(ddof)).to_series()[0]

def median(self) -> float:
"""
Expand Down Expand Up @@ -2271,24 +2263,12 @@ def is_utf8(self) -> bool:
"""
return self.dtype is Utf8

def view(self, ignore_nulls: bool = False) -> np.ndarray[Any, Any]:
def view(self, ignore_nulls: bool = False) -> SeriesView:
"""
Get a view into this Series data with a numpy array. This operation doesn't
clone data, but does not include missing values. Don't use this unless you know
what you are doing.
.. warning::
This function can lead to undefined behavior in the following cases:
Returns a view to a piece of memory that is already dropped:
>>> pl.Series([1, 3, 5]).sort().view() # doctest: +IGNORE_RESULT
Sums invalid data that is missing:
>>> pl.Series([1, 2, None]).view().sum() # doctest: +SKIP
"""
if not ignore_nulls:
assert not self.has_validity()
Expand All @@ -2297,7 +2277,7 @@ def view(self, ignore_nulls: bool = False) -> np.ndarray[Any, Any]:
ptr = self._s.as_single_ptr()
array = _ptr_to_numpy(ptr, self.len(), ptr_type)
array.setflags(write=False)
return array
return SeriesView(array, self)

def __array__(self, dtype: Any = None) -> np.ndarray[Any, Any]:
if dtype:
Expand Down Expand Up @@ -2524,43 +2504,16 @@ def set_at_idx(
if len(idx) == 0:
return self

if self.is_numeric() or self.is_datelike():
idx = Series("", idx)
if isinstance(value, (int, float, bool)) or (value is None):
value = Series("", [value])

# if we need to set more than a single value, we extend it
if len(idx) > 0:
value = value.extend_constant(value[0], len(idx) - 1)
elif not isinstance(value, Series):
value = Series("", value)
self._s.set_at_idx(idx._s, value._s)
return self

# the set_at_idx function expects a np.array of dtype u32
f = get_ffi_func("set_at_idx_<>", self.dtype, self._s)
if f is None:
raise ValueError(
"could not find the FFI function needed to set at idx for series"
f" {self._s}"
)
if isinstance(idx, Series):
# make sure the dtype matches
idx = idx.cast(get_idx_type())
idx_array = idx.view()
elif _NUMPY_AVAILABLE and isinstance(idx, np.ndarray):
if not idx.data.c_contiguous:
idx_array = np.ascontiguousarray(idx, dtype=np.uint32)
else:
idx_array = idx
if idx_array.dtype != np.uint32:
idx_array = np.array(idx_array, np.uint32)
else:
if not _NUMPY_AVAILABLE:
raise ImportError("'numpy' is required for this functionality.")
idx_array = np.array(idx, dtype=np.uint32)
idx = Series("", idx)
if isinstance(value, (int, float, bool, str)) or (value is None):
value = Series("", [value])

self._s = f(idx_array, value)
# if we need to set more than a single value, we extend it
if len(idx) > 0:
value = value.extend_constant(value[0], len(idx) - 1)
elif not isinstance(value, Series):
value = Series("", value)
self._s.set_at_idx(idx._s, value._s)
return self

def cleared(self) -> Series:
Expand Down

0 comments on commit c3142cc

Please sign in to comment.