From a5784dbc5cdea6f7581efadcccc4a1127478d019 Mon Sep 17 00:00:00 2001 From: Kyrian Obikwelu Date: Mon, 23 Mar 2026 10:57:27 +0100 Subject: [PATCH] refactor(rust): simplify copy, fill, and assign operations Refactor FFI array operations to use standardized view extraction helpers instead of wrapper methods. Removes redundant macros and simplifies the wrapper module by inlining operation logic directly in FFI functions. --- rust/src/ffi/array/copy.rs | 114 ++++++++++++++----- rust/src/ffi/indexing/assign.rs | 151 +++++++++++++++++++++---- rust/src/ffi/indexing/fill.rs | 64 +++++++---- rust/src/helpers/view.rs | 16 ++- rust/src/macros/core_macros.rs | 188 -------------------------------- rust/src/macros/view.rs | 36 +++++- rust/src/types/wrapper.rs | 45 -------- 7 files changed, 312 insertions(+), 302 deletions(-) diff --git a/rust/src/ffi/array/copy.rs b/rust/src/ffi/array/copy.rs index ca9c880..a3414f6 100644 --- a/rust/src/ffi/array/copy.rs +++ b/rust/src/ffi/array/copy.rs @@ -1,8 +1,14 @@ //! Array copy FFI function. -use crate::helpers::error::{self, ERR_GENERIC, SUCCESS}; +use crate::helpers::view::{ + extract_view_bool, extract_view_f32, extract_view_f64, extract_view_i16, extract_view_i32, + extract_view_i64, extract_view_i8, extract_view_u16, extract_view_u32, extract_view_u64, + extract_view_u8, +}; use crate::types::dtype::DType; -use crate::types::{ArrayMetadata, NdArrayHandle}; +use crate::types::{ArrayData, ArrayMetadata, NDArrayWrapper, NdArrayHandle}; +use parking_lot::RwLock; +use std::sync::Arc; /// Create a deep copy of an array view. #[no_mangle] @@ -12,36 +18,94 @@ pub unsafe extern "C" fn ndarray_copy( out_handle: *mut *mut NdArrayHandle, ) -> i32 { if handle.is_null() || meta.is_null() || out_handle.is_null() { - return ERR_GENERIC; + return crate::helpers::error::ERR_GENERIC; } crate::ffi_guard!({ let wrapper = NdArrayHandle::as_wrapper(handle as *mut _); let meta = &*meta; - let result = match wrapper.dtype { - DType::Int8 => wrapper.copy_view_i8(meta), - DType::Int16 => wrapper.copy_view_i16(meta), - DType::Int32 => wrapper.copy_view_i32(meta), - DType::Int64 => wrapper.copy_view_i64(meta), - DType::Uint8 => wrapper.copy_view_u8(meta), - DType::Uint16 => wrapper.copy_view_u16(meta), - DType::Uint32 => wrapper.copy_view_u32(meta), - DType::Uint64 => wrapper.copy_view_u64(meta), - DType::Float32 => wrapper.copy_view_f32(meta), - DType::Float64 => wrapper.copy_view_f64(meta), - DType::Bool => wrapper.copy_view_bool(meta), - }; - - match result { - Ok(new_wrapper) => { - *out_handle = NdArrayHandle::from_wrapper(Box::new(new_wrapper)); - SUCCESS + let new_wrapper = match wrapper.dtype { + DType::Int8 => { + let view = extract_view_i8(wrapper, meta).expect("Type mismatch"); + NDArrayWrapper { + data: ArrayData::Int8(Arc::new(RwLock::new(view.to_owned()))), + dtype: DType::Int8, + } + } + DType::Int16 => { + let view = extract_view_i16(wrapper, meta).expect("Type mismatch"); + NDArrayWrapper { + data: ArrayData::Int16(Arc::new(RwLock::new(view.to_owned()))), + dtype: DType::Int16, + } + } + DType::Int32 => { + let view = extract_view_i32(wrapper, meta).expect("Type mismatch"); + NDArrayWrapper { + data: ArrayData::Int32(Arc::new(RwLock::new(view.to_owned()))), + dtype: DType::Int32, + } + } + DType::Int64 => { + let view = extract_view_i64(wrapper, meta).expect("Type mismatch"); + NDArrayWrapper { + data: ArrayData::Int64(Arc::new(RwLock::new(view.to_owned()))), + dtype: DType::Int64, + } + } + DType::Uint8 => { + let view = extract_view_u8(wrapper, meta).expect("Type mismatch"); + NDArrayWrapper { + data: ArrayData::Uint8(Arc::new(RwLock::new(view.to_owned()))), + dtype: DType::Uint8, + } + } + DType::Uint16 => { + let view = extract_view_u16(wrapper, meta).expect("Type mismatch"); + NDArrayWrapper { + data: ArrayData::Uint16(Arc::new(RwLock::new(view.to_owned()))), + dtype: DType::Uint16, + } } - Err(e) => { - error::set_last_error(e); - ERR_GENERIC + DType::Uint32 => { + let view = extract_view_u32(wrapper, meta).expect("Type mismatch"); + NDArrayWrapper { + data: ArrayData::Uint32(Arc::new(RwLock::new(view.to_owned()))), + dtype: DType::Uint32, + } } - } + DType::Uint64 => { + let view = extract_view_u64(wrapper, meta).expect("Type mismatch"); + NDArrayWrapper { + data: ArrayData::Uint64(Arc::new(RwLock::new(view.to_owned()))), + dtype: DType::Uint64, + } + } + DType::Float32 => { + let view = extract_view_f32(wrapper, meta).expect("Type mismatch"); + NDArrayWrapper { + data: ArrayData::Float32(Arc::new(RwLock::new(view.to_owned()))), + dtype: DType::Float32, + } + } + DType::Float64 => { + let view = extract_view_f64(wrapper, meta).expect("Type mismatch"); + NDArrayWrapper { + data: ArrayData::Float64(Arc::new(RwLock::new(view.to_owned()))), + dtype: DType::Float64, + } + } + DType::Bool => { + let view = extract_view_bool(wrapper, meta).expect("Type mismatch"); + NDArrayWrapper { + data: ArrayData::Bool(Arc::new(RwLock::new(view.to_owned()))), + dtype: DType::Bool, + } + } + }; + + *out_handle = NdArrayHandle::from_wrapper(Box::new(new_wrapper)); + crate::helpers::error::SUCCESS }) } diff --git a/rust/src/ffi/indexing/assign.rs b/rust/src/ffi/indexing/assign.rs index f18d20c..f655775 100644 --- a/rust/src/ffi/indexing/assign.rs +++ b/rust/src/ffi/indexing/assign.rs @@ -2,7 +2,14 @@ //! //! Provides assign operations between strided array views. -use crate::helpers::error::{self, ERR_GENERIC, SUCCESS}; +use crate::helpers::error::{set_last_error, ERR_GENERIC, SUCCESS}; +use crate::helpers::view::{ + extract_view_bool, extract_view_f32, extract_view_f64, extract_view_i16, extract_view_i32, + extract_view_i64, extract_view_i8, extract_view_mut_bool, extract_view_mut_f32, + extract_view_mut_f64, extract_view_mut_i16, extract_view_mut_i32, extract_view_mut_i64, + extract_view_mut_i8, extract_view_mut_u16, extract_view_mut_u32, extract_view_mut_u64, + extract_view_mut_u8, extract_view_u16, extract_view_u32, extract_view_u64, extract_view_u8, +}; use crate::types::dtype::DType; use crate::types::{ArrayMetadata, NdArrayHandle}; @@ -31,26 +38,130 @@ pub unsafe extern "C" fn ndarray_assign( let dst_meta = &*dst_meta; let src_meta = &*src_meta; - let result = match dst_wrapper.dtype { - DType::Int8 => dst_wrapper.assign_slice_i8(dst_meta, src_wrapper, src_meta), - DType::Int16 => dst_wrapper.assign_slice_i16(dst_meta, src_wrapper, src_meta), - DType::Int32 => dst_wrapper.assign_slice_i32(dst_meta, src_wrapper, src_meta), - DType::Int64 => dst_wrapper.assign_slice_i64(dst_meta, src_wrapper, src_meta), - DType::Uint8 => dst_wrapper.assign_slice_u8(dst_meta, src_wrapper, src_meta), - DType::Uint16 => dst_wrapper.assign_slice_u16(dst_meta, src_wrapper, src_meta), - DType::Uint32 => dst_wrapper.assign_slice_u32(dst_meta, src_wrapper, src_meta), - DType::Uint64 => dst_wrapper.assign_slice_u64(dst_meta, src_wrapper, src_meta), - DType::Float32 => dst_wrapper.assign_slice_f32(dst_meta, src_wrapper, src_meta), - DType::Float64 => dst_wrapper.assign_slice_f64(dst_meta, src_wrapper, src_meta), - DType::Bool => dst_wrapper.assign_slice_bool(dst_meta, src_wrapper, src_meta), - }; - - match result { - Ok(()) => SUCCESS, - Err(e) => { - error::set_last_error(e); - ERR_GENERIC + // Validate dtypes match + if dst_wrapper.dtype != src_wrapper.dtype { + set_last_error(format!( + "DType mismatch in assign: dst={:?}, src={:?}", + dst_wrapper.dtype, src_wrapper.dtype + )); + return ERR_GENERIC; + } + + let is_same = dst_wrapper.is_same_array(src_wrapper); + + match dst_wrapper.dtype { + DType::Int8 => { + let src_view = extract_view_i8(src_wrapper, src_meta).expect("Type mismatch"); + let mut dst_view = + extract_view_mut_i8(dst_wrapper, dst_meta).expect("Type mismatch"); + if is_same { + dst_view.assign(&src_view.to_owned()); + } else { + dst_view.assign(&src_view); + } + } + DType::Int16 => { + let src_view = extract_view_i16(src_wrapper, src_meta).expect("Type mismatch"); + let mut dst_view = + extract_view_mut_i16(dst_wrapper, dst_meta).expect("Type mismatch"); + if is_same { + dst_view.assign(&src_view.to_owned()); + } else { + dst_view.assign(&src_view); + } + } + DType::Int32 => { + let src_view = extract_view_i32(src_wrapper, src_meta).expect("Type mismatch"); + let mut dst_view = + extract_view_mut_i32(dst_wrapper, dst_meta).expect("Type mismatch"); + if is_same { + dst_view.assign(&src_view.to_owned()); + } else { + dst_view.assign(&src_view); + } + } + DType::Int64 => { + let src_view = extract_view_i64(src_wrapper, src_meta).expect("Type mismatch"); + let mut dst_view = + extract_view_mut_i64(dst_wrapper, dst_meta).expect("Type mismatch"); + if is_same { + dst_view.assign(&src_view.to_owned()); + } else { + dst_view.assign(&src_view); + } + } + DType::Uint8 => { + let src_view = extract_view_u8(src_wrapper, src_meta).expect("Type mismatch"); + let mut dst_view = + extract_view_mut_u8(dst_wrapper, dst_meta).expect("Type mismatch"); + if is_same { + dst_view.assign(&src_view.to_owned()); + } else { + dst_view.assign(&src_view); + } + } + DType::Uint16 => { + let src_view = extract_view_u16(src_wrapper, src_meta).expect("Type mismatch"); + let mut dst_view = + extract_view_mut_u16(dst_wrapper, dst_meta).expect("Type mismatch"); + if is_same { + dst_view.assign(&src_view.to_owned()); + } else { + dst_view.assign(&src_view); + } + } + DType::Uint32 => { + let src_view = extract_view_u32(src_wrapper, src_meta).expect("Type mismatch"); + let mut dst_view = + extract_view_mut_u32(dst_wrapper, dst_meta).expect("Type mismatch"); + if is_same { + dst_view.assign(&src_view.to_owned()); + } else { + dst_view.assign(&src_view); + } + } + DType::Uint64 => { + let src_view = extract_view_u64(src_wrapper, src_meta).expect("Type mismatch"); + let mut dst_view = + extract_view_mut_u64(dst_wrapper, dst_meta).expect("Type mismatch"); + if is_same { + dst_view.assign(&src_view.to_owned()); + } else { + dst_view.assign(&src_view); + } + } + DType::Float32 => { + let src_view = extract_view_f32(src_wrapper, src_meta).expect("Type mismatch"); + let mut dst_view = + extract_view_mut_f32(dst_wrapper, dst_meta).expect("Type mismatch"); + if is_same { + dst_view.assign(&src_view.to_owned()); + } else { + dst_view.assign(&src_view); + } + } + DType::Float64 => { + let src_view = extract_view_f64(src_wrapper, src_meta).expect("Type mismatch"); + let mut dst_view = + extract_view_mut_f64(dst_wrapper, dst_meta).expect("Type mismatch"); + if is_same { + dst_view.assign(&src_view.to_owned()); + } else { + dst_view.assign(&src_view); + } + } + DType::Bool => { + let src_view = extract_view_bool(src_wrapper, src_meta).expect("Type mismatch"); + let mut dst_view = + extract_view_mut_bool(dst_wrapper, dst_meta).expect("Type mismatch"); + if is_same { + dst_view.assign(&src_view.to_owned()); + } else { + dst_view.assign(&src_view); + } } } + + SUCCESS }) } diff --git a/rust/src/ffi/indexing/fill.rs b/rust/src/ffi/indexing/fill.rs index 262576a..661025f 100644 --- a/rust/src/ffi/indexing/fill.rs +++ b/rust/src/ffi/indexing/fill.rs @@ -4,7 +4,11 @@ use std::ffi::c_void; -use crate::helpers::error::{self, ERR_GENERIC, SUCCESS}; +use crate::helpers::view::{ + extract_view_mut_bool, extract_view_mut_f32, extract_view_mut_f64, extract_view_mut_i16, + extract_view_mut_i32, extract_view_mut_i64, extract_view_mut_i8, extract_view_mut_u16, + extract_view_mut_u32, extract_view_mut_u64, extract_view_mut_u8, +}; use crate::types::dtype::DType; use crate::types::{ArrayMetadata, NdArrayHandle}; @@ -21,66 +25,82 @@ pub unsafe extern "C" fn ndarray_fill( value: *const c_void, ) -> i32 { if handle.is_null() || value.is_null() || meta.is_null() { - return ERR_GENERIC; + return crate::helpers::error::ERR_GENERIC; } crate::ffi_guard!({ let wrapper = NdArrayHandle::as_wrapper(handle as *mut _); let meta = &*meta; - let result = match wrapper.dtype { + match wrapper.dtype { DType::Int8 => { let v = *(value as *const i8); - wrapper.fill_slice_i8(v, meta) + extract_view_mut_i8(wrapper, meta) + .expect("Type mismatch") + .fill(v); } DType::Int16 => { let v = *(value as *const i16); - wrapper.fill_slice_i16(v, meta) + extract_view_mut_i16(wrapper, meta) + .expect("Type mismatch") + .fill(v); } DType::Int32 => { let v = *(value as *const i32); - wrapper.fill_slice_i32(v, meta) + extract_view_mut_i32(wrapper, meta) + .expect("Type mismatch") + .fill(v); } DType::Int64 => { let v = *(value as *const i64); - wrapper.fill_slice_i64(v, meta) + extract_view_mut_i64(wrapper, meta) + .expect("Type mismatch") + .fill(v); } DType::Uint8 => { let v = *(value as *const u8); - wrapper.fill_slice_u8(v, meta) + extract_view_mut_u8(wrapper, meta) + .expect("Type mismatch") + .fill(v); } DType::Uint16 => { let v = *(value as *const u16); - wrapper.fill_slice_u16(v, meta) + extract_view_mut_u16(wrapper, meta) + .expect("Type mismatch") + .fill(v); } DType::Uint32 => { let v = *(value as *const u32); - wrapper.fill_slice_u32(v, meta) + extract_view_mut_u32(wrapper, meta) + .expect("Type mismatch") + .fill(v); } DType::Uint64 => { let v = *(value as *const u64); - wrapper.fill_slice_u64(v, meta) + extract_view_mut_u64(wrapper, meta) + .expect("Type mismatch") + .fill(v); } DType::Float32 => { let v = *(value as *const f32); - wrapper.fill_slice_f32(v, meta) + extract_view_mut_f32(wrapper, meta) + .expect("Type mismatch") + .fill(v); } DType::Float64 => { let v = *(value as *const f64); - wrapper.fill_slice_f64(v, meta) + extract_view_mut_f64(wrapper, meta) + .expect("Type mismatch") + .fill(v); } DType::Bool => { let v = *(value as *const u8); - wrapper.fill_slice_bool(v, meta) - } - }; - - match result { - Ok(()) => SUCCESS, - Err(e) => { - error::set_last_error(e); - ERR_GENERIC + extract_view_mut_bool(wrapper, meta) + .expect("Type mismatch") + .fill(v); } } + + crate::helpers::error::SUCCESS }) } diff --git a/rust/src/helpers/view.rs b/rust/src/helpers/view.rs index 3779e1e..e828b28 100644 --- a/rust/src/helpers/view.rs +++ b/rust/src/helpers/view.rs @@ -5,10 +5,11 @@ use crate::define_extract_view; use crate::define_extract_view_as; +use crate::define_extract_view_mut; use crate::types::ArrayData; use ndarray::ShapeBuilder; -// Generate `extract_view` functions for all types +// Generate `extract_view` functions for all types (immutable) define_extract_view!(extract_view_f64, ArrayData::Float64, f64); define_extract_view!(extract_view_f32, ArrayData::Float32, f32); define_extract_view!(extract_view_i64, ArrayData::Int64, i64); @@ -21,6 +22,19 @@ define_extract_view!(extract_view_u16, ArrayData::Uint16, u16); define_extract_view!(extract_view_u8, ArrayData::Uint8, u8); define_extract_view!(extract_view_bool, ArrayData::Bool, u8); +// Generate `extract_view_mut` functions for all types (mutable) +define_extract_view_mut!(extract_view_mut_f64, ArrayData::Float64, f64); +define_extract_view_mut!(extract_view_mut_f32, ArrayData::Float32, f32); +define_extract_view_mut!(extract_view_mut_i64, ArrayData::Int64, i64); +define_extract_view_mut!(extract_view_mut_i32, ArrayData::Int32, i32); +define_extract_view_mut!(extract_view_mut_i16, ArrayData::Int16, i16); +define_extract_view_mut!(extract_view_mut_i8, ArrayData::Int8, i8); +define_extract_view_mut!(extract_view_mut_u64, ArrayData::Uint64, u64); +define_extract_view_mut!(extract_view_mut_u32, ArrayData::Uint32, u32); +define_extract_view_mut!(extract_view_mut_u16, ArrayData::Uint16, u16); +define_extract_view_mut!(extract_view_mut_u8, ArrayData::Uint8, u8); +define_extract_view_mut!(extract_view_mut_bool, ArrayData::Bool, u8); + // Generate `extract_view_as` functions for all types define_extract_view_as!( extract_view_as_f64, diff --git a/rust/src/macros/core_macros.rs b/rust/src/macros/core_macros.rs index bf464dc..186fb44 100644 --- a/rust/src/macros/core_macros.rs +++ b/rust/src/macros/core_macros.rs @@ -151,191 +151,3 @@ macro_rules! impl_set_element { } }; } - -/// Generate fill_slice_* methods for NDArrayWrapper. -#[macro_export] -macro_rules! impl_fill_slice { - ($($method:ident, $type:ty, $variant:ident);* $(;)?) => { - impl $crate::types::NDArrayWrapper { - $( - pub unsafe fn $method( - &self, - value: $type, - meta: &$crate::types::ArrayMetadata, - ) -> Result<(), String> { - let shape = meta.shape_slice(); - let strides = meta.strides_slice(); - let offset = meta.offset; - - if let $crate::types::ArrayData::$variant(arr) = &self.data { - let mut guard = arr.write(); - let raw_ptr = guard.as_mut_ptr(); - - unsafe { - let ptr = raw_ptr.add(offset); - let strides_ix = ndarray::IxDyn(strides); - let mut view = ndarray::ArrayViewMut::from_shape_ptr( - ndarray::ShapeBuilder::strides(ndarray::IxDyn(shape), strides_ix), - ptr - ); - view.fill(value); - } - Ok(()) - } else { - Err(format!( - "Type mismatch: expected {}, got {:?}", - stringify!($variant), self.dtype - )) - } - } - )* - } - }; -} - -/// Generate assign_slice_* methods for NDArrayWrapper. -#[macro_export] -macro_rules! impl_assign_slice { - ($($method:ident, $type:ty, $variant:ident);* $(;)?) => { - impl $crate::types::NDArrayWrapper { - $( - pub unsafe fn $method( - &self, - meta: &$crate::types::ArrayMetadata, - src: &$crate::types::NDArrayWrapper, - src_meta: &$crate::types::ArrayMetadata, - ) -> Result<(), String> { - if let $crate::types::ArrayData::$variant(_) = &self.data { - } else { - return Err(format!( - "Type mismatch: expected {}, got {:?}", - stringify!($variant), self.dtype - )); - } - - if self.dtype != src.dtype { - return Err(format!( - "DType mismatch in assign: dst={:?}, src={:?}", - self.dtype, src.dtype - )); - } - - let is_same = self.is_same_array(src); - let src_shape = src_meta.shape_slice(); - let src_strides = src_meta.strides_slice(); - let src_offset = src_meta.offset; - let dst_shape = meta.shape_slice(); - let dst_strides = meta.strides_slice(); - let dst_offset = meta.offset; - - if is_same { - let temp_data: Vec<$type> = { - if let $crate::types::ArrayData::$variant(arr) = &src.data { - let guard = arr.read(); - let raw_ptr = guard.as_ptr(); - unsafe { - let ptr = raw_ptr.add(src_offset); - let strides_ix = ndarray::IxDyn(src_strides); - let view = ndarray::ArrayView::from_shape_ptr( - ndarray::ShapeBuilder::strides(ndarray::IxDyn(src_shape), strides_ix), - ptr - ); - view.iter().cloned().collect() - } - } else { - unreachable!("DType checked above"); - } - }; - - if let $crate::types::ArrayData::$variant(arr) = &self.data { - let mut guard = arr.write(); - let raw_ptr = guard.as_mut_ptr(); - unsafe { - let ptr = raw_ptr.add(dst_offset); - let strides_ix = ndarray::IxDyn(dst_strides); - let mut view = ndarray::ArrayViewMut::from_shape_ptr( - ndarray::ShapeBuilder::strides(ndarray::IxDyn(dst_shape), strides_ix), - ptr - ); - let temp_view = ndarray::ArrayView::from_shape(ndarray::IxDyn(dst_shape), &temp_data) - .map_err(|e| e.to_string())?; - view.assign(&temp_view); - } - } - } else { - if let ($crate::types::ArrayData::$variant(dst_arr), $crate::types::ArrayData::$variant(src_arr)) = (&self.data, &src.data) { - let src_guard = src_arr.read(); - let src_ptr = src_guard.as_ptr(); - let mut dst_guard = dst_arr.write(); - let dst_ptr = dst_guard.as_mut_ptr(); - - unsafe { - let s_ptr = src_ptr.add(src_offset); - let s_strides = ndarray::IxDyn(src_strides); - let src_view = ndarray::ArrayView::from_shape_ptr( - ndarray::ShapeBuilder::strides(ndarray::IxDyn(src_shape), s_strides), - s_ptr - ); - let d_ptr = dst_ptr.add(dst_offset); - let d_strides = ndarray::IxDyn(dst_strides); - let mut dst_view = ndarray::ArrayViewMut::from_shape_ptr( - ndarray::ShapeBuilder::strides(ndarray::IxDyn(dst_shape), d_strides), - d_ptr - ); - dst_view.assign(&src_view); - } - } else { - unreachable!("DType checked above"); - } - } - Ok(()) - } - )* - } - }; -} - -/// Generate copy_view_* methods for NDArrayWrapper. -#[macro_export] -macro_rules! impl_copy_view { - ($($method:ident, $type:ty, $variant:ident, $dtype:ident);* $(;)?) => { - impl $crate::types::NDArrayWrapper { - $( - pub unsafe fn $method( - &self, - meta: &$crate::types::ArrayMetadata, - ) -> Result<$crate::types::NDArrayWrapper, String> { - let shape = meta.shape_slice(); - let strides = meta.strides_slice(); - let offset = meta.offset; - - if let $crate::types::ArrayData::$variant(arr) = &self.data { - let guard = arr.read(); - let raw_ptr = guard.as_ptr(); - - unsafe { - let ptr = raw_ptr.add(offset); - let strides_ix = ndarray::IxDyn(strides); - let view = ndarray::ArrayView::from_shape_ptr( - ndarray::ShapeBuilder::strides(ndarray::IxDyn(shape), strides_ix), - ptr - ); - let new_arr = view.to_owned(); - Ok($crate::types::NDArrayWrapper { - data: $crate::types::ArrayData::$variant( - std::sync::Arc::new(parking_lot::RwLock::new(new_arr)) - ), - dtype: $crate::types::dtype::DType::$dtype, - }) - } - } else { - Err(format!( - "Type mismatch: expected {}, got {:?}", - stringify!($variant), self.dtype - )) - } - } - )* - } - }; -} diff --git a/rust/src/macros/view.rs b/rust/src/macros/view.rs index 56c8d68..320763b 100644 --- a/rust/src/macros/view.rs +++ b/rust/src/macros/view.rs @@ -1,6 +1,6 @@ //! Macros to generate `extract_view` and `extract_view_as` functions for each type. -/// Macro to generate `extract_view` functions for each type +/// Macro to generate `extract_view` functions for each type (immutable) #[macro_export] macro_rules! define_extract_view { ($name:ident, $variant:path, $type:ty) => { @@ -33,6 +33,40 @@ macro_rules! define_extract_view { }; } +/// Macro to generate `extract_view_mut` functions for each type (mutable) +#[macro_export] +macro_rules! define_extract_view_mut { + ($name:ident, $variant:path, $type:ty) => { + /// Extract a mutable view of the specific type from the wrapper. + /// + /// # Safety + /// The caller must ensure the ArrayMetadata is valid and that there are + /// no other readers or writers to this array. + pub unsafe fn $name<'a>( + wrapper: &'a crate::types::NDArrayWrapper, + meta: &'a crate::types::ArrayMetadata, + ) -> Option> { + let offset = meta.offset; + let shape = meta.shape_slice(); + let strides = meta.strides_slice(); + match &wrapper.data { + $variant(arr) => { + let mut guard = arr.write(); + let ptr = guard.as_mut_ptr(); + let view_ptr = ptr.add(offset); + let strides_ix = ndarray::IxDyn(strides); + ndarray::ArrayViewMutD::<$type>::from_shape_ptr( + ndarray::IxDyn(shape).strides(strides_ix), + view_ptr, + ) + .into() + } + _ => None, + } + } + }; +} + /// Macro to generate `extract_view_as` functions. #[macro_export] macro_rules! define_extract_view_as { diff --git a/rust/src/types/wrapper.rs b/rust/src/types/wrapper.rs index 6977f41..bce78f2 100644 --- a/rust/src/types/wrapper.rs +++ b/rust/src/types/wrapper.rs @@ -103,51 +103,6 @@ crate::impl_set_element!( set_element_bool, u8, Bool ); -// Generate fill_slice_* methods for view filling -crate::impl_fill_slice!( - fill_slice_i8, i8, Int8; - fill_slice_i16, i16, Int16; - fill_slice_i32, i32, Int32; - fill_slice_i64, i64, Int64; - fill_slice_u8, u8, Uint8; - fill_slice_u16, u16, Uint16; - fill_slice_u32, u32, Uint32; - fill_slice_u64, u64, Uint64; - fill_slice_f32, f32, Float32; - fill_slice_f64, f64, Float64; - fill_slice_bool, u8, Bool -); - -// Generate assign_slice_* methods for array assignment -crate::impl_assign_slice!( - assign_slice_i8, i8, Int8; - assign_slice_i16, i16, Int16; - assign_slice_i32, i32, Int32; - assign_slice_i64, i64, Int64; - assign_slice_u8, u8, Uint8; - assign_slice_u16, u16, Uint16; - assign_slice_u32, u32, Uint32; - assign_slice_u64, u64, Uint64; - assign_slice_f32, f32, Float32; - assign_slice_f64, f64, Float64; - assign_slice_bool, u8, Bool -); - -// Generate copy_view_* methods for creating owned arrays from views -crate::impl_copy_view!( - copy_view_i8, i8, Int8, Int8; - copy_view_i16, i16, Int16, Int16; - copy_view_i32, i32, Int32, Int32; - copy_view_i64, i64, Int64, Int64; - copy_view_u8, u8, Uint8, Uint8; - copy_view_u16, u16, Uint16, Uint16; - copy_view_u32, u32, Uint32, Uint32; - copy_view_u64, u64, Uint64, Uint64; - copy_view_f32, f32, Float32, Float32; - copy_view_f64, f64, Float64, Float64; - copy_view_bool, u8, Bool, Bool -); - #[cfg(test)] mod tests { use super::*;