Skip to content

Commit 1996c09

Browse files
committed
feat: add simd uniform
1 parent 39bd711 commit 1996c09

File tree

3 files changed

+60
-0
lines changed

3 files changed

+60
-0
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ ndarray-stats = "0.6.0"
4848
ndrustfft = "0.5.0"
4949
num-complex = { version = "0.4.6", features = ["rand"] }
5050
ordered-float = "5.0.0"
51+
# orx-parallel = "2.3.0"
5152
plotly = { version = "0.10.0", features = ["plotly_ndarray"] }
5253
polars = { version = "0.43.1", features = ["lazy"], optional = true }
5354
prettytable-rs = "0.10.0"

src/stats/distr.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ pub mod normal_inverse_gauss;
1515
pub mod pareto;
1616
pub mod poisson;
1717
pub mod studentt;
18+
pub mod uniform;
1819
pub mod weibull;
1920

2021
/// Fills a slice with random floating-point values in the range [0, 1).

src/stats/distr/uniform.rs

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
use std::cell::UnsafeCell;
2+
3+
use rand::Rng;
4+
use rand_distr::Distribution;
5+
use wide::f32x8;
6+
7+
use super::fill_f32_zero_one;
8+
9+
pub struct SimdUniform {
10+
low: f32,
11+
scale: f32, // = high - low
12+
buffer: UnsafeCell<[f32; 8]>,
13+
index: UnsafeCell<usize>,
14+
}
15+
16+
impl SimdUniform {
17+
pub fn new(low: f32, high: f32) -> Self {
18+
assert!(high > low, "SimdUniform: high must be greater than low");
19+
assert!(low.is_finite() && high.is_finite(), "bounds must be finite");
20+
Self {
21+
low,
22+
scale: high - low,
23+
buffer: UnsafeCell::new([0.0; 8]),
24+
index: UnsafeCell::new(8), // kényszerít első refill-t
25+
}
26+
}
27+
28+
pub fn unit() -> Self {
29+
Self::new(0.0, 1.0)
30+
}
31+
32+
#[inline]
33+
fn refill<R: Rng + ?Sized>(&self, rng: &mut R) {
34+
let mut u = [0.0f32; 8];
35+
fill_f32_zero_one(rng, &mut u);
36+
let u = f32x8::from(u);
37+
38+
let vals = f32x8::splat(self.low) + u * f32x8::splat(self.scale);
39+
40+
unsafe {
41+
*self.buffer.get() = vals.to_array();
42+
*self.index.get() = 0;
43+
}
44+
}
45+
}
46+
47+
impl Distribution<f32> for SimdUniform {
48+
#[inline]
49+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f32 {
50+
let index = unsafe { &mut *self.index.get() };
51+
if *index >= 8 {
52+
self.refill(rng);
53+
}
54+
let val = unsafe { (*self.buffer.get())[*index] };
55+
*index += 1;
56+
val
57+
}
58+
}

0 commit comments

Comments
 (0)