|
| 1 | +use std::cell::UnsafeCell; |
| 2 | + |
| 3 | +use rand::Rng; |
| 4 | +use rand_distr::Distribution; |
| 5 | +use wide::f32x8; |
| 6 | + |
| 7 | +use super::fill_f32_zero_one; |
| 8 | + |
| 9 | +pub struct SimdUniform { |
| 10 | + low: f32, |
| 11 | + scale: f32, // = high - low |
| 12 | + buffer: UnsafeCell<[f32; 8]>, |
| 13 | + index: UnsafeCell<usize>, |
| 14 | +} |
| 15 | + |
| 16 | +impl SimdUniform { |
| 17 | + pub fn new(low: f32, high: f32) -> Self { |
| 18 | + assert!(high > low, "SimdUniform: high must be greater than low"); |
| 19 | + assert!(low.is_finite() && high.is_finite(), "bounds must be finite"); |
| 20 | + Self { |
| 21 | + low, |
| 22 | + scale: high - low, |
| 23 | + buffer: UnsafeCell::new([0.0; 8]), |
| 24 | + index: UnsafeCell::new(8), // kényszerít első refill-t |
| 25 | + } |
| 26 | + } |
| 27 | + |
| 28 | + pub fn unit() -> Self { |
| 29 | + Self::new(0.0, 1.0) |
| 30 | + } |
| 31 | + |
| 32 | + #[inline] |
| 33 | + fn refill<R: Rng + ?Sized>(&self, rng: &mut R) { |
| 34 | + let mut u = [0.0f32; 8]; |
| 35 | + fill_f32_zero_one(rng, &mut u); |
| 36 | + let u = f32x8::from(u); |
| 37 | + |
| 38 | + let vals = f32x8::splat(self.low) + u * f32x8::splat(self.scale); |
| 39 | + |
| 40 | + unsafe { |
| 41 | + *self.buffer.get() = vals.to_array(); |
| 42 | + *self.index.get() = 0; |
| 43 | + } |
| 44 | + } |
| 45 | +} |
| 46 | + |
| 47 | +impl Distribution<f32> for SimdUniform { |
| 48 | + #[inline] |
| 49 | + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f32 { |
| 50 | + let index = unsafe { &mut *self.index.get() }; |
| 51 | + if *index >= 8 { |
| 52 | + self.refill(rng); |
| 53 | + } |
| 54 | + let val = unsafe { (*self.buffer.get())[*index] }; |
| 55 | + *index += 1; |
| 56 | + val |
| 57 | + } |
| 58 | +} |
0 commit comments