diff --git a/rust/src/ffi/array/copy.rs b/rust/src/ffi/array/copy.rs index ca9c880..a3414f6 100644 --- a/rust/src/ffi/array/copy.rs +++ b/rust/src/ffi/array/copy.rs @@ -1,8 +1,14 @@ //! Array copy FFI function. -use crate::helpers::error::{self, ERR_GENERIC, SUCCESS}; +use crate::helpers::view::{ + extract_view_bool, extract_view_f32, extract_view_f64, extract_view_i16, extract_view_i32, + extract_view_i64, extract_view_i8, extract_view_u16, extract_view_u32, extract_view_u64, + extract_view_u8, +}; use crate::types::dtype::DType; -use crate::types::{ArrayMetadata, NdArrayHandle}; +use crate::types::{ArrayData, ArrayMetadata, NDArrayWrapper, NdArrayHandle}; +use parking_lot::RwLock; +use std::sync::Arc; /// Create a deep copy of an array view. #[no_mangle] @@ -12,36 +18,94 @@ pub unsafe extern "C" fn ndarray_copy( out_handle: *mut *mut NdArrayHandle, ) -> i32 { if handle.is_null() || meta.is_null() || out_handle.is_null() { - return ERR_GENERIC; + return crate::helpers::error::ERR_GENERIC; } crate::ffi_guard!({ let wrapper = NdArrayHandle::as_wrapper(handle as *mut _); let meta = &*meta; - let result = match wrapper.dtype { - DType::Int8 => wrapper.copy_view_i8(meta), - DType::Int16 => wrapper.copy_view_i16(meta), - DType::Int32 => wrapper.copy_view_i32(meta), - DType::Int64 => wrapper.copy_view_i64(meta), - DType::Uint8 => wrapper.copy_view_u8(meta), - DType::Uint16 => wrapper.copy_view_u16(meta), - DType::Uint32 => wrapper.copy_view_u32(meta), - DType::Uint64 => wrapper.copy_view_u64(meta), - DType::Float32 => wrapper.copy_view_f32(meta), - DType::Float64 => wrapper.copy_view_f64(meta), - DType::Bool => wrapper.copy_view_bool(meta), - }; - - match result { - Ok(new_wrapper) => { - *out_handle = NdArrayHandle::from_wrapper(Box::new(new_wrapper)); - SUCCESS + let new_wrapper = match wrapper.dtype { + DType::Int8 => { + let view = extract_view_i8(wrapper, meta).expect("Type mismatch"); + NDArrayWrapper { + data: ArrayData::Int8(Arc::new(RwLock::new(view.to_owned()))), + dtype: DType::Int8, + } + } + DType::Int16 => { + let view = extract_view_i16(wrapper, meta).expect("Type mismatch"); + NDArrayWrapper { + data: ArrayData::Int16(Arc::new(RwLock::new(view.to_owned()))), + dtype: DType::Int16, + } + } + DType::Int32 => { + let view = extract_view_i32(wrapper, meta).expect("Type mismatch"); + NDArrayWrapper { + data: ArrayData::Int32(Arc::new(RwLock::new(view.to_owned()))), + dtype: DType::Int32, + } + } + DType::Int64 => { + let view = extract_view_i64(wrapper, meta).expect("Type mismatch"); + NDArrayWrapper { + data: ArrayData::Int64(Arc::new(RwLock::new(view.to_owned()))), + dtype: DType::Int64, + } + } + DType::Uint8 => { + let view = extract_view_u8(wrapper, meta).expect("Type mismatch"); + NDArrayWrapper { + data: ArrayData::Uint8(Arc::new(RwLock::new(view.to_owned()))), + dtype: DType::Uint8, + } + } + DType::Uint16 => { + let view = extract_view_u16(wrapper, meta).expect("Type mismatch"); + NDArrayWrapper { + data: ArrayData::Uint16(Arc::new(RwLock::new(view.to_owned()))), + dtype: DType::Uint16, + } } - Err(e) => { - error::set_last_error(e); - ERR_GENERIC + DType::Uint32 => { + let view = extract_view_u32(wrapper, meta).expect("Type mismatch"); + NDArrayWrapper { + data: ArrayData::Uint32(Arc::new(RwLock::new(view.to_owned()))), + dtype: DType::Uint32, + } } - } + DType::Uint64 => { + let view = extract_view_u64(wrapper, meta).expect("Type mismatch"); + NDArrayWrapper { + data: ArrayData::Uint64(Arc::new(RwLock::new(view.to_owned()))), + dtype: DType::Uint64, + } + } + DType::Float32 => { + let view = extract_view_f32(wrapper, meta).expect("Type mismatch"); + NDArrayWrapper { + data: ArrayData::Float32(Arc::new(RwLock::new(view.to_owned()))), + dtype: DType::Float32, + } + } + DType::Float64 => { + let view = extract_view_f64(wrapper, meta).expect("Type mismatch"); + NDArrayWrapper { + data: ArrayData::Float64(Arc::new(RwLock::new(view.to_owned()))), + dtype: DType::Float64, + } + } + DType::Bool => { + let view = extract_view_bool(wrapper, meta).expect("Type mismatch"); + NDArrayWrapper { + data: ArrayData::Bool(Arc::new(RwLock::new(view.to_owned()))), + dtype: DType::Bool, + } + } + }; + + *out_handle = NdArrayHandle::from_wrapper(Box::new(new_wrapper)); + crate::helpers::error::SUCCESS }) } diff --git a/rust/src/ffi/indexing/assign.rs b/rust/src/ffi/indexing/assign.rs index f18d20c..f655775 100644 --- a/rust/src/ffi/indexing/assign.rs +++ b/rust/src/ffi/indexing/assign.rs @@ -2,7 +2,14 @@ //! //! Provides assign operations between strided array views. -use crate::helpers::error::{self, ERR_GENERIC, SUCCESS}; +use crate::helpers::error::{set_last_error, ERR_GENERIC, SUCCESS}; +use crate::helpers::view::{ + extract_view_bool, extract_view_f32, extract_view_f64, extract_view_i16, extract_view_i32, + extract_view_i64, extract_view_i8, extract_view_mut_bool, extract_view_mut_f32, + extract_view_mut_f64, extract_view_mut_i16, extract_view_mut_i32, extract_view_mut_i64, + extract_view_mut_i8, extract_view_mut_u16, extract_view_mut_u32, extract_view_mut_u64, + extract_view_mut_u8, extract_view_u16, extract_view_u32, extract_view_u64, extract_view_u8, +}; use crate::types::dtype::DType; use crate::types::{ArrayMetadata, NdArrayHandle}; @@ -31,26 +38,130 @@ pub unsafe extern "C" fn ndarray_assign( let dst_meta = &*dst_meta; let src_meta = &*src_meta; - let result = match dst_wrapper.dtype { - DType::Int8 => dst_wrapper.assign_slice_i8(dst_meta, src_wrapper, src_meta), - DType::Int16 => dst_wrapper.assign_slice_i16(dst_meta, src_wrapper, src_meta), - DType::Int32 => dst_wrapper.assign_slice_i32(dst_meta, src_wrapper, src_meta), - DType::Int64 => dst_wrapper.assign_slice_i64(dst_meta, src_wrapper, src_meta), - DType::Uint8 => dst_wrapper.assign_slice_u8(dst_meta, src_wrapper, src_meta), - DType::Uint16 => dst_wrapper.assign_slice_u16(dst_meta, src_wrapper, src_meta), - DType::Uint32 => dst_wrapper.assign_slice_u32(dst_meta, src_wrapper, src_meta), - DType::Uint64 => dst_wrapper.assign_slice_u64(dst_meta, src_wrapper, src_meta), - DType::Float32 => dst_wrapper.assign_slice_f32(dst_meta, src_wrapper, src_meta), - DType::Float64 => dst_wrapper.assign_slice_f64(dst_meta, src_wrapper, src_meta), - DType::Bool => dst_wrapper.assign_slice_bool(dst_meta, src_wrapper, src_meta), - }; - - match result { - Ok(()) => SUCCESS, - Err(e) => { - error::set_last_error(e); - ERR_GENERIC + // Validate dtypes match + if dst_wrapper.dtype != src_wrapper.dtype { + set_last_error(format!( + "DType mismatch in assign: dst={:?}, src={:?}", + dst_wrapper.dtype, src_wrapper.dtype + )); + return ERR_GENERIC; + } + + let is_same = dst_wrapper.is_same_array(src_wrapper); + + match dst_wrapper.dtype { + DType::Int8 => { + let src_view = extract_view_i8(src_wrapper, src_meta).expect("Type mismatch"); + let mut dst_view = + extract_view_mut_i8(dst_wrapper, dst_meta).expect("Type mismatch"); + if is_same { + dst_view.assign(&src_view.to_owned()); + } else { + dst_view.assign(&src_view); + } + } + DType::Int16 => { + let src_view = extract_view_i16(src_wrapper, src_meta).expect("Type mismatch"); + let mut dst_view = + extract_view_mut_i16(dst_wrapper, dst_meta).expect("Type mismatch"); + if is_same { + dst_view.assign(&src_view.to_owned()); + } else { + dst_view.assign(&src_view); + } + } + DType::Int32 => { + let src_view = extract_view_i32(src_wrapper, src_meta).expect("Type mismatch"); + let mut dst_view = + extract_view_mut_i32(dst_wrapper, dst_meta).expect("Type mismatch"); + if is_same { + dst_view.assign(&src_view.to_owned()); + } else { + dst_view.assign(&src_view); + } + } + DType::Int64 => { + let src_view = extract_view_i64(src_wrapper, src_meta).expect("Type mismatch"); + let mut dst_view = + extract_view_mut_i64(dst_wrapper, dst_meta).expect("Type mismatch"); + if is_same { + dst_view.assign(&src_view.to_owned()); + } else { + dst_view.assign(&src_view); + } + } + DType::Uint8 => { + let src_view = extract_view_u8(src_wrapper, src_meta).expect("Type mismatch"); + let mut dst_view = + extract_view_mut_u8(dst_wrapper, dst_meta).expect("Type mismatch"); + if is_same { + dst_view.assign(&src_view.to_owned()); + } else { + dst_view.assign(&src_view); + } + } + DType::Uint16 => { + let src_view = extract_view_u16(src_wrapper, src_meta).expect("Type mismatch"); + let mut dst_view = + extract_view_mut_u16(dst_wrapper, dst_meta).expect("Type mismatch"); + if is_same { + dst_view.assign(&src_view.to_owned()); + } else { + dst_view.assign(&src_view); + } + } + DType::Uint32 => { + let src_view = extract_view_u32(src_wrapper, src_meta).expect("Type mismatch"); + let mut dst_view = + extract_view_mut_u32(dst_wrapper, dst_meta).expect("Type mismatch"); + if is_same { + dst_view.assign(&src_view.to_owned()); + } else { + dst_view.assign(&src_view); + } + } + DType::Uint64 => { + let src_view = extract_view_u64(src_wrapper, src_meta).expect("Type mismatch"); + let mut dst_view = + extract_view_mut_u64(dst_wrapper, dst_meta).expect("Type mismatch"); + if is_same { + dst_view.assign(&src_view.to_owned()); + } else { + dst_view.assign(&src_view); + } + } + DType::Float32 => { + let src_view = extract_view_f32(src_wrapper, src_meta).expect("Type mismatch"); + let mut dst_view = + extract_view_mut_f32(dst_wrapper, dst_meta).expect("Type mismatch"); + if is_same { + dst_view.assign(&src_view.to_owned()); + } else { + dst_view.assign(&src_view); + } + } + DType::Float64 => { + let src_view = extract_view_f64(src_wrapper, src_meta).expect("Type mismatch"); + let mut dst_view = + extract_view_mut_f64(dst_wrapper, dst_meta).expect("Type mismatch"); + if is_same { + dst_view.assign(&src_view.to_owned()); + } else { + dst_view.assign(&src_view); + } + } + DType::Bool => { + let src_view = extract_view_bool(src_wrapper, src_meta).expect("Type mismatch"); + let mut dst_view = + extract_view_mut_bool(dst_wrapper, dst_meta).expect("Type mismatch"); + if is_same { + dst_view.assign(&src_view.to_owned()); + } else { + dst_view.assign(&src_view); + } } } + + SUCCESS }) } diff --git a/rust/src/ffi/indexing/fill.rs b/rust/src/ffi/indexing/fill.rs index 262576a..661025f 100644 --- a/rust/src/ffi/indexing/fill.rs +++ b/rust/src/ffi/indexing/fill.rs @@ -4,7 +4,11 @@ use std::ffi::c_void; -use crate::helpers::error::{self, ERR_GENERIC, SUCCESS}; +use crate::helpers::view::{ + extract_view_mut_bool, extract_view_mut_f32, extract_view_mut_f64, extract_view_mut_i16, + extract_view_mut_i32, extract_view_mut_i64, extract_view_mut_i8, extract_view_mut_u16, + extract_view_mut_u32, extract_view_mut_u64, extract_view_mut_u8, +}; use crate::types::dtype::DType; use crate::types::{ArrayMetadata, NdArrayHandle}; @@ -21,66 +25,82 @@ pub unsafe extern "C" fn ndarray_fill( value: *const c_void, ) -> i32 { if handle.is_null() || value.is_null() || meta.is_null() { - return ERR_GENERIC; + return crate::helpers::error::ERR_GENERIC; } crate::ffi_guard!({ let wrapper = NdArrayHandle::as_wrapper(handle as *mut _); let meta = &*meta; - let result = match wrapper.dtype { + match wrapper.dtype { DType::Int8 => { let v = *(value as *const i8); - wrapper.fill_slice_i8(v, meta) + extract_view_mut_i8(wrapper, meta) + .expect("Type mismatch") + .fill(v); } DType::Int16 => { let v = *(value as *const i16); - wrapper.fill_slice_i16(v, meta) + extract_view_mut_i16(wrapper, meta) + .expect("Type mismatch") + .fill(v); } DType::Int32 => { let v = *(value as *const i32); - wrapper.fill_slice_i32(v, meta) + extract_view_mut_i32(wrapper, meta) + .expect("Type mismatch") + .fill(v); } DType::Int64 => { let v = *(value as *const i64); - wrapper.fill_slice_i64(v, meta) + extract_view_mut_i64(wrapper, meta) + .expect("Type mismatch") + .fill(v); } DType::Uint8 => { let v = *(value as *const u8); - wrapper.fill_slice_u8(v, meta) + extract_view_mut_u8(wrapper, meta) + .expect("Type mismatch") + .fill(v); } DType::Uint16 => { let v = *(value as *const u16); - wrapper.fill_slice_u16(v, meta) + extract_view_mut_u16(wrapper, meta) + .expect("Type mismatch") + .fill(v); } DType::Uint32 => { let v = *(value as *const u32); - wrapper.fill_slice_u32(v, meta) + extract_view_mut_u32(wrapper, meta) + .expect("Type mismatch") + .fill(v); } DType::Uint64 => { let v = *(value as *const u64); - wrapper.fill_slice_u64(v, meta) + extract_view_mut_u64(wrapper, meta) + .expect("Type mismatch") + .fill(v); } DType::Float32 => { let v = *(value as *const f32); - wrapper.fill_slice_f32(v, meta) + extract_view_mut_f32(wrapper, meta) + .expect("Type mismatch") + .fill(v); } DType::Float64 => { let v = *(value as *const f64); - wrapper.fill_slice_f64(v, meta) + extract_view_mut_f64(wrapper, meta) + .expect("Type mismatch") + .fill(v); } DType::Bool => { let v = *(value as *const u8); - wrapper.fill_slice_bool(v, meta) - } - }; - - match result { - Ok(()) => SUCCESS, - Err(e) => { - error::set_last_error(e); - ERR_GENERIC + extract_view_mut_bool(wrapper, meta) + .expect("Type mismatch") + .fill(v); } } + + crate::helpers::error::SUCCESS }) } diff --git a/rust/src/helpers/view.rs b/rust/src/helpers/view.rs index 3779e1e..e828b28 100644 --- a/rust/src/helpers/view.rs +++ b/rust/src/helpers/view.rs @@ -5,10 +5,11 @@ use crate::define_extract_view; use crate::define_extract_view_as; +use crate::define_extract_view_mut; use crate::types::ArrayData; use ndarray::ShapeBuilder; -// Generate `extract_view` functions for all types +// Generate `extract_view` functions for all types (immutable) define_extract_view!(extract_view_f64, ArrayData::Float64, f64); define_extract_view!(extract_view_f32, ArrayData::Float32, f32); define_extract_view!(extract_view_i64, ArrayData::Int64, i64); @@ -21,6 +22,19 @@ define_extract_view!(extract_view_u16, ArrayData::Uint16, u16); define_extract_view!(extract_view_u8, ArrayData::Uint8, u8); define_extract_view!(extract_view_bool, ArrayData::Bool, u8); +// Generate `extract_view_mut` functions for all types (mutable) +define_extract_view_mut!(extract_view_mut_f64, ArrayData::Float64, f64); +define_extract_view_mut!(extract_view_mut_f32, ArrayData::Float32, f32); +define_extract_view_mut!(extract_view_mut_i64, ArrayData::Int64, i64); +define_extract_view_mut!(extract_view_mut_i32, ArrayData::Int32, i32); +define_extract_view_mut!(extract_view_mut_i16, ArrayData::Int16, i16); +define_extract_view_mut!(extract_view_mut_i8, ArrayData::Int8, i8); +define_extract_view_mut!(extract_view_mut_u64, ArrayData::Uint64, u64); +define_extract_view_mut!(extract_view_mut_u32, ArrayData::Uint32, u32); +define_extract_view_mut!(extract_view_mut_u16, ArrayData::Uint16, u16); +define_extract_view_mut!(extract_view_mut_u8, ArrayData::Uint8, u8); +define_extract_view_mut!(extract_view_mut_bool, ArrayData::Bool, u8); + // Generate `extract_view_as` functions for all types define_extract_view_as!( extract_view_as_f64, diff --git a/rust/src/macros/core_macros.rs b/rust/src/macros/core_macros.rs index bf464dc..186fb44 100644 --- a/rust/src/macros/core_macros.rs +++ b/rust/src/macros/core_macros.rs @@ -151,191 +151,3 @@ macro_rules! impl_set_element { } }; } - -/// Generate fill_slice_* methods for NDArrayWrapper. -#[macro_export] -macro_rules! impl_fill_slice { - ($($method:ident, $type:ty, $variant:ident);* $(;)?) => { - impl $crate::types::NDArrayWrapper { - $( - pub unsafe fn $method( - &self, - value: $type, - meta: &$crate::types::ArrayMetadata, - ) -> Result<(), String> { - let shape = meta.shape_slice(); - let strides = meta.strides_slice(); - let offset = meta.offset; - - if let $crate::types::ArrayData::$variant(arr) = &self.data { - let mut guard = arr.write(); - let raw_ptr = guard.as_mut_ptr(); - - unsafe { - let ptr = raw_ptr.add(offset); - let strides_ix = ndarray::IxDyn(strides); - let mut view = ndarray::ArrayViewMut::from_shape_ptr( - ndarray::ShapeBuilder::strides(ndarray::IxDyn(shape), strides_ix), - ptr - ); - view.fill(value); - } - Ok(()) - } else { - Err(format!( - "Type mismatch: expected {}, got {:?}", - stringify!($variant), self.dtype - )) - } - } - )* - } - }; -} - -/// Generate assign_slice_* methods for NDArrayWrapper. -#[macro_export] -macro_rules! impl_assign_slice { - ($($method:ident, $type:ty, $variant:ident);* $(;)?) => { - impl $crate::types::NDArrayWrapper { - $( - pub unsafe fn $method( - &self, - meta: &$crate::types::ArrayMetadata, - src: &$crate::types::NDArrayWrapper, - src_meta: &$crate::types::ArrayMetadata, - ) -> Result<(), String> { - if let $crate::types::ArrayData::$variant(_) = &self.data { - } else { - return Err(format!( - "Type mismatch: expected {}, got {:?}", - stringify!($variant), self.dtype - )); - } - - if self.dtype != src.dtype { - return Err(format!( - "DType mismatch in assign: dst={:?}, src={:?}", - self.dtype, src.dtype - )); - } - - let is_same = self.is_same_array(src); - let src_shape = src_meta.shape_slice(); - let src_strides = src_meta.strides_slice(); - let src_offset = src_meta.offset; - let dst_shape = meta.shape_slice(); - let dst_strides = meta.strides_slice(); - let dst_offset = meta.offset; - - if is_same { - let temp_data: Vec<$type> = { - if let $crate::types::ArrayData::$variant(arr) = &src.data { - let guard = arr.read(); - let raw_ptr = guard.as_ptr(); - unsafe { - let ptr = raw_ptr.add(src_offset); - let strides_ix = ndarray::IxDyn(src_strides); - let view = ndarray::ArrayView::from_shape_ptr( - ndarray::ShapeBuilder::strides(ndarray::IxDyn(src_shape), strides_ix), - ptr - ); - view.iter().cloned().collect() - } - } else { - unreachable!("DType checked above"); - } - }; - - if let $crate::types::ArrayData::$variant(arr) = &self.data { - let mut guard = arr.write(); - let raw_ptr = guard.as_mut_ptr(); - unsafe { - let ptr = raw_ptr.add(dst_offset); - let strides_ix = ndarray::IxDyn(dst_strides); - let mut view = ndarray::ArrayViewMut::from_shape_ptr( - ndarray::ShapeBuilder::strides(ndarray::IxDyn(dst_shape), strides_ix), - ptr - ); - let temp_view = ndarray::ArrayView::from_shape(ndarray::IxDyn(dst_shape), &temp_data) - .map_err(|e| e.to_string())?; - view.assign(&temp_view); - } - } - } else { - if let ($crate::types::ArrayData::$variant(dst_arr), $crate::types::ArrayData::$variant(src_arr)) = (&self.data, &src.data) { - let src_guard = src_arr.read(); - let src_ptr = src_guard.as_ptr(); - let mut dst_guard = dst_arr.write(); - let dst_ptr = dst_guard.as_mut_ptr(); - - unsafe { - let s_ptr = src_ptr.add(src_offset); - let s_strides = ndarray::IxDyn(src_strides); - let src_view = ndarray::ArrayView::from_shape_ptr( - ndarray::ShapeBuilder::strides(ndarray::IxDyn(src_shape), s_strides), - s_ptr - ); - let d_ptr = dst_ptr.add(dst_offset); - let d_strides = ndarray::IxDyn(dst_strides); - let mut dst_view = ndarray::ArrayViewMut::from_shape_ptr( - ndarray::ShapeBuilder::strides(ndarray::IxDyn(dst_shape), d_strides), - d_ptr - ); - dst_view.assign(&src_view); - } - } else { - unreachable!("DType checked above"); - } - } - Ok(()) - } - )* - } - }; -} - -/// Generate copy_view_* methods for NDArrayWrapper. -#[macro_export] -macro_rules! impl_copy_view { - ($($method:ident, $type:ty, $variant:ident, $dtype:ident);* $(;)?) => { - impl $crate::types::NDArrayWrapper { - $( - pub unsafe fn $method( - &self, - meta: &$crate::types::ArrayMetadata, - ) -> Result<$crate::types::NDArrayWrapper, String> { - let shape = meta.shape_slice(); - let strides = meta.strides_slice(); - let offset = meta.offset; - - if let $crate::types::ArrayData::$variant(arr) = &self.data { - let guard = arr.read(); - let raw_ptr = guard.as_ptr(); - - unsafe { - let ptr = raw_ptr.add(offset); - let strides_ix = ndarray::IxDyn(strides); - let view = ndarray::ArrayView::from_shape_ptr( - ndarray::ShapeBuilder::strides(ndarray::IxDyn(shape), strides_ix), - ptr - ); - let new_arr = view.to_owned(); - Ok($crate::types::NDArrayWrapper { - data: $crate::types::ArrayData::$variant( - std::sync::Arc::new(parking_lot::RwLock::new(new_arr)) - ), - dtype: $crate::types::dtype::DType::$dtype, - }) - } - } else { - Err(format!( - "Type mismatch: expected {}, got {:?}", - stringify!($variant), self.dtype - )) - } - } - )* - } - }; -} diff --git a/rust/src/macros/view.rs b/rust/src/macros/view.rs index 56c8d68..320763b 100644 --- a/rust/src/macros/view.rs +++ b/rust/src/macros/view.rs @@ -1,6 +1,6 @@ //! Macros to generate `extract_view` and `extract_view_as` functions for each type. -/// Macro to generate `extract_view` functions for each type +/// Macro to generate `extract_view` functions for each type (immutable) #[macro_export] macro_rules! define_extract_view { ($name:ident, $variant:path, $type:ty) => { @@ -33,6 +33,40 @@ macro_rules! define_extract_view { }; } +/// Macro to generate `extract_view_mut` functions for each type (mutable) +#[macro_export] +macro_rules! define_extract_view_mut { + ($name:ident, $variant:path, $type:ty) => { + /// Extract a mutable view of the specific type from the wrapper. + /// + /// # Safety + /// The caller must ensure the ArrayMetadata is valid and that there are + /// no other readers or writers to this array. + pub unsafe fn $name<'a>( + wrapper: &'a crate::types::NDArrayWrapper, + meta: &'a crate::types::ArrayMetadata, + ) -> Option> { + let offset = meta.offset; + let shape = meta.shape_slice(); + let strides = meta.strides_slice(); + match &wrapper.data { + $variant(arr) => { + let mut guard = arr.write(); + let ptr = guard.as_mut_ptr(); + let view_ptr = ptr.add(offset); + let strides_ix = ndarray::IxDyn(strides); + ndarray::ArrayViewMutD::<$type>::from_shape_ptr( + ndarray::IxDyn(shape).strides(strides_ix), + view_ptr, + ) + .into() + } + _ => None, + } + } + }; +} + /// Macro to generate `extract_view_as` functions. #[macro_export] macro_rules! define_extract_view_as { diff --git a/rust/src/types/wrapper.rs b/rust/src/types/wrapper.rs index 6977f41..bce78f2 100644 --- a/rust/src/types/wrapper.rs +++ b/rust/src/types/wrapper.rs @@ -103,51 +103,6 @@ crate::impl_set_element!( set_element_bool, u8, Bool ); -// Generate fill_slice_* methods for view filling -crate::impl_fill_slice!( - fill_slice_i8, i8, Int8; - fill_slice_i16, i16, Int16; - fill_slice_i32, i32, Int32; - fill_slice_i64, i64, Int64; - fill_slice_u8, u8, Uint8; - fill_slice_u16, u16, Uint16; - fill_slice_u32, u32, Uint32; - fill_slice_u64, u64, Uint64; - fill_slice_f32, f32, Float32; - fill_slice_f64, f64, Float64; - fill_slice_bool, u8, Bool -); - -// Generate assign_slice_* methods for array assignment -crate::impl_assign_slice!( - assign_slice_i8, i8, Int8; - assign_slice_i16, i16, Int16; - assign_slice_i32, i32, Int32; - assign_slice_i64, i64, Int64; - assign_slice_u8, u8, Uint8; - assign_slice_u16, u16, Uint16; - assign_slice_u32, u32, Uint32; - assign_slice_u64, u64, Uint64; - assign_slice_f32, f32, Float32; - assign_slice_f64, f64, Float64; - assign_slice_bool, u8, Bool -); - -// Generate copy_view_* methods for creating owned arrays from views -crate::impl_copy_view!( - copy_view_i8, i8, Int8, Int8; - copy_view_i16, i16, Int16, Int16; - copy_view_i32, i32, Int32, Int32; - copy_view_i64, i64, Int64, Int64; - copy_view_u8, u8, Uint8, Uint8; - copy_view_u16, u16, Uint16, Uint16; - copy_view_u32, u32, Uint32, Uint32; - copy_view_u64, u64, Uint64, Uint64; - copy_view_f32, f32, Float32, Float32; - copy_view_f64, f64, Float64, Float64; - copy_view_bool, u8, Bool, Bool -); - #[cfg(test)] mod tests { use super::*;