Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 89 additions & 25 deletions rust/src/ffi/array/copy.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
//! Array copy FFI function.

use crate::helpers::error::{self, ERR_GENERIC, SUCCESS};
use crate::helpers::view::{
extract_view_bool, extract_view_f32, extract_view_f64, extract_view_i16, extract_view_i32,
extract_view_i64, extract_view_i8, extract_view_u16, extract_view_u32, extract_view_u64,
extract_view_u8,
};
use crate::types::dtype::DType;
use crate::types::{ArrayMetadata, NdArrayHandle};
use crate::types::{ArrayData, ArrayMetadata, NDArrayWrapper, NdArrayHandle};
use parking_lot::RwLock;
use std::sync::Arc;

/// Create a deep copy of an array view.
#[no_mangle]
Expand All @@ -12,36 +18,94 @@ pub unsafe extern "C" fn ndarray_copy(
out_handle: *mut *mut NdArrayHandle,
) -> i32 {
if handle.is_null() || meta.is_null() || out_handle.is_null() {
return ERR_GENERIC;
return crate::helpers::error::ERR_GENERIC;
}

crate::ffi_guard!({
let wrapper = NdArrayHandle::as_wrapper(handle as *mut _);
let meta = &*meta;

let result = match wrapper.dtype {
DType::Int8 => wrapper.copy_view_i8(meta),
DType::Int16 => wrapper.copy_view_i16(meta),
DType::Int32 => wrapper.copy_view_i32(meta),
DType::Int64 => wrapper.copy_view_i64(meta),
DType::Uint8 => wrapper.copy_view_u8(meta),
DType::Uint16 => wrapper.copy_view_u16(meta),
DType::Uint32 => wrapper.copy_view_u32(meta),
DType::Uint64 => wrapper.copy_view_u64(meta),
DType::Float32 => wrapper.copy_view_f32(meta),
DType::Float64 => wrapper.copy_view_f64(meta),
DType::Bool => wrapper.copy_view_bool(meta),
};

match result {
Ok(new_wrapper) => {
*out_handle = NdArrayHandle::from_wrapper(Box::new(new_wrapper));
SUCCESS
let new_wrapper = match wrapper.dtype {
DType::Int8 => {
let view = extract_view_i8(wrapper, meta).expect("Type mismatch");
NDArrayWrapper {
data: ArrayData::Int8(Arc::new(RwLock::new(view.to_owned()))),
dtype: DType::Int8,
}
}
DType::Int16 => {
let view = extract_view_i16(wrapper, meta).expect("Type mismatch");
NDArrayWrapper {
data: ArrayData::Int16(Arc::new(RwLock::new(view.to_owned()))),
dtype: DType::Int16,
}
}
DType::Int32 => {
let view = extract_view_i32(wrapper, meta).expect("Type mismatch");
NDArrayWrapper {
data: ArrayData::Int32(Arc::new(RwLock::new(view.to_owned()))),
dtype: DType::Int32,
}
}
DType::Int64 => {
let view = extract_view_i64(wrapper, meta).expect("Type mismatch");
NDArrayWrapper {
data: ArrayData::Int64(Arc::new(RwLock::new(view.to_owned()))),
dtype: DType::Int64,
}
}
DType::Uint8 => {
let view = extract_view_u8(wrapper, meta).expect("Type mismatch");
NDArrayWrapper {
data: ArrayData::Uint8(Arc::new(RwLock::new(view.to_owned()))),
dtype: DType::Uint8,
}
}
DType::Uint16 => {
let view = extract_view_u16(wrapper, meta).expect("Type mismatch");
NDArrayWrapper {
data: ArrayData::Uint16(Arc::new(RwLock::new(view.to_owned()))),
dtype: DType::Uint16,
}
}
Err(e) => {
error::set_last_error(e);
ERR_GENERIC
DType::Uint32 => {
let view = extract_view_u32(wrapper, meta).expect("Type mismatch");
NDArrayWrapper {
data: ArrayData::Uint32(Arc::new(RwLock::new(view.to_owned()))),
dtype: DType::Uint32,
}
}
}
DType::Uint64 => {
let view = extract_view_u64(wrapper, meta).expect("Type mismatch");
NDArrayWrapper {
data: ArrayData::Uint64(Arc::new(RwLock::new(view.to_owned()))),
dtype: DType::Uint64,
}
}
DType::Float32 => {
let view = extract_view_f32(wrapper, meta).expect("Type mismatch");
NDArrayWrapper {
data: ArrayData::Float32(Arc::new(RwLock::new(view.to_owned()))),
dtype: DType::Float32,
}
}
DType::Float64 => {
let view = extract_view_f64(wrapper, meta).expect("Type mismatch");
NDArrayWrapper {
data: ArrayData::Float64(Arc::new(RwLock::new(view.to_owned()))),
dtype: DType::Float64,
}
}
DType::Bool => {
let view = extract_view_bool(wrapper, meta).expect("Type mismatch");
NDArrayWrapper {
data: ArrayData::Bool(Arc::new(RwLock::new(view.to_owned()))),
dtype: DType::Bool,
}
}
};

*out_handle = NdArrayHandle::from_wrapper(Box::new(new_wrapper));
crate::helpers::error::SUCCESS
})
}
151 changes: 131 additions & 20 deletions rust/src/ffi/indexing/assign.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@
//!
//! Provides assign operations between strided array views.

use crate::helpers::error::{self, ERR_GENERIC, SUCCESS};
use crate::helpers::error::{set_last_error, ERR_GENERIC, SUCCESS};
use crate::helpers::view::{
extract_view_bool, extract_view_f32, extract_view_f64, extract_view_i16, extract_view_i32,
extract_view_i64, extract_view_i8, extract_view_mut_bool, extract_view_mut_f32,
extract_view_mut_f64, extract_view_mut_i16, extract_view_mut_i32, extract_view_mut_i64,
extract_view_mut_i8, extract_view_mut_u16, extract_view_mut_u32, extract_view_mut_u64,
extract_view_mut_u8, extract_view_u16, extract_view_u32, extract_view_u64, extract_view_u8,
};
use crate::types::dtype::DType;
use crate::types::{ArrayMetadata, NdArrayHandle};

Expand Down Expand Up @@ -31,26 +38,130 @@ pub unsafe extern "C" fn ndarray_assign(
let dst_meta = &*dst_meta;
let src_meta = &*src_meta;

let result = match dst_wrapper.dtype {
DType::Int8 => dst_wrapper.assign_slice_i8(dst_meta, src_wrapper, src_meta),
DType::Int16 => dst_wrapper.assign_slice_i16(dst_meta, src_wrapper, src_meta),
DType::Int32 => dst_wrapper.assign_slice_i32(dst_meta, src_wrapper, src_meta),
DType::Int64 => dst_wrapper.assign_slice_i64(dst_meta, src_wrapper, src_meta),
DType::Uint8 => dst_wrapper.assign_slice_u8(dst_meta, src_wrapper, src_meta),
DType::Uint16 => dst_wrapper.assign_slice_u16(dst_meta, src_wrapper, src_meta),
DType::Uint32 => dst_wrapper.assign_slice_u32(dst_meta, src_wrapper, src_meta),
DType::Uint64 => dst_wrapper.assign_slice_u64(dst_meta, src_wrapper, src_meta),
DType::Float32 => dst_wrapper.assign_slice_f32(dst_meta, src_wrapper, src_meta),
DType::Float64 => dst_wrapper.assign_slice_f64(dst_meta, src_wrapper, src_meta),
DType::Bool => dst_wrapper.assign_slice_bool(dst_meta, src_wrapper, src_meta),
};

match result {
Ok(()) => SUCCESS,
Err(e) => {
error::set_last_error(e);
ERR_GENERIC
// Validate dtypes match
if dst_wrapper.dtype != src_wrapper.dtype {
set_last_error(format!(
"DType mismatch in assign: dst={:?}, src={:?}",
dst_wrapper.dtype, src_wrapper.dtype
));
return ERR_GENERIC;
}

let is_same = dst_wrapper.is_same_array(src_wrapper);

match dst_wrapper.dtype {
DType::Int8 => {
let src_view = extract_view_i8(src_wrapper, src_meta).expect("Type mismatch");
let mut dst_view =
extract_view_mut_i8(dst_wrapper, dst_meta).expect("Type mismatch");
if is_same {
dst_view.assign(&src_view.to_owned());
} else {
dst_view.assign(&src_view);
}
}
DType::Int16 => {
let src_view = extract_view_i16(src_wrapper, src_meta).expect("Type mismatch");
let mut dst_view =
extract_view_mut_i16(dst_wrapper, dst_meta).expect("Type mismatch");
if is_same {
dst_view.assign(&src_view.to_owned());
} else {
dst_view.assign(&src_view);
}
}
DType::Int32 => {
let src_view = extract_view_i32(src_wrapper, src_meta).expect("Type mismatch");
let mut dst_view =
extract_view_mut_i32(dst_wrapper, dst_meta).expect("Type mismatch");
if is_same {
dst_view.assign(&src_view.to_owned());
} else {
dst_view.assign(&src_view);
}
}
DType::Int64 => {
let src_view = extract_view_i64(src_wrapper, src_meta).expect("Type mismatch");
let mut dst_view =
extract_view_mut_i64(dst_wrapper, dst_meta).expect("Type mismatch");
if is_same {
dst_view.assign(&src_view.to_owned());
} else {
dst_view.assign(&src_view);
}
}
DType::Uint8 => {
let src_view = extract_view_u8(src_wrapper, src_meta).expect("Type mismatch");
let mut dst_view =
extract_view_mut_u8(dst_wrapper, dst_meta).expect("Type mismatch");
if is_same {
dst_view.assign(&src_view.to_owned());
} else {
dst_view.assign(&src_view);
}
}
DType::Uint16 => {
let src_view = extract_view_u16(src_wrapper, src_meta).expect("Type mismatch");
let mut dst_view =
extract_view_mut_u16(dst_wrapper, dst_meta).expect("Type mismatch");
if is_same {
dst_view.assign(&src_view.to_owned());
} else {
dst_view.assign(&src_view);
}
}
DType::Uint32 => {
let src_view = extract_view_u32(src_wrapper, src_meta).expect("Type mismatch");
let mut dst_view =
extract_view_mut_u32(dst_wrapper, dst_meta).expect("Type mismatch");
if is_same {
dst_view.assign(&src_view.to_owned());
} else {
dst_view.assign(&src_view);
}
}
DType::Uint64 => {
let src_view = extract_view_u64(src_wrapper, src_meta).expect("Type mismatch");
let mut dst_view =
extract_view_mut_u64(dst_wrapper, dst_meta).expect("Type mismatch");
if is_same {
dst_view.assign(&src_view.to_owned());
} else {
dst_view.assign(&src_view);
}
}
DType::Float32 => {
let src_view = extract_view_f32(src_wrapper, src_meta).expect("Type mismatch");
let mut dst_view =
extract_view_mut_f32(dst_wrapper, dst_meta).expect("Type mismatch");
if is_same {
dst_view.assign(&src_view.to_owned());
} else {
dst_view.assign(&src_view);
}
}
DType::Float64 => {
let src_view = extract_view_f64(src_wrapper, src_meta).expect("Type mismatch");
let mut dst_view =
extract_view_mut_f64(dst_wrapper, dst_meta).expect("Type mismatch");
if is_same {
dst_view.assign(&src_view.to_owned());
} else {
dst_view.assign(&src_view);
}
}
DType::Bool => {
let src_view = extract_view_bool(src_wrapper, src_meta).expect("Type mismatch");
let mut dst_view =
extract_view_mut_bool(dst_wrapper, dst_meta).expect("Type mismatch");
if is_same {
dst_view.assign(&src_view.to_owned());
} else {
dst_view.assign(&src_view);
}
}
}

SUCCESS
})
}
Loading
Loading