diff --git a/ndarray-rand/src/lib.rs b/ndarray-rand/src/lib.rs index 8ee2cda75..d155695aa 100644 --- a/ndarray-rand/src/lib.rs +++ b/ndarray-rand/src/lib.rs @@ -29,12 +29,14 @@ //! that the items are not compatible (e.g. that a type doesn't implement a //! necessary trait). +#![warn(missing_docs)] + use crate::rand::distr::{Distribution, Uniform}; use crate::rand::rngs::SmallRng; use crate::rand::seq::index; use crate::rand::{rng, Rng, SeedableRng}; -use ndarray::{Array, Axis, RemoveAxis, ShapeBuilder}; +use ndarray::{Array, ArrayRef, Axis, RemoveAxis, ShapeBuilder}; use ndarray::{ArrayBase, Data, DataOwned, Dimension, RawData}; #[cfg(feature = "quickcheck")] use quickcheck::{Arbitrary, Gen}; @@ -51,18 +53,15 @@ pub mod rand_distr pub use rand_distr::*; } -/// Constructors for n-dimensional arrays with random elements. -/// -/// This trait extends ndarray’s `ArrayBase` and can not be implemented -/// for other types. +/// Extension trait for constructing n-dimensional arrays with random elements. /// /// The default RNG is a fast automatically seeded rng (currently -/// [`rand::rngs::SmallRng`], seeded from [`rand::thread_rng`]). +/// [`rand::rngs::SmallRng`], seeded from [`rand::rng`]). /// /// Note that `SmallRng` is cheap to initialize and fast, but it may generate /// low-quality random numbers, and reproducibility is not guaranteed. See its /// documentation for information. You can select a different RNG with -/// [`.random_using()`](Self::random_using). +/// [`.random_using()`](RandomExt::random_using). pub trait RandomExt where S: RawData, @@ -124,6 +123,40 @@ where S: DataOwned, Sh: ShapeBuilder; + /// Sample `n_samples` lanes slicing along `axis` using the default RNG. + /// + /// See [`RandomRefExt::sample_axis`] for additional information. + fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array + where + A: Copy, + S: Data, + D: RemoveAxis; + + /// Sample `n_samples` lanes slicing along `axis` using the specified RNG `rng`. + /// + /// See [`RandomRefExt::sample_axis_using`] for additional information. + fn sample_axis_using( + &self, axis: Axis, n_samples: usize, strategy: SamplingStrategy, rng: &mut R, + ) -> Array + where + R: Rng + ?Sized, + A: Copy, + S: Data, + D: RemoveAxis; +} + +/// Extension trait for sampling from [`ArrayRef`] with random elements. +/// +/// The default RNG is a fast, automatically seeded rng (currently +/// [`rand::rngs::SmallRng`], seeded from [`rand::rng`]). +/// +/// Note that `SmallRng` is cheap to initialize and fast, but it may generate +/// low-quality random numbers, and reproducibility is not guaranteed. See its +/// documentation for information. You can select a different RNG with +/// [`.sample_axis_using()`](RandomRefExt::sample_axis_using). +pub trait RandomRefExt +where D: Dimension +{ /// Sample `n_samples` lanes slicing along `axis` using the default RNG. /// /// If `strategy==SamplingStrategy::WithoutReplacement`, each lane can only be sampled once. @@ -168,7 +201,6 @@ where fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array where A: Copy, - S: Data, D: RemoveAxis; /// Sample `n_samples` lanes slicing along `axis` using the specified RNG `rng`. @@ -225,7 +257,6 @@ where where R: Rng + ?Sized, A: Copy, - S: Data, D: RemoveAxis; } @@ -259,7 +290,7 @@ where S: Data, D: RemoveAxis, { - self.sample_axis_using(axis, n_samples, strategy, &mut get_rng()) + (**self).sample_axis(axis, n_samples, strategy) } fn sample_axis_using(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy, rng: &mut R) -> Array @@ -268,6 +299,27 @@ where A: Copy, S: Data, D: RemoveAxis, + { + (**self).sample_axis_using(axis, n_samples, strategy, rng) + } +} + +impl RandomRefExt for ArrayRef +where D: Dimension +{ + fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array + where + A: Copy, + D: RemoveAxis, + { + self.sample_axis_using(axis, n_samples, strategy, &mut get_rng()) + } + + fn sample_axis_using(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy, rng: &mut R) -> Array + where + R: Rng + ?Sized, + A: Copy, + D: RemoveAxis, { let indices: Vec<_> = match strategy { SamplingStrategy::WithReplacement => { @@ -284,9 +336,10 @@ where /// if lanes from the original array should only be sampled once (*without replacement*) or /// multiple times (*with replacement*). /// -/// [`sample_axis`]: RandomExt::sample_axis -/// [`sample_axis_using`]: RandomExt::sample_axis_using +/// [`sample_axis`]: RandomRefExt::sample_axis +/// [`sample_axis_using`]: RandomRefExt::sample_axis_using #[derive(Debug, Clone)] +#[allow(missing_docs)] pub enum SamplingStrategy { WithReplacement,