Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 65 additions & 12 deletions ndarray-rand/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@
//! that the items are not compatible (e.g. that a type doesn't implement a
//! necessary trait).

#![warn(missing_docs)]

use crate::rand::distr::{Distribution, Uniform};
use crate::rand::rngs::SmallRng;
use crate::rand::seq::index;
use crate::rand::{rng, Rng, SeedableRng};

use ndarray::{Array, Axis, RemoveAxis, ShapeBuilder};
use ndarray::{Array, ArrayRef, Axis, RemoveAxis, ShapeBuilder};
use ndarray::{ArrayBase, Data, DataOwned, Dimension, RawData};
#[cfg(feature = "quickcheck")]
use quickcheck::{Arbitrary, Gen};
Expand All @@ -51,18 +53,15 @@ pub mod rand_distr
pub use rand_distr::*;
}

/// Constructors for n-dimensional arrays with random elements.
///
/// This trait extends ndarray’s `ArrayBase` and can not be implemented
/// for other types.
/// Extension trait for constructing n-dimensional arrays with random elements.
///
/// The default RNG is a fast automatically seeded rng (currently
/// [`rand::rngs::SmallRng`], seeded from [`rand::thread_rng`]).
/// [`rand::rngs::SmallRng`], seeded from [`rand::rng`]).
///
/// Note that `SmallRng` is cheap to initialize and fast, but it may generate
/// low-quality random numbers, and reproducibility is not guaranteed. See its
/// documentation for information. You can select a different RNG with
/// [`.random_using()`](Self::random_using).
/// [`.random_using()`](RandomExt::random_using).
pub trait RandomExt<S, A, D>
where
S: RawData<Elem = A>,
Expand Down Expand Up @@ -124,6 +123,40 @@ where
S: DataOwned<Elem = A>,
Sh: ShapeBuilder<Dim = D>;

/// Sample `n_samples` lanes slicing along `axis` using the default RNG.
///
/// See [`RandomRefExt::sample_axis`] for additional information.
fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array<A, D>
where
A: Copy,
S: Data<Elem = A>,
D: RemoveAxis;

/// Sample `n_samples` lanes slicing along `axis` using the specified RNG `rng`.
///
/// See [`RandomRefExt::sample_axis_using`] for additional information.
fn sample_axis_using<R>(
&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy, rng: &mut R,
) -> Array<A, D>
where
R: Rng + ?Sized,
A: Copy,
S: Data<Elem = A>,
D: RemoveAxis;
}

/// Extension trait for sampling from [`ArrayRef`] with random elements.
///
/// The default RNG is a fast, automatically seeded rng (currently
/// [`rand::rngs::SmallRng`], seeded from [`rand::rng`]).
///
/// Note that `SmallRng` is cheap to initialize and fast, but it may generate
/// low-quality random numbers, and reproducibility is not guaranteed. See its
/// documentation for information. You can select a different RNG with
/// [`.sample_axis_using()`](RandomRefExt::sample_axis_using).
pub trait RandomRefExt<A, D>
where D: Dimension
{
/// Sample `n_samples` lanes slicing along `axis` using the default RNG.
///
/// If `strategy==SamplingStrategy::WithoutReplacement`, each lane can only be sampled once.
Expand Down Expand Up @@ -168,7 +201,6 @@ where
fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array<A, D>
where
A: Copy,
S: Data<Elem = A>,
D: RemoveAxis;

/// Sample `n_samples` lanes slicing along `axis` using the specified RNG `rng`.
Expand Down Expand Up @@ -225,7 +257,6 @@ where
where
R: Rng + ?Sized,
A: Copy,
S: Data<Elem = A>,
D: RemoveAxis;
}

Expand Down Expand Up @@ -259,7 +290,7 @@ where
S: Data<Elem = A>,
D: RemoveAxis,
{
self.sample_axis_using(axis, n_samples, strategy, &mut get_rng())
(**self).sample_axis(axis, n_samples, strategy)
}

fn sample_axis_using<R>(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy, rng: &mut R) -> Array<A, D>
Expand All @@ -268,6 +299,27 @@ where
A: Copy,
S: Data<Elem = A>,
D: RemoveAxis,
{
(**self).sample_axis_using(axis, n_samples, strategy, rng)
}
}

impl<A, D> RandomRefExt<A, D> for ArrayRef<A, D>
where D: Dimension
{
fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array<A, D>
where
A: Copy,
D: RemoveAxis,
{
self.sample_axis_using(axis, n_samples, strategy, &mut get_rng())
}

fn sample_axis_using<R>(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy, rng: &mut R) -> Array<A, D>
where
R: Rng + ?Sized,
A: Copy,
D: RemoveAxis,
{
let indices: Vec<_> = match strategy {
SamplingStrategy::WithReplacement => {
Expand All @@ -284,9 +336,10 @@ where
/// if lanes from the original array should only be sampled once (*without replacement*) or
/// multiple times (*with replacement*).
///
/// [`sample_axis`]: RandomExt::sample_axis
/// [`sample_axis_using`]: RandomExt::sample_axis_using
/// [`sample_axis`]: RandomRefExt::sample_axis
/// [`sample_axis_using`]: RandomRefExt::sample_axis_using
#[derive(Debug, Clone)]
#[allow(missing_docs)]
pub enum SamplingStrategy
{
WithReplacement,
Expand Down