Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fix of issue #3057 Allow zero total count for multinomial RNG #3061

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion stan/math/prim/prob/dirichlet_multinomial_rng.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ inline std::vector<int> dirichlet_multinomial_rng(
check_positive_finite(function, "prior size variable", alpha_ref);
check_nonnegative(function, "number of trials variables", N);

// special case N = 0 would lead to an exception thrown by multinomial_rng
// special case N = 0
if (N == 0) {
return std::vector<int>(alpha.size(), 0);
}
Expand Down
12 changes: 7 additions & 5 deletions stan/math/prim/prob/multinomial_logit_rng.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@ namespace math {

/** \ingroup multivar_dists
* Return a draw from a Multinomial distribution given a
* a vector of unnormalized log probabilities and a pseudo-random
* number generator.
* vector of unnormalized log probabilities, a total count,
* and a pseudo-random number generator.
*
* @tparam RNG Type of pseudo-random number generator.
* @param beta Vector of unnormalized log probabilities.
* @param N Total count
* @param N Total count.
* @param rng Pseudo-random number generator.
* @return Multinomial random variate
* @return Multinomial random variate.
* @throw std::domain_error if any element of beta is not finite.
* @throw std::domain_error is N is less than 0.
*/
template <class RNG, typename T_beta,
require_eigen_col_vector_t<T_beta>* = nullptr>
Expand All @@ -29,7 +31,7 @@ inline std::vector<int> multinomial_logit_rng(const T_beta& beta, int N,
static constexpr const char* function = "multinomial_logit_rng";
const auto& beta_ref = to_ref(beta);
check_finite(function, "Log-probabilities parameter", beta_ref);
check_positive(function, "number of trials variables", N);
check_nonnegative(function, "number of trials variables", N);

plain_type_t<T_beta> theta = softmax(beta_ref);
std::vector<int> result(theta.size(), 0);
Expand Down
15 changes: 14 additions & 1 deletion stan/math/prim/prob/multinomial_rng.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,26 @@
namespace stan {
namespace math {

/** \ingroup multivar_dists
* Return a draw from a Multinomial distribution given a
* probability simplex, a total count, and a pseudo-random
* number generator.
*
* @tparam RNG Type of pseudo-random number generator.
* @param theta Vector of normalized probabilities.
* @param N Total count.
* @param rng Pseudo-random number generator.
* @return Multinomial random variate.
* @throw std::domain_error if theta is not a simplex.
* @throw std::domain_error is N is less than 0.
*/
template <class T_theta, class RNG,
require_eigen_col_vector_t<T_theta>* = nullptr>
inline std::vector<int> multinomial_rng(const T_theta& theta, int N, RNG& rng) {
static constexpr const char* function = "multinomial_rng";
const auto& theta_ref = to_ref(theta);
check_simplex(function, "Probabilities parameter", theta_ref);
check_positive(function, "number of trials variables", N);
check_nonnegative(function, "number of trials variables", N);

std::vector<int> result(theta.size(), 0);
double mass_left = 1.0;
Expand Down
13 changes: 13 additions & 0 deletions test/unit/math/prim/prob/multinomial_logit_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,19 @@
using Eigen::Dynamic;
using Eigen::Matrix;

TEST(ProbDistributionsMultinomialLogit, RNGZero) {
boost::random::mt19937 rng;
Matrix<double, Dynamic, 1> beta(3);
beta << 1.3, 0.1, -2.6;
// bug in 4.8.1: RNG does not allow a zero total count
EXPECT_NO_THROW(stan::math::multinomial_logit_rng(beta, 0, rng));
// when the total count is zero, the sample should be a zero array
std::vector<int> sample = stan::math::multinomial_logit_rng(beta, 0, rng);
for (int k : sample) {
EXPECT_EQ(0, k);
}
}

TEST(ProbDistributionsMultinomialLogit, RNGSize) {
boost::random::mt19937 rng;
Matrix<double, Dynamic, 1> beta(5);
Expand Down
15 changes: 15 additions & 0 deletions test/unit/math/prim/prob/multinomial_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,21 @@
#include <limits>
#include <vector>

TEST(ProbDistributionsMultinomial, RNGZero) {
using Eigen::Dynamic;
using Eigen::Matrix;
boost::random::mt19937 rng;
Matrix<double, Dynamic, 1> theta(3);
theta << 0.3, 0.1, 0.6;
// bug in 4.8.1: RNG does not allow a zero total count
EXPECT_NO_THROW(stan::math::multinomial_rng(theta, 0, rng));
// when the total count is zero, the sample should be a zero array
std::vector<int> sample = stan::math::multinomial_rng(theta, 0, rng);
for (int k : sample) {
EXPECT_EQ(0, k);
}
}

TEST(ProbDistributionsMultinomial, RNGSize) {
using Eigen::Dynamic;
using Eigen::Matrix;
Expand Down