Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 28 additions & 34 deletions rust/src/ffi/generators/geomspace.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
//! Create numbers spaced geometrically from start to stop.

use ndarray::{ArrayD, IxDyn};
use ndarray::{Array, ArrayD};
use parking_lot::RwLock;
use std::sync::Arc;

use crate::helpers::error::{set_last_error, ERR_DTYPE, ERR_GENERIC, SUCCESS};
use crate::types::dtype::DType;
use crate::types::{ArrayData, NDArrayWrapper, NdArrayHandle};

Expand All @@ -20,56 +21,49 @@ pub unsafe extern "C" fn ndarray_geomspace(
out_handle: *mut *mut NdArrayHandle,
) -> i32 {
if out_handle.is_null() || num == 0 {
return crate::helpers::error::ERR_GENERIC;
return ERR_GENERIC;
}

crate::ffi_guard!({
let dtype_enum = match DType::from_u8(dtype) {
Some(d) => d,
None => return crate::helpers::error::ERR_DTYPE,
None => return ERR_DTYPE,
};

// Validate: start and stop must have same sign and neither can be zero
if start == 0.0 || stop == 0.0 || (start > 0.0) != (stop > 0.0) {
return crate::helpers::error::ERR_GENERIC;
}

match dtype_enum {
let result_wrapper = match dtype_enum {
DType::Float32 => {
let data: Vec<f32> = if num == 1 {
vec![start as f32]
} else {
let ratio = (stop / start).powf(1.0 / ((num - 1) as f64));
(0..num)
.map(|i| (start * ratio.powi(i as i32)) as f32)
.collect()
let arr: ArrayD<f32> = match Array::geomspace(start as f32, stop as f32, num) {
Some(arr1) => arr1.into_dyn(),
None => {
set_last_error("geomspace requires start and stop to have the same sign and be non-zero".to_string());
return ERR_GENERIC;
}
};
let arr = ArrayD::<f32>::from_shape_vec(IxDyn(&[num]), data)
.expect("Shape mismatch should not happen");
let wrapper = NDArrayWrapper {
NDArrayWrapper {
data: ArrayData::Float32(Arc::new(RwLock::new(arr))),
dtype: DType::Float32,
};
*out_handle = NdArrayHandle::from_wrapper(Box::new(wrapper));
}
}
DType::Float64 => {
let data: Vec<f64> = if num == 1 {
vec![start]
} else {
let ratio = (stop / start).powf(1.0 / ((num - 1) as f64));
(0..num).map(|i| start * ratio.powi(i as i32)).collect()
let arr: ArrayD<f64> = match Array::geomspace(start, stop, num) {
Some(arr1) => arr1.into_dyn(),
None => {
set_last_error("geomspace requires start and stop to have the same sign and be non-zero".to_string());
return ERR_GENERIC;
}
};
let arr = ArrayD::<f64>::from_shape_vec(IxDyn(&[num]), data)
.expect("Shape mismatch should not happen");
let wrapper = NDArrayWrapper {
NDArrayWrapper {
data: ArrayData::Float64(Arc::new(RwLock::new(arr))),
dtype: DType::Float64,
};
*out_handle = NdArrayHandle::from_wrapper(Box::new(wrapper));
}
}
_ => return crate::helpers::error::ERR_DTYPE,
}
_ => {
set_last_error("geomspace() requires float type (Float64 or Float32)".to_string());
return ERR_DTYPE;
}
};
*out_handle = NdArrayHandle::from_wrapper(Box::new(result_wrapper));

crate::helpers::error::SUCCESS
SUCCESS
})
}
65 changes: 30 additions & 35 deletions rust/src/ffi/generators/linspace.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
//! Create evenly spaced numbers over a specified interval.

use ndarray::{ArrayD, IxDyn};
use ndarray::{Array, ArrayD};
use parking_lot::RwLock;
use std::sync::Arc;

use crate::helpers::error::{set_last_error, ERR_DTYPE, ERR_GENERIC, SUCCESS};
use crate::types::dtype::DType;
use crate::types::{ArrayData, NDArrayWrapper, NdArrayHandle};

Expand All @@ -21,59 +22,53 @@ pub unsafe extern "C" fn ndarray_linspace(
out_handle: *mut *mut NdArrayHandle,
) -> i32 {
if out_handle.is_null() || num == 0 {
return crate::helpers::error::ERR_GENERIC;
return ERR_GENERIC;
}

crate::ffi_guard!({
let dtype_enum = match DType::from_u8(dtype) {
Some(d) => d,
None => return crate::helpers::error::ERR_DTYPE,
None => return ERR_DTYPE,
};

match dtype_enum {
let result_wrapper = match dtype_enum {
DType::Float32 => {
let data: Vec<f32> = if num == 1 {
vec![start as f32]
let adjusted_stop = if endpoint {
stop as f32
} else {
let step = (stop - start)
/ (if endpoint {
(num - 1) as f64
} else {
num as f64
});
(0..num).map(|i| (start + step * i as f64) as f32).collect()
let step = ((stop - start) / (num as f64)) as f32;
(start as f32) + step * ((num - 1) as f32)
};
let arr = ArrayD::<f32>::from_shape_vec(IxDyn(&[num]), data)
.expect("Shape mismatch should not happen");
let wrapper = NDArrayWrapper {

let arr: ArrayD<f32> = Array::linspace(start as f32, adjusted_stop, num).into_dyn();

NDArrayWrapper {
data: ArrayData::Float32(Arc::new(RwLock::new(arr))),
dtype: DType::Float32,
};
*out_handle = NdArrayHandle::from_wrapper(Box::new(wrapper));
}
}
DType::Float64 => {
let data: Vec<f64> = if num == 1 {
vec![start]
let adjusted_stop = if endpoint {
stop
} else {
let step = (stop - start)
/ (if endpoint {
(num - 1) as f64
} else {
num as f64
});
(0..num).map(|i| start + step * i as f64).collect()
let step = (stop - start) / (num as f64);
start + step * ((num - 1) as f64)
};
let arr = ArrayD::<f64>::from_shape_vec(IxDyn(&[num]), data)
.expect("Shape mismatch should not happen");
let wrapper = NDArrayWrapper {

let arr: ArrayD<f64> = Array::linspace(start, adjusted_stop, num).into_dyn();

NDArrayWrapper {
data: ArrayData::Float64(Arc::new(RwLock::new(arr))),
dtype: DType::Float64,
};
*out_handle = NdArrayHandle::from_wrapper(Box::new(wrapper));
}
}
_ => return crate::helpers::error::ERR_DTYPE,
}
_ => {
set_last_error("linspace() requires float type (Float64 or Float32)".to_string());
return ERR_DTYPE;
}
};
*out_handle = NdArrayHandle::from_wrapper(Box::new(result_wrapper));

crate::helpers::error::SUCCESS
SUCCESS
})
}
52 changes: 19 additions & 33 deletions rust/src/ffi/generators/logspace.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
//! Create numbers spaced evenly on a log scale.

use ndarray::{ArrayD, IxDyn};
use ndarray::{Array, ArrayD};
use parking_lot::RwLock;
use std::sync::Arc;

use crate::helpers::error::{set_last_error, ERR_DTYPE, ERR_GENERIC, SUCCESS};
use crate::types::dtype::DType;
use crate::types::{ArrayData, NDArrayWrapper, NdArrayHandle};

Expand All @@ -21,53 +22,38 @@ pub unsafe extern "C" fn ndarray_logspace(
out_handle: *mut *mut NdArrayHandle,
) -> i32 {
if out_handle.is_null() || num == 0 {
return crate::helpers::error::ERR_GENERIC;
return ERR_GENERIC;
}

crate::ffi_guard!({
let dtype_enum = match DType::from_u8(dtype) {
Some(d) => d,
None => return crate::helpers::error::ERR_DTYPE,
None => return ERR_DTYPE,
};

match dtype_enum {
let result_wrapper = match dtype_enum {
DType::Float32 => {
let data: Vec<f32> = if num == 1 {
vec![base.powf(start) as f32]
} else {
let step = (stop - start) / ((num - 1) as f64);
(0..num)
.map(|i| base.powf(start + step * i as f64) as f32)
.collect()
};
let arr = ArrayD::<f32>::from_shape_vec(IxDyn(&[num]), data)
.expect("Shape mismatch should not happen");
let wrapper = NDArrayWrapper {
let arr: ArrayD<f32> =
Array::logspace(base as f32, start as f32, stop as f32, num).into_dyn();
NDArrayWrapper {
data: ArrayData::Float32(Arc::new(RwLock::new(arr))),
dtype: DType::Float32,
};
*out_handle = NdArrayHandle::from_wrapper(Box::new(wrapper));
}
}
DType::Float64 => {
let data: Vec<f64> = if num == 1 {
vec![base.powf(start)]
} else {
let step = (stop - start) / ((num - 1) as f64);
(0..num)
.map(|i| base.powf(start + step * i as f64))
.collect()
};
let arr = ArrayD::<f64>::from_shape_vec(IxDyn(&[num]), data)
.expect("Shape mismatch should not happen");
let wrapper = NDArrayWrapper {
let arr: ArrayD<f64> = Array::logspace(base, start, stop, num).into_dyn();
NDArrayWrapper {
data: ArrayData::Float64(Arc::new(RwLock::new(arr))),
dtype: DType::Float64,
};
*out_handle = NdArrayHandle::from_wrapper(Box::new(wrapper));
}
}
_ => return crate::helpers::error::ERR_DTYPE,
}
_ => {
set_last_error("logspace() requires float type (Float64 or Float32)".to_string());
return ERR_DTYPE;
}
};
*out_handle = NdArrayHandle::from_wrapper(Box::new(result_wrapper));

crate::helpers::error::SUCCESS
SUCCESS
})
}
Loading