From c21e0b49f8b5e2ad96a630e4074f3e30757d18f4 Mon Sep 17 00:00:00 2001 From: Kyrian Obikwelu Date: Sun, 22 Mar 2026 20:47:19 +0100 Subject: [PATCH] refactor(rust): use ndarray's built-in generator functions Replace manual implementations of linspace, logspace, and geomspace with ndarray's Array::linspace, Array::logspace, and Array::geomspace methods.This leverages ndarray's optimized and well-tested implementations while simplifying our FFI code. --- rust/src/ffi/generators/geomspace.rs | 62 ++++++++++++-------------- rust/src/ffi/generators/linspace.rs | 65 +++++++++++++--------------- rust/src/ffi/generators/logspace.rs | 52 ++++++++-------------- 3 files changed, 77 insertions(+), 102 deletions(-) diff --git a/rust/src/ffi/generators/geomspace.rs b/rust/src/ffi/generators/geomspace.rs index 6211ce0..8b7514f 100644 --- a/rust/src/ffi/generators/geomspace.rs +++ b/rust/src/ffi/generators/geomspace.rs @@ -1,9 +1,10 @@ //! Create numbers spaced geometrically from start to stop. -use ndarray::{ArrayD, IxDyn}; +use ndarray::{Array, ArrayD}; use parking_lot::RwLock; use std::sync::Arc; +use crate::helpers::error::{set_last_error, ERR_DTYPE, ERR_GENERIC, SUCCESS}; use crate::types::dtype::DType; use crate::types::{ArrayData, NDArrayWrapper, NdArrayHandle}; @@ -20,56 +21,49 @@ pub unsafe extern "C" fn ndarray_geomspace( out_handle: *mut *mut NdArrayHandle, ) -> i32 { if out_handle.is_null() || num == 0 { - return crate::helpers::error::ERR_GENERIC; + return ERR_GENERIC; } crate::ffi_guard!({ let dtype_enum = match DType::from_u8(dtype) { Some(d) => d, - None => return crate::helpers::error::ERR_DTYPE, + None => return ERR_DTYPE, }; - // Validate: start and stop must have same sign and neither can be zero - if start == 0.0 || stop == 0.0 || (start > 0.0) != (stop > 0.0) { - return crate::helpers::error::ERR_GENERIC; - } - - match dtype_enum { + let result_wrapper = match dtype_enum { DType::Float32 => { - let data: Vec = if num == 1 { - vec![start as f32] - } else { - let ratio = (stop / start).powf(1.0 / ((num - 1) as f64)); - (0..num) - .map(|i| (start * ratio.powi(i as i32)) as f32) - .collect() + let arr: ArrayD = match Array::geomspace(start as f32, stop as f32, num) { + Some(arr1) => arr1.into_dyn(), + None => { + set_last_error("geomspace requires start and stop to have the same sign and be non-zero".to_string()); + return ERR_GENERIC; + } }; - let arr = ArrayD::::from_shape_vec(IxDyn(&[num]), data) - .expect("Shape mismatch should not happen"); - let wrapper = NDArrayWrapper { + NDArrayWrapper { data: ArrayData::Float32(Arc::new(RwLock::new(arr))), dtype: DType::Float32, - }; - *out_handle = NdArrayHandle::from_wrapper(Box::new(wrapper)); + } } DType::Float64 => { - let data: Vec = if num == 1 { - vec![start] - } else { - let ratio = (stop / start).powf(1.0 / ((num - 1) as f64)); - (0..num).map(|i| start * ratio.powi(i as i32)).collect() + let arr: ArrayD = match Array::geomspace(start, stop, num) { + Some(arr1) => arr1.into_dyn(), + None => { + set_last_error("geomspace requires start and stop to have the same sign and be non-zero".to_string()); + return ERR_GENERIC; + } }; - let arr = ArrayD::::from_shape_vec(IxDyn(&[num]), data) - .expect("Shape mismatch should not happen"); - let wrapper = NDArrayWrapper { + NDArrayWrapper { data: ArrayData::Float64(Arc::new(RwLock::new(arr))), dtype: DType::Float64, - }; - *out_handle = NdArrayHandle::from_wrapper(Box::new(wrapper)); + } } - _ => return crate::helpers::error::ERR_DTYPE, - } + _ => { + set_last_error("geomspace() requires float type (Float64 or Float32)".to_string()); + return ERR_DTYPE; + } + }; + *out_handle = NdArrayHandle::from_wrapper(Box::new(result_wrapper)); - crate::helpers::error::SUCCESS + SUCCESS }) } diff --git a/rust/src/ffi/generators/linspace.rs b/rust/src/ffi/generators/linspace.rs index ab5ba37..6c93148 100644 --- a/rust/src/ffi/generators/linspace.rs +++ b/rust/src/ffi/generators/linspace.rs @@ -1,9 +1,10 @@ //! Create evenly spaced numbers over a specified interval. -use ndarray::{ArrayD, IxDyn}; +use ndarray::{Array, ArrayD}; use parking_lot::RwLock; use std::sync::Arc; +use crate::helpers::error::{set_last_error, ERR_DTYPE, ERR_GENERIC, SUCCESS}; use crate::types::dtype::DType; use crate::types::{ArrayData, NDArrayWrapper, NdArrayHandle}; @@ -21,59 +22,53 @@ pub unsafe extern "C" fn ndarray_linspace( out_handle: *mut *mut NdArrayHandle, ) -> i32 { if out_handle.is_null() || num == 0 { - return crate::helpers::error::ERR_GENERIC; + return ERR_GENERIC; } crate::ffi_guard!({ let dtype_enum = match DType::from_u8(dtype) { Some(d) => d, - None => return crate::helpers::error::ERR_DTYPE, + None => return ERR_DTYPE, }; - match dtype_enum { + let result_wrapper = match dtype_enum { DType::Float32 => { - let data: Vec = if num == 1 { - vec![start as f32] + let adjusted_stop = if endpoint { + stop as f32 } else { - let step = (stop - start) - / (if endpoint { - (num - 1) as f64 - } else { - num as f64 - }); - (0..num).map(|i| (start + step * i as f64) as f32).collect() + let step = ((stop - start) / (num as f64)) as f32; + (start as f32) + step * ((num - 1) as f32) }; - let arr = ArrayD::::from_shape_vec(IxDyn(&[num]), data) - .expect("Shape mismatch should not happen"); - let wrapper = NDArrayWrapper { + + let arr: ArrayD = Array::linspace(start as f32, adjusted_stop, num).into_dyn(); + + NDArrayWrapper { data: ArrayData::Float32(Arc::new(RwLock::new(arr))), dtype: DType::Float32, - }; - *out_handle = NdArrayHandle::from_wrapper(Box::new(wrapper)); + } } DType::Float64 => { - let data: Vec = if num == 1 { - vec![start] + let adjusted_stop = if endpoint { + stop } else { - let step = (stop - start) - / (if endpoint { - (num - 1) as f64 - } else { - num as f64 - }); - (0..num).map(|i| start + step * i as f64).collect() + let step = (stop - start) / (num as f64); + start + step * ((num - 1) as f64) }; - let arr = ArrayD::::from_shape_vec(IxDyn(&[num]), data) - .expect("Shape mismatch should not happen"); - let wrapper = NDArrayWrapper { + + let arr: ArrayD = Array::linspace(start, adjusted_stop, num).into_dyn(); + + NDArrayWrapper { data: ArrayData::Float64(Arc::new(RwLock::new(arr))), dtype: DType::Float64, - }; - *out_handle = NdArrayHandle::from_wrapper(Box::new(wrapper)); + } } - _ => return crate::helpers::error::ERR_DTYPE, - } + _ => { + set_last_error("linspace() requires float type (Float64 or Float32)".to_string()); + return ERR_DTYPE; + } + }; + *out_handle = NdArrayHandle::from_wrapper(Box::new(result_wrapper)); - crate::helpers::error::SUCCESS + SUCCESS }) } diff --git a/rust/src/ffi/generators/logspace.rs b/rust/src/ffi/generators/logspace.rs index 53d2ee9..846d292 100644 --- a/rust/src/ffi/generators/logspace.rs +++ b/rust/src/ffi/generators/logspace.rs @@ -1,9 +1,10 @@ //! Create numbers spaced evenly on a log scale. -use ndarray::{ArrayD, IxDyn}; +use ndarray::{Array, ArrayD}; use parking_lot::RwLock; use std::sync::Arc; +use crate::helpers::error::{set_last_error, ERR_DTYPE, ERR_GENERIC, SUCCESS}; use crate::types::dtype::DType; use crate::types::{ArrayData, NDArrayWrapper, NdArrayHandle}; @@ -21,53 +22,38 @@ pub unsafe extern "C" fn ndarray_logspace( out_handle: *mut *mut NdArrayHandle, ) -> i32 { if out_handle.is_null() || num == 0 { - return crate::helpers::error::ERR_GENERIC; + return ERR_GENERIC; } crate::ffi_guard!({ let dtype_enum = match DType::from_u8(dtype) { Some(d) => d, - None => return crate::helpers::error::ERR_DTYPE, + None => return ERR_DTYPE, }; - match dtype_enum { + let result_wrapper = match dtype_enum { DType::Float32 => { - let data: Vec = if num == 1 { - vec![base.powf(start) as f32] - } else { - let step = (stop - start) / ((num - 1) as f64); - (0..num) - .map(|i| base.powf(start + step * i as f64) as f32) - .collect() - }; - let arr = ArrayD::::from_shape_vec(IxDyn(&[num]), data) - .expect("Shape mismatch should not happen"); - let wrapper = NDArrayWrapper { + let arr: ArrayD = + Array::logspace(base as f32, start as f32, stop as f32, num).into_dyn(); + NDArrayWrapper { data: ArrayData::Float32(Arc::new(RwLock::new(arr))), dtype: DType::Float32, - }; - *out_handle = NdArrayHandle::from_wrapper(Box::new(wrapper)); + } } DType::Float64 => { - let data: Vec = if num == 1 { - vec![base.powf(start)] - } else { - let step = (stop - start) / ((num - 1) as f64); - (0..num) - .map(|i| base.powf(start + step * i as f64)) - .collect() - }; - let arr = ArrayD::::from_shape_vec(IxDyn(&[num]), data) - .expect("Shape mismatch should not happen"); - let wrapper = NDArrayWrapper { + let arr: ArrayD = Array::logspace(base, start, stop, num).into_dyn(); + NDArrayWrapper { data: ArrayData::Float64(Arc::new(RwLock::new(arr))), dtype: DType::Float64, - }; - *out_handle = NdArrayHandle::from_wrapper(Box::new(wrapper)); + } } - _ => return crate::helpers::error::ERR_DTYPE, - } + _ => { + set_last_error("logspace() requires float type (Float64 or Float32)".to_string()); + return ERR_DTYPE; + } + }; + *out_handle = NdArrayHandle::from_wrapper(Box::new(result_wrapper)); - crate::helpers::error::SUCCESS + SUCCESS }) }