diff --git a/stan/math/prim/prob/dirichlet_multinomial_rng.hpp b/stan/math/prim/prob/dirichlet_multinomial_rng.hpp index 7934001e2d7..2dd9051a1ad 100644 --- a/stan/math/prim/prob/dirichlet_multinomial_rng.hpp +++ b/stan/math/prim/prob/dirichlet_multinomial_rng.hpp @@ -41,7 +41,7 @@ inline std::vector dirichlet_multinomial_rng( check_positive_finite(function, "prior size variable", alpha_ref); check_nonnegative(function, "number of trials variables", N); - // special case N = 0 would lead to an exception thrown by multinomial_rng + // special case N = 0 if (N == 0) { return std::vector(alpha.size(), 0); } diff --git a/stan/math/prim/prob/multinomial_logit_rng.hpp b/stan/math/prim/prob/multinomial_logit_rng.hpp index 969c007bd1f..70b85b6f44c 100644 --- a/stan/math/prim/prob/multinomial_logit_rng.hpp +++ b/stan/math/prim/prob/multinomial_logit_rng.hpp @@ -13,14 +13,16 @@ namespace math { /** \ingroup multivar_dists * Return a draw from a Multinomial distribution given a - * a vector of unnormalized log probabilities and a pseudo-random - * number generator. + * vector of unnormalized log probabilities, a total count, + * and a pseudo-random number generator. * * @tparam RNG Type of pseudo-random number generator. * @param beta Vector of unnormalized log probabilities. - * @param N Total count + * @param N Total count. * @param rng Pseudo-random number generator. - * @return Multinomial random variate + * @return Multinomial random variate. + * @throw std::domain_error if any element of beta is not finite. + * @throw std::domain_error is N is less than 0. */ template * = nullptr> @@ -29,7 +31,7 @@ inline std::vector multinomial_logit_rng(const T_beta& beta, int N, static constexpr const char* function = "multinomial_logit_rng"; const auto& beta_ref = to_ref(beta); check_finite(function, "Log-probabilities parameter", beta_ref); - check_positive(function, "number of trials variables", N); + check_nonnegative(function, "number of trials variables", N); plain_type_t theta = softmax(beta_ref); std::vector result(theta.size(), 0); diff --git a/stan/math/prim/prob/multinomial_rng.hpp b/stan/math/prim/prob/multinomial_rng.hpp index 86edf66beb6..8e1610c60b4 100644 --- a/stan/math/prim/prob/multinomial_rng.hpp +++ b/stan/math/prim/prob/multinomial_rng.hpp @@ -10,13 +10,26 @@ namespace stan { namespace math { +/** \ingroup multivar_dists + * Return a draw from a Multinomial distribution given a + * probability simplex, a total count, and a pseudo-random + * number generator. + * + * @tparam RNG Type of pseudo-random number generator. + * @param theta Vector of normalized probabilities. + * @param N Total count. + * @param rng Pseudo-random number generator. + * @return Multinomial random variate. + * @throw std::domain_error if theta is not a simplex. + * @throw std::domain_error is N is less than 0. + */ template * = nullptr> inline std::vector multinomial_rng(const T_theta& theta, int N, RNG& rng) { static constexpr const char* function = "multinomial_rng"; const auto& theta_ref = to_ref(theta); check_simplex(function, "Probabilities parameter", theta_ref); - check_positive(function, "number of trials variables", N); + check_nonnegative(function, "number of trials variables", N); std::vector result(theta.size(), 0); double mass_left = 1.0; diff --git a/test/unit/math/prim/prob/multinomial_logit_test.cpp b/test/unit/math/prim/prob/multinomial_logit_test.cpp index 19988208587..ea587254e9a 100644 --- a/test/unit/math/prim/prob/multinomial_logit_test.cpp +++ b/test/unit/math/prim/prob/multinomial_logit_test.cpp @@ -8,6 +8,19 @@ using Eigen::Dynamic; using Eigen::Matrix; +TEST(ProbDistributionsMultinomialLogit, RNGZero) { + boost::random::mt19937 rng; + Matrix beta(3); + beta << 1.3, 0.1, -2.6; + // bug in 4.8.1: RNG does not allow a zero total count + EXPECT_NO_THROW(stan::math::multinomial_logit_rng(beta, 0, rng)); + // when the total count is zero, the sample should be a zero array + std::vector sample = stan::math::multinomial_logit_rng(beta, 0, rng); + for (int k : sample) { + EXPECT_EQ(0, k); + } +} + TEST(ProbDistributionsMultinomialLogit, RNGSize) { boost::random::mt19937 rng; Matrix beta(5); diff --git a/test/unit/math/prim/prob/multinomial_test.cpp b/test/unit/math/prim/prob/multinomial_test.cpp index 7821264f05e..62e1dbd5a58 100644 --- a/test/unit/math/prim/prob/multinomial_test.cpp +++ b/test/unit/math/prim/prob/multinomial_test.cpp @@ -5,6 +5,21 @@ #include #include +TEST(ProbDistributionsMultinomial, RNGZero) { + using Eigen::Dynamic; + using Eigen::Matrix; + boost::random::mt19937 rng; + Matrix theta(3); + theta << 0.3, 0.1, 0.6; + // bug in 4.8.1: RNG does not allow a zero total count + EXPECT_NO_THROW(stan::math::multinomial_rng(theta, 0, rng)); + // when the total count is zero, the sample should be a zero array + std::vector sample = stan::math::multinomial_rng(theta, 0, rng); + for (int k : sample) { + EXPECT_EQ(0, k); + } +} + TEST(ProbDistributionsMultinomial, RNGSize) { using Eigen::Dynamic; using Eigen::Matrix;