From b3cc16c3f2ffcfbeb7305a61cca2a040dcd15ed5 Mon Sep 17 00:00:00 2001 From: kendalfoster Date: Sun, 19 Sep 2021 21:50:52 -0400 Subject: [PATCH 1/2] fixes #2584 --- stan/math/prim/prob.hpp | 2 + stan/math/prim/prob/ddm_lcdf.hpp | 399 ++++++++++++++++++ stan/math/prim/prob/ddm_lpdf.hpp | 378 +++++++++++++++++ test/prob/ddm/ddm_cdf_log_test.hpp | 197 +++++++++ test/prob/ddm/ddm_test.hpp | 202 +++++++++ test/unit/math/prim/prob/ddm_cdf_log_test.cpp | 97 +++++ test/unit/math/prim/prob/ddm_test.cpp | 333 +++++++++++++++ 7 files changed, 1608 insertions(+) create mode 100644 stan/math/prim/prob/ddm_lcdf.hpp create mode 100644 stan/math/prim/prob/ddm_lpdf.hpp create mode 100644 test/prob/ddm/ddm_cdf_log_test.hpp create mode 100644 test/prob/ddm/ddm_test.hpp create mode 100644 test/unit/math/prim/prob/ddm_cdf_log_test.cpp create mode 100644 test/unit/math/prim/prob/ddm_test.cpp diff --git a/stan/math/prim/prob.hpp b/stan/math/prim/prob.hpp index 2fce0b34e1d..5b302630f32 100644 --- a/stan/math/prim/prob.hpp +++ b/stan/math/prim/prob.hpp @@ -71,6 +71,8 @@ #include #include #include +#include +#include #include #include #include diff --git a/stan/math/prim/prob/ddm_lcdf.hpp b/stan/math/prim/prob/ddm_lcdf.hpp new file mode 100644 index 00000000000..066f53970c3 --- /dev/null +++ b/stan/math/prim/prob/ddm_lcdf.hpp @@ -0,0 +1,399 @@ +#ifndef STAN_MATH_PRIM_PROB_DDM_LCDF_HPP +#define STAN_MATH_PRIM_PROB_DDM_LCDF_HPP + +#include +#include +#include +#include +#include +#include +#include + + +// Open the Namespace +namespace stan { +namespace math { +using stan::return_type_t; + +/** + * The log of the first passage time distribution function for a + * (Ratcliff, 1978) drift diffusion model with intrinsic trial-trial variability + * for the given response time \f$rt\f$, response \f$response\f$, boundary + * separation \f$a\f$, mean drift rate across trials \f$v\f$, non-decision + * time \f$t0\f$, relative bias \f$w\f$, standard deviation of drift rate + * across trials \f$sv\f$. + * + * @tparam T_rt type of parameter `rt` + * @tparam T_response type of parameter `response` + * @tparam T_a type of parameter `a` + * @tparam T_v type of parameter `v` + * @tparam T_t0 type of parameter `t0` + * @tparam T_w type of parameter `w` + * @tparam T_sv type of parameter `sv` + * + * @param rt The response time; rt >= 0. + * @param response The response; response in {1, 2}. + * @param a The threshold separation; a > 0. + * @param v The mean drift rate across trials. + * @param t0 The non-decision time; t0 >= 0. + * @param w The relative a priori bias; 0 < w < 1. + * @param sv The standard deviation of drift rate across trials; sv >= 0. + * @return The log of the Wiener first passage time density of + * the specified arguments. + */ +template +return_type_t ddm_lcdf( + const T_rt& rt, const T_response& response, const T_a& a, + const T_v& v, const T_t0& t0, const T_w& w, const T_sv& sv) { + using T_return_type = return_type_t; + using std::vector; + using std::log; + using std::exp; + using std::erf; + using std::sqrt; + using std::isfinite; + using std::isnan; + using std::max; + using stan::ref_type_t; + using stan::scalar_seq_view; + using stan::math::throw_domain_error; + using stan::math::invalid_argument; + using stan::math::include_summand; + using T_rt_ref = ref_type_t; + using T_response_ref = ref_type_t; + using T_a_ref = ref_type_t; + using T_v_ref = ref_type_t; + using T_t0_ref = ref_type_t; + using T_w_ref = ref_type_t; + using T_sv_ref = ref_type_t; + + // Constants + static const char* function = "ddm_lcdf"; + static const double ERR_TOL = 0.000001; // error tolerance for PDF approx + static const double PI_CONST = 3.14159265358979323846; // define pi like C++ + static const double SQRT_2PI = sqrt(2 * PI_CONST); + static const double SQRT_2PI_INV = 1 / SQRT_2PI; + static const double SQRT_2_INV_NEG = -1 / sqrt(2); + + // Convert Inputs + T_rt_ref rt_ref = rt; + T_response_ref response_ref = response; + T_a_ref a_ref = a; + T_v_ref v_ref = v; + T_t0_ref t0_ref = t0; + T_w_ref w_ref = w; + T_sv_ref sv_ref = sv; + scalar_seq_view rt_vec(rt_ref); + scalar_seq_view response_vec(response_ref); + scalar_seq_view a_vec(a_ref); + scalar_seq_view v_vec(v_ref); + scalar_seq_view t0_vec(t0_ref); + scalar_seq_view w_vec(w_ref); + scalar_seq_view sv_vec(sv_ref); + + // Parameter Checks + size_t Nrt = rt_vec.size(); + size_t Nres = response_vec.size(); + size_t Na = a_vec.size(); + size_t Nv = v_vec.size(); + size_t Nt0 = t0_vec.size(); + size_t Nw = w_vec.size(); + size_t Nsv = sv_vec.size(); + size_t Nmax = max({Nrt, Nres, Na, Nv, Nt0, Nw, Nsv}); + vector out(Nmax); // initialize output-checking vector + + if (Nrt < 1) { // rt, invalid inputs will be handled in calculation of the CDF + return 0; + } + + if (Nres < 1) { // response + return 0; + } else { + for (size_t i = 0; i < Nres; i++) { + if (response_vec[i] == 1) { // lower + for (size_t j = i; j < Nmax; j += Nres) { + out[j] = 1; + } + } else if (response_vec[i] == 2) { // upper + for (size_t j = i; j < Nmax; j += Nres) { + out[j] = 2; + } + } else { // {NaN, NA} evaluate to FALSE + throw_domain_error(function, "response", response_vec[i], " = ", + ", but it must be either 1 (lower) or 2 (upper)"); + } + } + } + + if (Na < 1) { // a + return 0; + } else { + for (size_t i = 0; i < Na; i++) { + if (a_vec[i] > 0) { + if (isfinite(a_vec[i])) { + continue; + } else { // a = Inf implies PDF = log(0) and CDF problems + throw_domain_error(function, "a (threshold separation)", a_vec[i], + " = ", ", but it must be finite"); + } + } else { // {NaN, NA} evaluate to FALSE + throw_domain_error(function, "a (threshold separation)", a_vec[i], + " = ", ", but it must be positive and finite"); + } + } + } + + if (Nv < 1) { // v + return 0; + } else { + for (size_t i = 0; i < Nv; i++) { + if (isfinite(v_vec[i])) { + continue; + } else { // NaN, NA, Inf, -Inf are not finite + throw_domain_error(function, "v (drift rate)", v_vec[i], " = ", + ", but it must be finite"); + } + } + } + + if (Nt0 < 1) { // t0 + return 0; + } else { + for (size_t i = 0; i < Nt0; i++) { + if (t0_vec[i] >= 0) { + if (isfinite(t0_vec[i])) { // this could also be handled in calculation of CDF + continue; + } else { // t0 = Inf implies rt - t0 < 0 implies CDF = log(0) + throw_domain_error(function, "t0 (non-decision time)", t0_vec[i], + " = ", ", but it must be finite"); + } + } else { // {NaN, NA} evaluate to FALSE + throw_domain_error(function, "t0 (non-decision time)", t0_vec[i], + " = ", ", but it must be positive and finite"); + } + } + } + + if (Nw < 1) { // w + return 0; + } else { + for (size_t i = 0; i < Nw; i++) { + if (w_vec[i] > 0 && w_vec[i] < 1) { + continue; + } else { // {NaN, NA} evaluate to FALSE + throw_domain_error(function, "w (relative a priori bias)", w_vec[i], + " = ", ", but it must be that 0 < w < 1"); + } + } + } + + if (Nsv < 1) { // sv + return 0; + } else { + for (size_t i = 0; i < Nsv; i++) { + if (sv_vec[i] >= 0) { + if (isfinite(sv_vec[i])) { + continue; + } else { // sv = Inf implies PDF = log(0) and CDF problems + throw_domain_error(function, "sv (standard deviation of drift rate across trials)", + sv_vec[i], " = ", ", but it must be finite"); + } + } else { // {NaN, NA} evaluate to FALSE + throw_domain_error(function, "sv (standard deviation of drift rate across trials)", + sv_vec[i], " = ", + ", but it must be positive and finite"); + } + } + } + + if (!include_summand::value) { + return 0; + } + + + // Calculate log(CDF) + T_return_type lp(0.0); + double t, a_i, v_i, w_i, sv_i; + for (size_t i = 0; i < Nmax; i++) { + + // Check Parameter Values + t = rt_vec[i % Nrt] - t0_vec[i % Nt0]; // response time minus non-decision time + if (t > 0) { // sort response and calculate density + a_i = a_vec[i % Na]; + sv_i = sv_vec[i % Nsv]; + if (out[i] == 1) { // response is "lower" so use unchanged parameters + v_i = v_vec[i % Nv]; + w_i = w_vec[i % Nw]; + } else { // response is "upper" so use alternate parameters + v_i = -v_vec[i % Nv]; + w_i = 1 - w_vec[i % Nw]; + } + + if (t > 32) { // approximation for t = +Infinity + t = 32; + } + + // Calculate sum multiplier + double mult = (sv_i*sv_i * a_i*a_i * w_i*w_i - + 2 * v_i * a_i * w_i - v_i*v_i * t) / + (2 + 2 * sv_i*sv_i * t); + + // Scale error so it is valid inside the sum (not logged) + double exp_err = ERR_TOL * exp(-mult); + + // Calculate sum + double sum = 0; + double gamma = v_i - sv_i*sv_i * a_i * w_i; + double lambda = 1 + sv_i*sv_i * t; + double rho = sqrt(t * lambda); + + int j = 0; + double rj = a_i * j + a_i * w_i; + double m1 = (lambda * rj - gamma * t) / rho; + double m2 = (lambda * rj + gamma * t) / rho; + double mills_1, mills_2; + if (m1 < 6.5) { + mills_1 = SQRT_2PI * 0.5 * exp(0.5 * m1*m1) * + (1 + erf(SQRT_2_INV_NEG * m1)); + } else { + double m1sq = m1*m1; + mills_1 = ( 1 - + 1 / (m1sq + 2) + + 1 / ( (m1sq + 2) * (m1sq + 4) ) - + 5 / ( (m1sq + 2) * (m1sq + 4) * (m1sq + 6) ) + + 9 / ( (m1sq + 2) * (m1sq + 4) * (m1sq + 6) * (m1sq + 8) ) - + 129 / ( (m1sq + 2) * (m1sq + 4) * (m1sq + 6) * (m1sq + 8) * + (m1sq + 10) ) + ) / m1; + } + if (m2 < 6.5) { + mills_2 = SQRT_2PI * 0.5 * exp(0.5 * m2*m2) * + (1 + erf(SQRT_2_INV_NEG * m2)); + } else { + double m2sq = m2*m2; + mills_2 = ( 1 - + 1 / (m2sq + 2) + + 1 / ( (m2sq + 2) * (m2sq + 4) ) - + 5 / ( (m2sq + 2) * (m2sq + 4) * (m2sq + 6) ) + + 9 / ( (m2sq + 2) * (m2sq + 4) * (m2sq + 6) * (m2sq + 8) ) - + 129 / ( (m2sq + 2) * (m2sq + 4) * (m2sq + 6) * (m2sq + 8) * + (m2sq + 10) ) + ) / m2; + } + double term = SQRT_2PI_INV * exp(-0.5 * rj*rj / t) * (mills_1 + mills_2); + sum += term; + + while (term > exp_err) { + if (j > 1000) { + // maybe include a warning here? + break; + } + j++; + rj = a_i * j + a_i * (1 - w_i); + m1 = (lambda * rj - gamma * t) / rho; + m2 = (lambda * rj + gamma * t) / rho; + if (m1 < 6.5) { + mills_1 = SQRT_2PI * 0.5 * exp(0.5 * m1*m1) * + (1 + erf(SQRT_2_INV_NEG * m1)); + } else { + double m1sq = m1*m1; + mills_1 = ( 1 - + 1 / (m1sq + 2) + + 1 / ( (m1sq + 2) * (m1sq + 4) ) - + 5 / ( (m1sq + 2) * (m1sq + 4) * (m1sq + 6) ) + + 9 / ( (m1sq + 2) * (m1sq + 4) * (m1sq + 6) * + (m1sq + 8) ) - + 129 / ( (m1sq + 2) * (m1sq + 4) * (m1sq + 6) * + (m1sq + 8) * (m1sq + 10) ) + ) / m1; + } + if (m2 < 6.5) { + mills_2 = SQRT_2PI * 0.5 * exp(0.5 * m2*m2) * + (1 + erf(SQRT_2_INV_NEG * m2)); + } else { + double m2sq = m2*m2; + mills_2 = ( 1 - + 1 / (m2sq + 2) + + 1 / ( (m2sq + 2) * (m2sq + 4) ) - + 5 / ( (m2sq + 2) * (m2sq + 4) * (m2sq + 6) ) + + 9 / ( (m2sq + 2) * (m2sq + 4) * (m2sq + 6) * + (m2sq + 8) ) - + 129 / ( (m2sq + 2) * (m2sq + 4) * (m2sq + 6) * + (m2sq + 8) * (m2sq + 10) ) + ) / m2; + } + term = SQRT_2PI_INV * exp(-0.5 * rj*rj / t) * (mills_1 + mills_2); + sum -= term; + + if (term <= exp_err) break; + + j++; + rj = a_i * j + a_i * w_i; + m1 = (lambda * rj - gamma * t) / rho; + m2 = (lambda * rj + gamma * t) / rho; + if (m1 < 6.5) { + mills_1 = SQRT_2PI * 0.5 * exp(0.5 * m1*m1) * + (1 + erf(SQRT_2_INV_NEG * m1)); + } else { + double m1sq = m1*m1; + mills_1 = ( 1 - + 1 / (m1sq + 2) + + 1 / ( (m1sq + 2) * (m1sq + 4) ) - + 5 / ( (m1sq + 2) * (m1sq + 4) * (m1sq + 6) ) + + 9 / ( (m1sq + 2) * (m1sq + 4) * (m1sq + 6) * + (m1sq + 8) ) - + 129 / ( (m1sq + 2) * (m1sq + 4) * (m1sq + 6) * + (m1sq + 8) * (m1sq + 10) ) + ) / m1; + } + if (m2 < 6.5) { + mills_2 = SQRT_2PI * 0.5 * exp(0.5 * m2*m2) * + (1 + erf(SQRT_2_INV_NEG * m2)); + } else { + double m2sq = m2*m2; + mills_2 = ( 1 - + 1 / (m2sq + 2) + + 1 / ( (m2sq + 2) * (m2sq + 4) ) - + 5 / ( (m2sq + 2) * (m2sq + 4) * (m2sq + 6) ) + + 9 / ( (m2sq + 2) * (m2sq + 4) * (m2sq + 6) * + (m2sq + 8) ) - + 129 / ( (m2sq + 2) * (m2sq + 4) * (m2sq + 6) * + (m2sq + 8) * (m2sq + 10) ) + ) / m2; + } + term = SQRT_2PI_INV * exp(-0.5 * rj*rj / t) * (mills_1 + mills_2); + sum += term; + } + + // Add sum and multiplier to lp + if (sum >= 0) { // if result is negative, treat as 0 and don't add to lp + lp += mult + log(sum); + } + } else { // {NaN, NA} evaluate to FALSE + if (isnan(t)) { + throw_domain_error(function, "rt (response time)", rt_vec[i % Nrt], + "is a NaN and = ", ", but this value is invalid"); + } else { + throw_domain_error(function, "rt (response time)", t0_vec[i % Nt0], + "is not greater than t0 = ", ", but it must be that rt - t0 > 0"); + } + } + } + + return lp; +} + +template +inline return_type_t ddm_lcdf( + const T_rt& rt, const T_response& response, const T_a& a, const T_v& v, + const T_t0& t0, const T_w& w, const T_sv& sv) { + return ddm_lcdf(rt, response, a, v, t0, w, sv); +} + +} +} // Close namespace +#endif diff --git a/stan/math/prim/prob/ddm_lpdf.hpp b/stan/math/prim/prob/ddm_lpdf.hpp new file mode 100644 index 00000000000..b3d78c46b7a --- /dev/null +++ b/stan/math/prim/prob/ddm_lpdf.hpp @@ -0,0 +1,378 @@ +#ifndef STAN_MATH_PRIM_PROB_DDM_LPDF_HPP +#define STAN_MATH_PRIM_PROB_DDM_LPDF_HPP + +#define _USE_MATH_DEFINES +#include +#include +#include +#include +#include +#include +#include + + +// Open the Namespace +namespace stan { +namespace math { +using stan::return_type_t; + +/** + * The log of the first passage time density function for a (Ratcliff, 1978) + * drift diffusion model with intrinsic trial-trial variability + * for the given response time \f$rt\f$, response \f$response\f$, boundary + * separation \f$a\f$, mean drift rate across trials \f$v\f$, non-decision + * time \f$t0\f$, relative bias \f$w\f$, and standard deviation of drift rate + * across trials \f$sv\f$. + * + * @tparam T_rt type of parameter `rt` + * @tparam T_response type of parameter `response` + * @tparam T_a type of parameter `a` + * @tparam T_v type of parameter `v` + * @tparam T_t0 type of parameter `t0` + * @tparam T_w type of parameter `w` + * @tparam T_sv type of parameter `sv` + * + * @param rt The response time; rt >= 0. + * @param response The response; response in {1, 2}. + * @param a The threshold separation; a > 0. + * @param v The mean drift rate across trials. + * @param t0 The non-decision time; t0 >= 0. + * @param w The relative a priori bias; 0 < w < 1. + * @param sv The standard deviation of drift rate across trials; sv >= 0. + * @return The log of the Wiener first passage time density of + * the specified arguments. + */ +template +return_type_t ddm_lpdf( + const T_rt& rt, const T_response& response, const T_a& a, const T_v& v, + const T_t0& t0, const T_w& w, const T_sv& sv) { + using T_return_type = return_type_t; + using std::vector; + using std::log; + using std::exp; + using std::sqrt; + using std::max; + using std::ceil; + using std::isfinite; + using std::isnan; + using stan::ref_type_t; + using stan::scalar_seq_view; + using stan::math::throw_domain_error; + using stan::math::invalid_argument; + using stan::math::include_summand; + using T_rt_ref = ref_type_t; + using T_response_ref = ref_type_t; + using T_a_ref = ref_type_t; + using T_v_ref = ref_type_t; + using T_t0_ref = ref_type_t; + using T_w_ref = ref_type_t; + using T_sv_ref = ref_type_t; + + // Constants + static const char* function = "ddm_lpdf"; + static const int max_terms_large = 1; // heuristic for switching mechanism + static const double ERR_TOL = 0.000001; // error tolerance for PDF approx + static const double SV_THRESH = 0; // threshold for using variable drift rate + static const double LOG_PI = log(M_PI); + static const double LOG_2PI_2 = 0.5 * log(2 * M_PI); + + // Convert Inputs + T_rt_ref rt_ref = rt; + T_response_ref response_ref = response; + T_a_ref a_ref = a; + T_v_ref v_ref = v; + T_t0_ref t0_ref = t0; + T_w_ref w_ref = w; + T_sv_ref sv_ref = sv; + scalar_seq_view rt_vec(rt_ref); + scalar_seq_view response_vec(response_ref); + scalar_seq_view a_vec(a_ref); + scalar_seq_view v_vec(v_ref); + scalar_seq_view t0_vec(t0_ref); + scalar_seq_view w_vec(w_ref); + scalar_seq_view sv_vec(sv_ref); + + // Parameter Checks + size_t Nrt = rt_vec.size(); + size_t Nres = response_vec.size(); + size_t Na = a_vec.size(); + size_t Nv = v_vec.size(); + size_t Nt0 = t0_vec.size(); + size_t Nw = w_vec.size(); + size_t Nsv = sv_vec.size(); + size_t Nmax = max({Nrt, Nres, Na, Nv, Nt0, Nw, Nsv}); + vector out(Nmax); // initialize output-checking vector + + if (Nrt < 1) { // rt, invalid inputs will be handled in calculation of the pdf + return 0; + } + + if (Nres < 1) { // response + return 0; + } else { + for (size_t i = 0; i < Nres; i++) { + if (response_vec[i] == 1) { // lower + for (size_t j = i; j < Nmax; j += Nres) { + out[j] = 1; + } + } else if (response_vec[i] == 2) { // upper + for (size_t j = i; j < Nmax; j += Nres) { + out[j] = 2; + } + } else { // {NaN, NA} evaluate to FALSE + throw_domain_error(function, "response", response_vec[i], " = ", + ", but it must be either 1 (lower) or 2 (upper)"); + } + } + } + + if (Na < 1) { // a + return 0; + } else { + for (size_t i = 0; i < Na; i++) { + if (a_vec[i] > 0) { + if (isfinite(a_vec[i])) { + continue; + } else { // a = Inf implies PDF = log(0) + throw_domain_error(function, "a (threshold separation)", a_vec[i], + " = ", ", but it must be positive and finite"); + } + } else { // {NaN, NA} evaluate to FALSE + throw_domain_error(function, "a (threshold separation)", a_vec[i], + " = ", ", but it must be positive and finite"); + } + } + } + + if (Nv < 1) { // v + return 0; + } else { + for (size_t i = 0; i < Nv; i++) { + if (isfinite(v_vec[i])) { + continue; + } else { // NaN, NA, Inf, -Inf are not finite + throw_domain_error(function, "v (drift rate)", v_vec[i], " = ", + ", but it must be finite"); + } + } + } + + if (Nt0 < 1) { // t0 + return 0; + } else { + for (size_t i = 0; i < Nt0; i++) { + if (t0_vec[i] >= 0) { + if (isfinite(t0_vec[i])) { // this could also be handled in calculate_pdf() + continue; + } else { // t0 = Inf implies rt - t0 < 0 implies PDF = log(0) + throw_domain_error(function, "t0 (non-decision time)", t0_vec[i], + " = ", ", but it must be positive and finite"); + } + } else { // {NaN, NA} evaluate to FALSE + throw_domain_error(function, "t0 (non-decision time)", t0_vec[i], + " = ", ", but it must be positive and finite"); + } + } + } + + if (Nw < 1) { // w + return 0; + } else { + for (size_t i = 0; i < Nw; i++) { + if (w_vec[i] > 0 && w_vec[i] < 1) { + continue; + } else { // {NaN, NA} evaluate to FALSE + throw_domain_error(function, "w (relative a priori bias)", w_vec[i], + " = ", ", but it must be that 0 < w < 1"); + } + } + } + + if (Nsv < 1) { // sv + return 0; + } else { + for (size_t i = 0; i < Nsv; i++) { + if (sv_vec[i] >= 0) { + if (isfinite(sv_vec[i])) { + continue; + } else { // sv = Inf implies PDF = log(0) + throw_domain_error(function, + "sv (standard deviation of drift rate across trials)", + sv_vec[i], " = ", + ", but it must be positive and finite"); + } + } else { // {NaN, NA} evaluate to FALSE + throw_domain_error(function, + "sv (standard deviation of drift rate across trials)", + sv_vec[i], " = ", + ", but it must be positive and finite"); + } + } + } + + if (!include_summand::value) { + return 0; + } + + + // Calculate log(PDF) + T_return_type lp(0.0); + double t, a_i, v_i, w_i, sv_i; + for (size_t i = 0; i < Nmax; i++) { + + // Check Parameter Values + t = rt_vec[i % Nrt] - t0_vec[i % Nt0]; // response time minus non-decision time + if (t > 0 && std::isfinite(t)) { // sort response and calculate density + a_i = a_vec[i % Na]; + sv_i = sv_vec[i % Nsv]; + if (out[i] == 1) { // response is "lower" so use unchanged parameters + v_i = v_vec[i % Nv]; + w_i = w_vec[i % Nw]; + } else { // response is "upper" so use alternate parameters + v_i = -v_vec[i % Nv]; + w_i = 1 - w_vec[i % Nw]; + } + + // Approximate log(PDF) + double mult; + // Check large time + if (sv_i <= SV_THRESH) { // no sv + mult = - v_i * a_i * w_i - v_i*v_i * t / 2 - 2 * log(a_i); + } else { // sv + mult = (sv_i*sv_i * a_i*a_i * w_i*w_i - 2 * v_i * a_i * w_i + - v_i*v_i * t) / (2 + 2 * sv_i*sv_i * t) + - 0.5 * log(1 + sv_i*sv_i * t) - 2 * log(a_i); + } + int kl; + double exp_err = ERR_TOL * exp(-mult); + double taa = t / (a_i*a_i); + double bc = 1 / (M_PI * sqrt(taa)); // boundary conditions + if (bc > INT_MAX) return INT_MAX; + if (exp_err * M_PI * taa < 1) { // error threshold is low enough + double kl_tmp = sqrt(-2 * log(M_PI * taa * exp_err) + / (M_PI*M_PI * taa)); + if (kl_tmp > INT_MAX) { + kl = INT_MAX; + } else { + kl = ceil(max(kl_tmp, bc)); // ensure boundary conditions are met + } + } else { + kl = ceil(bc); // else set to boundary condition + } + + // Compare kl (large time) to max_terms_large (small time) + if (kl <= max_terms_large) { // use large time + double gamma = -0.5 * M_PI*M_PI * taa; + double sum = 0.0; + for (size_t j = 1; j <= kl; j++) { + sum += j * sin(j * w_i * M_PI) * exp(gamma * j*j); + } + if (sum >= 0) { // if result is negative, don't add to lp + lp += LOG_PI + mult + log(sum); + } + } else { // use small time + if (sv_i <= SV_THRESH) { // no sv + mult = log(a_i) - LOG_2PI_2 - 1.5 * log(t) + - v_i * a_i * w_i - v_i*v_i * t / 2; + } else { // sv + mult = log(a_i) - 1.5 * log(t) - LOG_2PI_2 + - 0.5 * log(1 + sv_i*sv_i * t) + + (sv_i*sv_i * a_i*a_i * w_i*w_i - 2 * v_i * a_i * w_i + - v_i*v_i * t) / (2 + 2 * sv_i*sv_i * t); + } + exp_err = ERR_TOL / exp(mult); + size_t minterms = sqrt(t)/a_i - w_i; // min number of terms, truncates toward 0 + double gamma = -1 / (2 * taa); + double sum = w_i * exp(gamma * w_i*w_i); // initialize with j=0 term + double term, rj; + size_t j = 0; + if (minterms % 2) { // minterms is odd (and at least 1) + j++; + rj = j + 1 - w_i; + term = rj * exp(gamma * rj*rj); + sum -= term; + while (j < minterms) { + j++; + rj = j + w_i; + sum += rj * exp(gamma * rj*rj); + j++; + rj = j + 1 - w_i; + term = rj * exp(gamma * rj*rj); + sum -= term; + } + j++; + rj = j + w_i; // j is now even + term = rj * exp(gamma * rj*rj); + sum += term; + while (term > exp_err) { + j++; + rj = j + 1 - w_i; + term = rj * exp(gamma * rj*rj); + sum -= term; + if (term <= exp_err) break; + j++; + rj = j + w_i; + term = rj * exp(gamma * rj*rj); + sum += term; + } + } else { // minterms is even (and at least 0) + while (j < minterms) { // j is currently 0 + j++; + rj = j + 1 - w_i; + sum -= rj * exp(gamma * rj*rj); + j++; + rj = j + w_i; + term = rj * exp(gamma * rj*rj); + sum += term; + } + j++; + rj = j + 1 - w_i; // j is now odd + term = rj * exp(gamma * rj*rj); + sum -= term; + while (term > exp_err) { + j++; + rj = j + w_i; + term = rj * exp(gamma * rj*rj); + sum += term; + if (term <= exp_err) break; + j++; + rj = j + 1 - w_i; + term = rj * exp(gamma * rj*rj); + sum -= term; + } + } + if (sum >= 0) { // if result is negative, don't add to lp + lp += mult + log(sum); + } + } + } else { // {NaN, NA} evaluate to FALSE + if (isnan(t)) { + throw_domain_error(function, "rt (response time)", rt_vec[i % Nrt], + "is a NaN and = ", + ", but rt must be positive and finite"); + } else { + throw_domain_error(function, + "rt (response time)", t0_vec[i % Nt0], + "is not greater than t0 = ", + ", but it must be that rt - t0 is positive and finite"); + } + } + } + + return lp; +} + +template +inline return_type_t ddm_lpdf( + const T_rt& rt, const T_response& response, const T_a& a, const T_v& v, + const T_t0& t0, const T_w& w, const T_sv& sv) { + return ddm_lpdf(rt, response, a, v, t0, w, sv); +} + +} +} // Close namespace +#endif diff --git a/test/prob/ddm/ddm_cdf_log_test.hpp b/test/prob/ddm/ddm_cdf_log_test.hpp new file mode 100644 index 00000000000..c28572beb8b --- /dev/null +++ b/test/prob/ddm/ddm_cdf_log_test.hpp @@ -0,0 +1,197 @@ +// Arguments: Doubles, Ints, Doubles, Doubles, Doubles, Doubles, Doubles +#include + +using std::vector; +using stan::math::INFTY; + +class AgradCdfDdm : public AgradCdfTest { +public: + void valid_values(vector >& parameters, + vector& cdf) { + vector param(7); + + // each expected log_prob is calculated with the R package `fddm` as follows + // fddm::pfddm(rt, response, a, v, t0, w, sv, log = TRUE) + + param[0] = 1.0; // rt + param[1] = 1; // response + param[2] = 1.0; // a + param[3] = -1.0; // v + param[4] = 0.0; // t0 + param[5] = 0.5; // w + param[6] = 0.0; // sv + parameters.push_back(param); + log_prob.push_back(-0.3189645693469165); // expected log_prob + + param[0] = 1.0; // rt + param[1] = 2; // response + param[2] = 1.0; // a + param[3] = -1.0; // v + param[4] = 0.0; // t0 + param[5] = 0.5; // w + param[6] = 0.0; // sv + parameters.push_back(param); + log_prob.push_back(-1.318964569346917); // expected log_prob + + param[0] = 1.0; // rt + param[1] = 1; // response + param[2] = 2.0; // a + param[3] = -1.0; // v + param[4] = 0.0; // t0 + param[5] = 0.5; // w + param[6] = 0.0; // sv + parameters.push_back(param); + log_prob.push_back(-0.4105356025317656); // expected log_prob + + param[0] = 1.0; // rt + param[1] = 1; // response + param[2] = 1.0; // a + param[3] = 1.0; // v + param[4] = 0.0; // t0 + param[5] = 0.5; // w + param[6] = 0.0; // sv + parameters.push_back(param); + log_prob.push_back(-1.318964569346917); // expected log_prob + + param[0] = 1.0; // rt + param[1] = 1; // response + param[2] = 1.0; // a + param[3] = -1.0; // v + param[4] = 0.5; // t0 + param[5] = 0.5; // w + param[6] = 0.0; // sv + parameters.push_back(param); + log_prob.push_back(-0.403297055469638); // expected log_prob + + param[0] = 1.0; // rt + param[1] = 1; // response + param[2] = 1.0; // a + param[3] = -1.0; // v + param[4] = 0.0; // t0 + param[5] = 0.2; // w + param[6] = 0.0; // sv + parameters.push_back(param); + log_prob.push_back(-0.08206670817568271); // expected log_prob + + param[0] = 1.0; // rt + param[1] = 1; // response + param[2] = 1.0; // a + param[3] = -1.0; // v + param[4] = 0.0; // t0 + param[5] = 0.5; // w + param[6] = 1.0; // sv + parameters.push_back(param); + log_prob.push_back(-0.3658772238413681); // expected log_prob + } + + void invalid_values(vector& index, vector& value) { + // rt + index.push_back(0U); + value.push_back(0.0); + + index.push_back(0U); + value.push_back(-1.0); + + index.push_back(0U); + value.push_back(INFTY); + + index.push_back(0U); + value.push_back(-INFTY); + + // response + index.push_back(1U); + value.push_back(0); + + index.push_back(1U); + value.push_back(3); + + index.push_back(1U); + value.push_back(-1); + + index.push_back(1U); + value.push_back(INFTY); + + index.push_back(1U); + value.push_back(-INFTY); + + // a + index.push_back(2U); + value.push_back(0.0); + + index.push_back(2U); + value.push_back(-1.0); + + index.push_back(2U); + value.push_back(INFTY); + + index.push_back(2U); + value.push_back(-INFTY); + + // v + index.push_back(3U); + value.push_back(INFTY); + + index.push_back(3U); + value.push_back(-INFTY); + + // t0 + index.push_back(4U); + value.push_back(-1); + + index.push_back(4U); + value.push_back(INFTY); + + index.push_back(4U); + value.push_back(-INFTY); + + // w + index.push_back(5U); + value.push_back(-0.1); + + index.push_back(5U); + value.push_back(0.0); + + index.push_back(5U); + value.push_back(1.0); + + index.push_back(5U); + value.push_back(1.1); + + index.push_back(5U); + value.push_back(INFTY); + + index.push_back(5U); + value.push_back(-INFTY); + + // sv + index.push_back(6U); + value.push_back(-1.0); + + index.push_back(6U); + value.push_back(INFTY); + + index.push_back(6U); + value.push_back(-INFTY); + } + + bool has_upper_bound() { return true; } + + double upper_bound() { return 0.0; } + + template + stan::return_type_t cdf( + const T_rt& rt, const T_response& response, const T_a& a, const T_v& v, + const T_t0& t0, const T_w& w, const T_sv& sv, const T7&) { + return stan::math::ddm_lcdf(rt, response, a, v, t0, w, sv); + } + + template + stan::return_type_t + cdf_function(const T_rt& rt, const T_response& response, const T_a& a, + const T_v& v, const T_t0& t0, const T_w& w, const T_sv& sv, + const T7&) { + return stan::math::ddm_lcdf(rt, response, a, v, t0, w, sv); + } +}; diff --git a/test/prob/ddm/ddm_test.hpp b/test/prob/ddm/ddm_test.hpp new file mode 100644 index 00000000000..17d7cb7d239 --- /dev/null +++ b/test/prob/ddm/ddm_test.hpp @@ -0,0 +1,202 @@ +// Arguments: Doubles, Ints, Doubles, Doubles, Doubles, Doubles, Doubles +#include + +using std::vector; +using stan::math::INFTY; + +class AgradDistributionDdm : public AgradDistributionTest { +public: + void valid_values(vector >& parameters, + vector& log_prob) { + vector param(7); + + // each expected log_prob is calculated with the R package `fddm` as follows + // fddm::dfddm(rt, response, a, v, t0, w, sv, log = TRUE) + + param[0] = 1.0; // rt + param[1] = 1; // response + param[2] = 1.0; // a + param[3] = -1.0; // v + param[4] = 0.0; // t0 + param[5] = 0.5; // w + param[6] = 0.0; // sv + parameters.push_back(param); + log_prob.push_back(-3.790072391414288); // expected log_prob + + param[0] = 1.0; // rt + param[1] = 2; // response + param[2] = 1.0; // a + param[3] = -1.0; // v + param[4] = 0.0; // t0 + param[5] = 0.5; // w + param[6] = 0.0; // sv + parameters.push_back(param); + log_prob.push_back(-4.790072391414288); // expected log_prob + + param[0] = 1.0; // rt + param[1] = 1; // response + param[2] = 2.0; // a + param[3] = -1.0; // v + param[4] = 0.0; // t0 + param[5] = 0.5; // w + param[6] = 0.0; // sv + parameters.push_back(param); + log_prob.push_back(-0.9754202070046213); // expected log_prob + + param[0] = 1.0; // rt + param[1] = 1; // response + param[2] = 1.0; // a + param[3] = 1.0; // v + param[4] = 0.0; // t0 + param[5] = 0.5; // w + param[6] = 0.0; // sv + parameters.push_back(param); + log_prob.push_back(-4.790072391414288); // expected log_prob + + param[0] = 1.0; // rt + param[1] = 1; // response + param[2] = 1.0; // a + param[3] = -1.0; // v + param[4] = 0.5; // t0 + param[5] = 0.5; // w + param[6] = 0.0; // sv + parameters.push_back(param); + log_prob.push_back(-1.072671222447106); // expected log_prob + + param[0] = 1.0; // rt + param[1] = 1; // response + param[2] = 1.0; // a + param[3] = -1.0; // v + param[4] = 0.0; // t0 + param[5] = 0.2; // w + param[6] = 0.0; // sv + parameters.push_back(param); + log_prob.push_back(-4.621465563321226); // expected log_prob + + param[0] = 1.0; // rt + param[1] = 1; // response + param[2] = 1.0; // a + param[3] = -1.0; // v + param[4] = 0.0; // t0 + param[5] = 0.5; // w + param[6] = 1.0; // sv + parameters.push_back(param); + log_prob.push_back(-4.07414598169426); // expected log_prob + } + + void invalid_values(vector& index, vector& value) { + // rt + index.push_back(0U); + value.push_back(0.0); + + index.push_back(0U); + value.push_back(-1.0); + + index.push_back(0U); + value.push_back(INFTY); + + index.push_back(0U); + value.push_back(-INFTY); + + // response + index.push_back(1U); + value.push_back(0); + + index.push_back(1U); + value.push_back(3); + + index.push_back(1U); + value.push_back(-1); + + index.push_back(1U); + value.push_back(INFTY); + + index.push_back(1U); + value.push_back(-INFTY); + + // a + index.push_back(2U); + value.push_back(0.0); + + index.push_back(2U); + value.push_back(-1.0); + + index.push_back(2U); + value.push_back(INFTY); + + index.push_back(2U); + value.push_back(-INFTY); + + // v + index.push_back(3U); + value.push_back(INFTY); + + index.push_back(3U); + value.push_back(-INFTY); + + // t0 + index.push_back(4U); + value.push_back(-1); + + index.push_back(4U); + value.push_back(INFTY); + + index.push_back(4U); + value.push_back(-INFTY); + + // w + index.push_back(5U); + value.push_back(-0.1); + + index.push_back(5U); + value.push_back(0.0); + + index.push_back(5U); + value.push_back(1.0); + + index.push_back(5U); + value.push_back(1.1); + + index.push_back(5U); + value.push_back(INFTY); + + index.push_back(5U); + value.push_back(-INFTY); + + // sv + index.push_back(6U); + value.push_back(-1.0); + + index.push_back(6U); + value.push_back(INFTY); + + index.push_back(6U); + value.push_back(-INFTY); + } + + template + stan::return_type_t log_prob( + const T_rt& rt, const T_response& response, const T_a& a, const T_v& v, + const T_t0& t0, const T_w& w, const T_sv& sv, const T7&) { + return stan::math::ddm_lpdf(rt, response, a, v, t0, w, sv); + } + + template + stan::return_type_t log_prob( + const T_rt& rt, const T_response& response, const T_a& a, const T_v& v, + const T_t0& t0, const T_w& w, const T_sv& sv, const T7&) { + return stan::math::ddm_lpdf(rt, response, a, v, t0, w, sv); + } + + template + stan::return_type_t + log_prob_function(const T_rt& rt, const T_response& response, const T_a& a, + const T_v& v, const T_t0& t0, const T_w& w, const T_sv& sv, + const T7&) { + return stan::math::ddm_lpdf(rt, response, a, v, t0, w, sv); + } +}; diff --git a/test/unit/math/prim/prob/ddm_cdf_log_test.cpp b/test/unit/math/prim/prob/ddm_cdf_log_test.cpp new file mode 100644 index 00000000000..863476f8ac2 --- /dev/null +++ b/test/unit/math/prim/prob/ddm_cdf_log_test.cpp @@ -0,0 +1,97 @@ +#include +#include +#include + + +TEST(ProbDdm, ddm_lcdf_matches_known_cdf) { + using std::vector; + using std::exp; + using stan::math::ddm_lcdf; + + // Note: + // 1. Each expected log_prob is calculated with the R package `fddm` using + // fddm::pfddm(rt, response, a, v, t0, w, sv, log = TRUE) + // 2. We must define a large error tolerance because we must check on the log + // scale, and since the sum of the log PDFs is very negative (~ -1800), + // exponentiating this very negative value would result in rounding to + // zero. This value of error tolerance is slightly arbitrary, but it is + // more useful than comparing zero to zero due to rounding issues. + + static const double pfddm_output = -972.3893812; + static const double err_tol = 1.0; + static const vector rt{0.1, 1, 10.0}; + static const int response = 1; // "lower" threshold + static const vector a{0.5, 1.0, 5.0}; + static const vector v{-2.0, 0.0, 2.0}; + static const double t0 = 0.0001; + static const vector w{0.2, 0.5, 0.8}; + static const vector sv{0.0, 0.5, 1.0, 1.5}; + + static const int n_rt = rt.size(); + static const int n_a = a.size(); + static const int n_v = v.size(); + static const int n_w = w.size(); + static const int n_sv = sv.size(); + static const int n_max = n_rt * n_a * n_v * n_w * n_sv; + vector rt_vec(n_max); + vector a_vec(n_max); + vector v_vec(n_max); + vector t0_vec(n_max, t0); + vector w_vec(n_max); + vector sv_vec(n_max); + + for (int i = 0; i < n_rt; i++) { + for (int j = i; j < n_max; j += n_rt) { + rt_vec[j] = rt[i]; + } + } + for (int i = 0; i < n_a; i++) { + for (int j = i; j < n_max; j += n_a) { + a_vec[j] = a[i]; + } + } + for (int i = 0; i < n_v; i++) { + for (int j = i; j < n_max; j += n_v) { + v_vec[j] = v[i]; + } + } + for (int i = 0; i < n_w; i++) { + for (int j = i; j < n_max; j += n_w) { + w_vec[j] = w[i]; + } + } + for (int i = 0; i < n_sv; i++) { + for (int j = i; j < n_max; j += n_sv) { + sv_vec[j] = sv[i]; + } + } + + + EXPECT_NEAR((ddm_lcdf(rt, response, a, v, t0_vec, w, sv_vec)), + pfddm_output, + err_tol); + EXPECT_NEAR((ddm_lcdf(rt, response, a, v, t0_vec, w, sv_vec)), + 0, // true makes ddm_lcdf() and wiener_lpdf() evaluate to 0 + err_tol); + EXPECT_NEAR((ddm_lcdf(rt, response, a, v, t0_vec, w, sv_vec)), + pfddm_output, + err_tol); + EXPECT_NEAR( + (ddm_lcdf, int, vector, vector, + vector, vector, vector >( + rt, response, a, v, t0_vec, w, sv_vec)), + pfddm_output, + err_tol); + EXPECT_NEAR( + (ddm_lcdf, int, vector, vector, + vector, vector, vector >( + rt, response, a, v, t0_vec, w, sv_vec)), + 0, // true makes ddm_lcdf() and wiener_lpdf() evaluate to 0 + err_tol); + EXPECT_NEAR( + (ddm_lcdf, int, vector, vector, + vector, vector, vector >( + rt, response, a, v, t0_vec, w, sv_vec)), + pfddm_output, + err_tol); +} diff --git a/test/unit/math/prim/prob/ddm_test.cpp b/test/unit/math/prim/prob/ddm_test.cpp new file mode 100644 index 00000000000..e1790c4f152 --- /dev/null +++ b/test/unit/math/prim/prob/ddm_test.cpp @@ -0,0 +1,333 @@ +#include +#include +#include +#include + +// ddm_lpdf(rt, response, a, v, t0, w, sv) +// wiener_lpdf(y, alpha, tau, beta, delta) +// rt <- y +// response <- 2 (wiener_lpdf() always uses the "upper" threshold) +// a <- alpha +// v <- delta +// t0 <- tau +// w <- beta +// sv is not included in wiener_lpdf() +// alpha -> a +// tau -> t0 +// beta -> w +// delta -> v +// Note: `response` and `sv` are not included in wiener_lpdf() + + +// Check invalid arguments + +// rt +TEST(mathPrimScalProbDdmScal, invalid_rt) { + using stan::math::ddm_lpdf; + using stan::math::INFTY; + EXPECT_THROW(ddm_lpdf(0, 2, 1, -1, 0, 0.5, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(-1, 2, 1, -1, 0, 0.5, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(INFTY, 2, 1, -1, 0, 0.5, 0), + std::domain_error); + EXPECT_THROW(ddm_lpdf(-INFTY, 2, 1, -1, 0, 0.5, 0), + std::domain_error); + EXPECT_THROW(ddm_lpdf(NAN, 2, 1, -1, 0, 0.5, 0), std::domain_error); +} +TEST(mathPrimScalProbDdmMat, invalid_rt) { + using stan::math::ddm_lpdf; + using stan::math::INFTY; + std::vector rt{1, 0}; + EXPECT_THROW(ddm_lpdf(rt, 2, 1, -1, 0, 0.5, 0), std::domain_error); + rt[1] = -1; + EXPECT_THROW(ddm_lpdf(rt, 2, 1, -1, 0, 0.5, 0), std::domain_error); + rt[1] = INFTY; + EXPECT_THROW(ddm_lpdf(rt, 2, 1, -1, 0, 0.5, 0), std::domain_error); + rt[1] = -INFTY; + EXPECT_THROW(ddm_lpdf(rt, 2, 1, -1, 0, 0.5, 0), std::domain_error); + rt[1] = NAN; + EXPECT_THROW(ddm_lpdf(rt, 2, 1, -1, 0, 0.5, 0), std::domain_error); +} + +// response +TEST(mathPrimScalProbDdmScal, invalid_response) { + using stan::math::ddm_lpdf; + using stan::math::INFTY; + EXPECT_THROW(ddm_lpdf(1, 0, 1, -1, 0, 0.5, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 3, 1, -1, 0, 0.5, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, -1, 1, -1, 0, 0.5, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, INFTY, 1, -1, 0, 0.5, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, -INFTY, 1, -1, 0, 0.5, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, NAN, 1, -1, 0, 0.5, 0), std::domain_error); +} +TEST(mathPrimScalProbDdmMat, invalid_response) { + using stan::math::ddm_lpdf; + using stan::math::INFTY; + std::vector response{2, 0}; + EXPECT_THROW(ddm_lpdf(1, response, 1, -1, 0, 0.5, 0), std::domain_error); + response[1] = 3; + EXPECT_THROW(ddm_lpdf(1, response, 1, -1, 0, 0.5, 0), std::domain_error); + response[1] = -1; + EXPECT_THROW(ddm_lpdf(1, response, 1, -1, 0, 0.5, 0), std::domain_error); + response[1] = INFTY; + EXPECT_THROW(ddm_lpdf(1, response, 1, -1, 0, 0.5, 0), std::domain_error); + response[1] = -INFTY; + EXPECT_THROW(ddm_lpdf(1, response, 1, -1, 0, 0.5, 0), std::domain_error); + response[1] = NAN; + EXPECT_THROW(ddm_lpdf(1, response, 1, -1, 0, 0.5, 0), std::domain_error); +} + +// a +TEST(mathPrimScalProbDdmScal, invalid_a) { + using stan::math::ddm_lpdf; + using stan::math::INFTY; + EXPECT_THROW(ddm_lpdf(1, 2, 0, -1, 0, 0.5, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 2, -1, -1, 0, 0.5, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 2, INFTY, -1, 0, 0.5, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 2, -INFTY, -1, 0, 0.5, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 2, NAN, -1, 0, 0.5, 0), std::domain_error); +} +TEST(mathPrimScalProbDdmMat, invalid_a) { + using stan::math::ddm_lpdf; + using stan::math::INFTY; + std::vector a{1, 0}; + EXPECT_THROW(ddm_lpdf(1, 2, a, -1, 0, 0.5, 0), std::domain_error); + a[1] = -1; + EXPECT_THROW(ddm_lpdf(1, 2, a, -1, 0, 0.5, 0), std::domain_error); + a[1] = INFTY; + EXPECT_THROW(ddm_lpdf(1, 2, a, -1, 0, 0.5, 0), std::domain_error); + a[1] = -INFTY; + EXPECT_THROW(ddm_lpdf(1, 2, a, -1, 0, 0.5, 0), std::domain_error); + a[1] = NAN; + EXPECT_THROW(ddm_lpdf(1, 2, a, -1, 0, 0.5, 0), std::domain_error); +} + +// v +TEST(mathPrimScalProbDdmScal, invalid_v) { + using stan::math::ddm_lpdf; + using stan::math::INFTY; + EXPECT_THROW(ddm_lpdf(1, 2, 1, INFTY, 0, 0.5, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 2, 1, -INFTY, 0, 0.5, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 2, 1, NAN, 0, 0.5, 0), std::domain_error); +} +TEST(mathPrimScalProbDdmMat, invalid_v) { + using stan::math::ddm_lpdf; + using stan::math::INFTY; + std::vector v{1, INFTY}; + EXPECT_THROW(ddm_lpdf(1, 2, 1, v, 0, 0.5, 0), std::domain_error); + v[1] = -INFTY; + EXPECT_THROW(ddm_lpdf(1, 2, 1, v, 0, 0.5, 0), std::domain_error); + v[1] = NAN; + EXPECT_THROW(ddm_lpdf(1, 2, 1, v, 0, 0.5, 0), std::domain_error); +} + +// t0 +TEST(mathPrimScalProbDdmScal, invalid_t0) { + using stan::math::ddm_lpdf; + using stan::math::INFTY; + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, -1, 0.5, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, INFTY, 0.5, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, -INFTY, 0.5, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, NAN, 0.5, 0), std::domain_error); +} +TEST(mathPrimScalProbDdmMat, invalid_t0) { + using stan::math::ddm_lpdf; + using stan::math::INFTY; + std::vector t0{1, -1}; + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, t0, 0.5, 0), std::domain_error); + t0[1] = INFTY; + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, t0, 0.5, 0), std::domain_error); + t0[1] = -INFTY; + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, t0, 0.5, 0), std::domain_error); + t0[1] = NAN; + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, t0, 0.5, 0), std::domain_error); +} + +// w +TEST(mathPrimScalProbDdmScal, invalid_w) { + using stan::math::ddm_lpdf; + using stan::math::INFTY; + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, -0.1, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, 0, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, 1, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, 1.1, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, INFTY, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, -INFTY, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, NAN, 0), std::domain_error); +} +TEST(mathPrimScalProbDdmMat, invalid_w) { + using stan::math::ddm_lpdf; + using stan::math::INFTY; + std::vector w{1, -0.1}; + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, w, 0), std::domain_error); + w[1] = 0; + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, w, 0), std::domain_error); + w[1] = 1; + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, w, 0), std::domain_error); + w[1] = 1.1; + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, w, 0), std::domain_error); + w[1] = INFTY; + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, w, 0), std::domain_error); + w[1] = -INFTY; + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, w, 0), std::domain_error); + w[1] = NAN; + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, w, 0), std::domain_error); +} + +// sv +TEST(mathPrimScalProbDdmScal, invalid_sv) { + using stan::math::ddm_lpdf; + using stan::math::INFTY; + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, 0.5, -1), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, 0.5, INFTY), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, 0.5, -INFTY), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, 0.5, NAN), std::domain_error); +} +TEST(mathPrimScalProbDdmMat, invalid_sv) { + using stan::math::ddm_lpdf; + using stan::math::INFTY; + std::vector sv{1, -1}; + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, 0.5, sv), std::domain_error); + sv[1] = INFTY; + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, 0.5, sv), std::domain_error); + sv[1] = -INFTY; + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, 0.5, sv), std::domain_error); + sv[1] = NAN; + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, 0.5, sv), std::domain_error); +} + +TEST(ProbDdm, ddm_lpdf_matches_wiener_lpdf) { + using std::vector; + using std::exp; + using stan::math::ddm_lpdf; + using stan::math::wiener_lpdf; + // Notes: + // 1. define error tolerance for PDF approximations, use double tolerance to + // allow for convergence (of the truncated infinite sum) from above and below + // 2. wiener_lpdf() only uses the "upper" threshold in the DDM, but the + // "lower" threshold maps v -> -v and w -> 1-w, and this is covered in the + // parameter values defined below + double err_tol = 2 * 0.000001; + vector rt{0.1, 1, 10.0}; + int response = 2; // wiener_lpdf() always uses the "upper" threshold + vector a{0.5, 1.0, 5.0}; + vector v{-2.0, 0.0, 2.0}; + double t0 = 0.0001; // t0 (i.e., tau) needs to be > 0 for wiener_lpdf() + vector w{0.2, 0.5, 0.8}; + double sv = 0.0; // sv is not included in wiener_lpdf(), and thus it must be 0 + + int n_rt = rt.size(); + int n_a = a.size(); + int n_v = v.size(); + int n_w = w.size(); + int n_max = n_rt * n_a * n_v * n_w; + vector rt_vec(n_max); + vector a_vec(n_max); + vector v_vec(n_max); + vector t0_vec(n_max, t0); + vector w_vec(n_max); + + for (int i = 0; i < n_rt; i++) { + for (int j = i; j < n_max; j += n_rt) { + rt_vec[j] = rt[i]; + } + } + for (int i = 0; i < n_a; i++) { + for (int j = i; j < n_max; j += n_a) { + a_vec[j] = a[i]; + } + } + for (int i = 0; i < n_v; i++) { + for (int j = i; j < n_max; j += n_v) { + v_vec[j] = v[i]; + } + } + for (int i = 0; i < n_w; i++) { + for (int j = i; j < n_max; j += n_w) { + w_vec[j] = w[i]; + } + } + + // The PDF approximation error tolerance is based on the non-log version of + // the PDF. To account for this, we compare the exponentiated version of the + // *_lpdf results. + EXPECT_NEAR(exp(ddm_lpdf(rt, response, a, v, t0_vec, w, sv)), + exp(wiener_lpdf(rt_vec, a_vec, t0_vec, w_vec, v_vec)), + err_tol); + EXPECT_NEAR(exp(ddm_lpdf(rt, response, a, v, t0_vec, w, sv)), + exp(wiener_lpdf(rt_vec, a_vec, t0_vec, w_vec, v_vec)), + err_tol); + EXPECT_NEAR(exp(ddm_lpdf(rt, response, a, v, t0_vec, w, sv)), + exp(wiener_lpdf(rt_vec, a_vec, t0_vec, w_vec, v_vec)), + err_tol); + EXPECT_NEAR( + exp(ddm_lpdf, int, vector, vector, + vector, vector, double>( + rt, response, a, v, t0_vec, w, sv)), + exp(wiener_lpdf, vector, vector, + vector, vector >( + rt_vec, a_vec, t0_vec, w_vec, v_vec)), + err_tol); + EXPECT_NEAR( + exp(ddm_lpdf, int, vector, vector, + vector, vector, double>( + rt, response, a, v, t0_vec, w, sv)), + exp(wiener_lpdf, vector, vector, + vector, vector >( + rt_vec, a_vec, t0_vec, w_vec, v_vec)), + err_tol); + EXPECT_NEAR( + exp(ddm_lpdf, int, vector, vector, + vector, vector, double>( + rt, response, a, v, t0_vec, w, sv)), + exp(wiener_lpdf, vector, vector, + vector, vector >( + rt_vec, a_vec, t0_vec, w_vec, v_vec)), + err_tol); + + // check with variable drift rate (against results from R package `fddm`) + // Notes: + // 1. We must redefine the error tolerance because we must check on the log + // scale as the sum of the log PDFs is very negative (~ -1800), and + // exponentiating this very negative value would result in rounding to zero. + // This value of error tolerance is slightly arbitrary, but it is more useful + // than comparing zero to zero due to rounding issues. + err_tol = 1.0; + static const double dfddm_output = -1800.6359154; + vector sv_vals{0.0, 0.5, 1.0, 1.5}; + int n_sv = sv_vals.size(); + n_max *= n_sv; + vector sv_vec(n_max); + for (int i = 0; i < n_sv; i++) { + for (int j = i; j < n_max; j += n_sv) { + sv_vec[j] = sv_vals[i]; + } + } + + EXPECT_NEAR((ddm_lpdf(rt, response, a, v, t0_vec, w, sv_vec)), + dfddm_output, + err_tol); + EXPECT_NEAR((ddm_lpdf(rt, response, a, v, t0_vec, w, sv_vec)), + 0, // true makes ddm_lpdf() and wiener_lpdf() evaluate to 0 + err_tol); + EXPECT_NEAR((ddm_lpdf(rt, response, a, v, t0_vec, w, sv_vec)), + dfddm_output, + err_tol); + EXPECT_NEAR( + (ddm_lpdf, int, vector, vector, + vector, vector, vector >( + rt, response, a, v, t0_vec, w, sv_vec)), + dfddm_output, + err_tol); + EXPECT_NEAR( + (ddm_lpdf, int, vector, vector, + vector, vector, vector >( + rt, response, a, v, t0_vec, w, sv_vec)), + 0, // true makes ddm_lpdf() and wiener_lpdf() evaluate to 0 + err_tol); + EXPECT_NEAR( + (ddm_lpdf, int, vector, vector, + vector, vector, vector >( + rt, response, a, v, t0_vec, w, sv_vec)), + dfddm_output, + err_tol); +} From 9e595d01bba3eccb1d37d44a9e9ed1a50986e7a9 Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Mon, 20 Sep 2021 02:09:42 +0000 Subject: [PATCH 2/2] [Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.04.1 (tags/RELEASE_600/final) --- stan/math/prim/prob/ddm_lcdf.hpp | 318 +++++++++--------- stan/math/prim/prob/ddm_lpdf.hpp | 238 ++++++------- test/prob/ddm/ddm_cdf_log_test.hpp | 99 +++--- test/prob/ddm/ddm_test.hpp | 94 +++--- test/unit/math/prim/prob/ddm_cdf_log_test.cpp | 55 ++- test/unit/math/prim/prob/ddm_test.cpp | 111 +++--- 6 files changed, 448 insertions(+), 467 deletions(-) diff --git a/stan/math/prim/prob/ddm_lcdf.hpp b/stan/math/prim/prob/ddm_lcdf.hpp index 066f53970c3..66ffee27ff0 100644 --- a/stan/math/prim/prob/ddm_lcdf.hpp +++ b/stan/math/prim/prob/ddm_lcdf.hpp @@ -9,7 +9,6 @@ #include #include - // Open the Namespace namespace stan { namespace math { @@ -44,23 +43,23 @@ using stan::return_type_t; template return_type_t ddm_lcdf( - const T_rt& rt, const T_response& response, const T_a& a, - const T_v& v, const T_t0& t0, const T_w& w, const T_sv& sv) { - using T_return_type = return_type_t; - using std::vector; - using std::log; - using std::exp; + const T_rt& rt, const T_response& response, const T_a& a, const T_v& v, + const T_t0& t0, const T_w& w, const T_sv& sv) { + using T_return_type + = return_type_t; + using stan::ref_type_t; + using stan::scalar_seq_view; + using stan::math::include_summand; + using stan::math::invalid_argument; + using stan::math::throw_domain_error; using std::erf; - using std::sqrt; + using std::exp; using std::isfinite; using std::isnan; + using std::log; using std::max; - using stan::ref_type_t; - using stan::scalar_seq_view; - using stan::math::throw_domain_error; - using stan::math::invalid_argument; - using stan::math::include_summand; + using std::sqrt; + using std::vector; using T_rt_ref = ref_type_t; using T_response_ref = ref_type_t; using T_a_ref = ref_type_t; @@ -71,8 +70,8 @@ return_type_t ddm_lcdf( // Constants static const char* function = "ddm_lcdf"; - static const double ERR_TOL = 0.000001; // error tolerance for PDF approx - static const double PI_CONST = 3.14159265358979323846; // define pi like C++ + static const double ERR_TOL = 0.000001; // error tolerance for PDF approx + static const double PI_CONST = 3.14159265358979323846; // define pi like C++ static const double SQRT_2PI = sqrt(2 * PI_CONST); static const double SQRT_2PI_INV = 1 / SQRT_2PI; static const double SQRT_2_INV_NEG = -1 / sqrt(2); @@ -94,198 +93,199 @@ return_type_t ddm_lcdf( scalar_seq_view sv_vec(sv_ref); // Parameter Checks - size_t Nrt = rt_vec.size(); + size_t Nrt = rt_vec.size(); size_t Nres = response_vec.size(); - size_t Na = a_vec.size(); - size_t Nv = v_vec.size(); - size_t Nt0 = t0_vec.size(); - size_t Nw = w_vec.size(); - size_t Nsv = sv_vec.size(); + size_t Na = a_vec.size(); + size_t Nv = v_vec.size(); + size_t Nt0 = t0_vec.size(); + size_t Nw = w_vec.size(); + size_t Nsv = sv_vec.size(); size_t Nmax = max({Nrt, Nres, Na, Nv, Nt0, Nw, Nsv}); - vector out(Nmax); // initialize output-checking vector + vector out(Nmax); // initialize output-checking vector - if (Nrt < 1) { // rt, invalid inputs will be handled in calculation of the CDF - return 0; + if (Nrt + < 1) { // rt, invalid inputs will be handled in calculation of the CDF + return 0; } - if (Nres < 1) { // response + if (Nres < 1) { // response return 0; } else { for (size_t i = 0; i < Nres; i++) { - if (response_vec[i] == 1) { // lower + if (response_vec[i] == 1) { // lower for (size_t j = i; j < Nmax; j += Nres) { out[j] = 1; } - } else if (response_vec[i] == 2) { // upper + } else if (response_vec[i] == 2) { // upper for (size_t j = i; j < Nmax; j += Nres) { out[j] = 2; } - } else { // {NaN, NA} evaluate to FALSE + } else { // {NaN, NA} evaluate to FALSE throw_domain_error(function, "response", response_vec[i], " = ", ", but it must be either 1 (lower) or 2 (upper)"); } } } - if (Na < 1) { // a + if (Na < 1) { // a return 0; } else { for (size_t i = 0; i < Na; i++) { if (a_vec[i] > 0) { if (isfinite(a_vec[i])) { continue; - } else { // a = Inf implies PDF = log(0) and CDF problems + } else { // a = Inf implies PDF = log(0) and CDF problems throw_domain_error(function, "a (threshold separation)", a_vec[i], " = ", ", but it must be finite"); } - } else { // {NaN, NA} evaluate to FALSE + } else { // {NaN, NA} evaluate to FALSE throw_domain_error(function, "a (threshold separation)", a_vec[i], " = ", ", but it must be positive and finite"); } } } - if (Nv < 1) { // v + if (Nv < 1) { // v return 0; } else { for (size_t i = 0; i < Nv; i++) { if (isfinite(v_vec[i])) { continue; - } else { // NaN, NA, Inf, -Inf are not finite + } else { // NaN, NA, Inf, -Inf are not finite throw_domain_error(function, "v (drift rate)", v_vec[i], " = ", ", but it must be finite"); } } } - if (Nt0 < 1) { // t0 + if (Nt0 < 1) { // t0 return 0; } else { for (size_t i = 0; i < Nt0; i++) { if (t0_vec[i] >= 0) { - if (isfinite(t0_vec[i])) { // this could also be handled in calculation of CDF + if (isfinite(t0_vec[i])) { // this could also be handled in calculation + // of CDF continue; - } else { // t0 = Inf implies rt - t0 < 0 implies CDF = log(0) + } else { // t0 = Inf implies rt - t0 < 0 implies CDF = log(0) throw_domain_error(function, "t0 (non-decision time)", t0_vec[i], " = ", ", but it must be finite"); } - } else { // {NaN, NA} evaluate to FALSE - throw_domain_error(function, "t0 (non-decision time)", t0_vec[i], - " = ", ", but it must be positive and finite"); + } else { // {NaN, NA} evaluate to FALSE + throw_domain_error(function, "t0 (non-decision time)", t0_vec[i], " = ", + ", but it must be positive and finite"); } } } - if (Nw < 1) { // w + if (Nw < 1) { // w return 0; } else { for (size_t i = 0; i < Nw; i++) { if (w_vec[i] > 0 && w_vec[i] < 1) { continue; - } else { // {NaN, NA} evaluate to FALSE + } else { // {NaN, NA} evaluate to FALSE throw_domain_error(function, "w (relative a priori bias)", w_vec[i], " = ", ", but it must be that 0 < w < 1"); } } } - if (Nsv < 1) { // sv - return 0; - } else { - for (size_t i = 0; i < Nsv; i++) { - if (sv_vec[i] >= 0) { - if (isfinite(sv_vec[i])) { - continue; - } else { // sv = Inf implies PDF = log(0) and CDF problems - throw_domain_error(function, "sv (standard deviation of drift rate across trials)", - sv_vec[i], " = ", ", but it must be finite"); - } - } else { // {NaN, NA} evaluate to FALSE - throw_domain_error(function, "sv (standard deviation of drift rate across trials)", - sv_vec[i], " = ", - ", but it must be positive and finite"); + if (Nsv < 1) { // sv + return 0; + } else { + for (size_t i = 0; i < Nsv; i++) { + if (sv_vec[i] >= 0) { + if (isfinite(sv_vec[i])) { + continue; + } else { // sv = Inf implies PDF = log(0) and CDF problems + throw_domain_error( + function, "sv (standard deviation of drift rate across trials)", + sv_vec[i], " = ", ", but it must be finite"); } + } else { // {NaN, NA} evaluate to FALSE + throw_domain_error( + function, "sv (standard deviation of drift rate across trials)", + sv_vec[i], " = ", ", but it must be positive and finite"); } } + } if (!include_summand::value) { + T_sv>::value) { return 0; } - // Calculate log(CDF) T_return_type lp(0.0); double t, a_i, v_i, w_i, sv_i; for (size_t i = 0; i < Nmax; i++) { - // Check Parameter Values - t = rt_vec[i % Nrt] - t0_vec[i % Nt0]; // response time minus non-decision time - if (t > 0) { // sort response and calculate density + t = rt_vec[i % Nrt] + - t0_vec[i % Nt0]; // response time minus non-decision time + if (t > 0) { // sort response and calculate density a_i = a_vec[i % Na]; sv_i = sv_vec[i % Nsv]; - if (out[i] == 1) { // response is "lower" so use unchanged parameters + if (out[i] == 1) { // response is "lower" so use unchanged parameters v_i = v_vec[i % Nv]; w_i = w_vec[i % Nw]; - } else { // response is "upper" so use alternate parameters + } else { // response is "upper" so use alternate parameters v_i = -v_vec[i % Nv]; w_i = 1 - w_vec[i % Nw]; } - - if (t > 32) { // approximation for t = +Infinity + + if (t > 32) { // approximation for t = +Infinity t = 32; } - + // Calculate sum multiplier - double mult = (sv_i*sv_i * a_i*a_i * w_i*w_i - - 2 * v_i * a_i * w_i - v_i*v_i * t) / - (2 + 2 * sv_i*sv_i * t); - + double mult = (sv_i * sv_i * a_i * a_i * w_i * w_i - 2 * v_i * a_i * w_i + - v_i * v_i * t) + / (2 + 2 * sv_i * sv_i * t); + // Scale error so it is valid inside the sum (not logged) double exp_err = ERR_TOL * exp(-mult); - + // Calculate sum double sum = 0; - double gamma = v_i - sv_i*sv_i * a_i * w_i; - double lambda = 1 + sv_i*sv_i * t; + double gamma = v_i - sv_i * sv_i * a_i * w_i; + double lambda = 1 + sv_i * sv_i * t; double rho = sqrt(t * lambda); - + int j = 0; double rj = a_i * j + a_i * w_i; double m1 = (lambda * rj - gamma * t) / rho; double m2 = (lambda * rj + gamma * t) / rho; double mills_1, mills_2; if (m1 < 6.5) { - mills_1 = SQRT_2PI * 0.5 * exp(0.5 * m1*m1) * - (1 + erf(SQRT_2_INV_NEG * m1)); + mills_1 = SQRT_2PI * 0.5 * exp(0.5 * m1 * m1) + * (1 + erf(SQRT_2_INV_NEG * m1)); } else { - double m1sq = m1*m1; - mills_1 = ( 1 - - 1 / (m1sq + 2) + - 1 / ( (m1sq + 2) * (m1sq + 4) ) - - 5 / ( (m1sq + 2) * (m1sq + 4) * (m1sq + 6) ) + - 9 / ( (m1sq + 2) * (m1sq + 4) * (m1sq + 6) * (m1sq + 8) ) - - 129 / ( (m1sq + 2) * (m1sq + 4) * (m1sq + 6) * (m1sq + 8) * - (m1sq + 10) ) - ) / m1; + double m1sq = m1 * m1; + mills_1 = (1 - 1 / (m1sq + 2) + 1 / ((m1sq + 2) * (m1sq + 4)) + - 5 / ((m1sq + 2) * (m1sq + 4) * (m1sq + 6)) + + 9 / ((m1sq + 2) * (m1sq + 4) * (m1sq + 6) * (m1sq + 8)) + - 129 + / ((m1sq + 2) * (m1sq + 4) * (m1sq + 6) * (m1sq + 8) + * (m1sq + 10))) + / m1; } if (m2 < 6.5) { - mills_2 = SQRT_2PI * 0.5 * exp(0.5 * m2*m2) * - (1 + erf(SQRT_2_INV_NEG * m2)); + mills_2 = SQRT_2PI * 0.5 * exp(0.5 * m2 * m2) + * (1 + erf(SQRT_2_INV_NEG * m2)); } else { - double m2sq = m2*m2; - mills_2 = ( 1 - - 1 / (m2sq + 2) + - 1 / ( (m2sq + 2) * (m2sq + 4) ) - - 5 / ( (m2sq + 2) * (m2sq + 4) * (m2sq + 6) ) + - 9 / ( (m2sq + 2) * (m2sq + 4) * (m2sq + 6) * (m2sq + 8) ) - - 129 / ( (m2sq + 2) * (m2sq + 4) * (m2sq + 6) * (m2sq + 8) * - (m2sq + 10) ) - ) / m2; + double m2sq = m2 * m2; + mills_2 = (1 - 1 / (m2sq + 2) + 1 / ((m2sq + 2) * (m2sq + 4)) + - 5 / ((m2sq + 2) * (m2sq + 4) * (m2sq + 6)) + + 9 / ((m2sq + 2) * (m2sq + 4) * (m2sq + 6) * (m2sq + 8)) + - 129 + / ((m2sq + 2) * (m2sq + 4) * (m2sq + 6) * (m2sq + 8) + * (m2sq + 10))) + / m2; } - double term = SQRT_2PI_INV * exp(-0.5 * rj*rj / t) * (mills_1 + mills_2); + double term + = SQRT_2PI_INV * exp(-0.5 * rj * rj / t) * (mills_1 + mills_2); sum += term; - + while (term > exp_err) { if (j > 1000) { // maybe include a warning here? @@ -296,89 +296,83 @@ return_type_t ddm_lcdf( m1 = (lambda * rj - gamma * t) / rho; m2 = (lambda * rj + gamma * t) / rho; if (m1 < 6.5) { - mills_1 = SQRT_2PI * 0.5 * exp(0.5 * m1*m1) * - (1 + erf(SQRT_2_INV_NEG * m1)); + mills_1 = SQRT_2PI * 0.5 * exp(0.5 * m1 * m1) + * (1 + erf(SQRT_2_INV_NEG * m1)); } else { - double m1sq = m1*m1; - mills_1 = ( 1 - - 1 / (m1sq + 2) + - 1 / ( (m1sq + 2) * (m1sq + 4) ) - - 5 / ( (m1sq + 2) * (m1sq + 4) * (m1sq + 6) ) + - 9 / ( (m1sq + 2) * (m1sq + 4) * (m1sq + 6) * - (m1sq + 8) ) - - 129 / ( (m1sq + 2) * (m1sq + 4) * (m1sq + 6) * - (m1sq + 8) * (m1sq + 10) ) - ) / m1; + double m1sq = m1 * m1; + mills_1 = (1 - 1 / (m1sq + 2) + 1 / ((m1sq + 2) * (m1sq + 4)) + - 5 / ((m1sq + 2) * (m1sq + 4) * (m1sq + 6)) + + 9 / ((m1sq + 2) * (m1sq + 4) * (m1sq + 6) * (m1sq + 8)) + - 129 + / ((m1sq + 2) * (m1sq + 4) * (m1sq + 6) * (m1sq + 8) + * (m1sq + 10))) + / m1; } if (m2 < 6.5) { - mills_2 = SQRT_2PI * 0.5 * exp(0.5 * m2*m2) * - (1 + erf(SQRT_2_INV_NEG * m2)); + mills_2 = SQRT_2PI * 0.5 * exp(0.5 * m2 * m2) + * (1 + erf(SQRT_2_INV_NEG * m2)); } else { - double m2sq = m2*m2; - mills_2 = ( 1 - - 1 / (m2sq + 2) + - 1 / ( (m2sq + 2) * (m2sq + 4) ) - - 5 / ( (m2sq + 2) * (m2sq + 4) * (m2sq + 6) ) + - 9 / ( (m2sq + 2) * (m2sq + 4) * (m2sq + 6) * - (m2sq + 8) ) - - 129 / ( (m2sq + 2) * (m2sq + 4) * (m2sq + 6) * - (m2sq + 8) * (m2sq + 10) ) - ) / m2; + double m2sq = m2 * m2; + mills_2 = (1 - 1 / (m2sq + 2) + 1 / ((m2sq + 2) * (m2sq + 4)) + - 5 / ((m2sq + 2) * (m2sq + 4) * (m2sq + 6)) + + 9 / ((m2sq + 2) * (m2sq + 4) * (m2sq + 6) * (m2sq + 8)) + - 129 + / ((m2sq + 2) * (m2sq + 4) * (m2sq + 6) * (m2sq + 8) + * (m2sq + 10))) + / m2; } - term = SQRT_2PI_INV * exp(-0.5 * rj*rj / t) * (mills_1 + mills_2); + term = SQRT_2PI_INV * exp(-0.5 * rj * rj / t) * (mills_1 + mills_2); sum -= term; - - if (term <= exp_err) break; - + + if (term <= exp_err) + break; + j++; rj = a_i * j + a_i * w_i; m1 = (lambda * rj - gamma * t) / rho; m2 = (lambda * rj + gamma * t) / rho; if (m1 < 6.5) { - mills_1 = SQRT_2PI * 0.5 * exp(0.5 * m1*m1) * - (1 + erf(SQRT_2_INV_NEG * m1)); + mills_1 = SQRT_2PI * 0.5 * exp(0.5 * m1 * m1) + * (1 + erf(SQRT_2_INV_NEG * m1)); } else { - double m1sq = m1*m1; - mills_1 = ( 1 - - 1 / (m1sq + 2) + - 1 / ( (m1sq + 2) * (m1sq + 4) ) - - 5 / ( (m1sq + 2) * (m1sq + 4) * (m1sq + 6) ) + - 9 / ( (m1sq + 2) * (m1sq + 4) * (m1sq + 6) * - (m1sq + 8) ) - - 129 / ( (m1sq + 2) * (m1sq + 4) * (m1sq + 6) * - (m1sq + 8) * (m1sq + 10) ) - ) / m1; + double m1sq = m1 * m1; + mills_1 = (1 - 1 / (m1sq + 2) + 1 / ((m1sq + 2) * (m1sq + 4)) + - 5 / ((m1sq + 2) * (m1sq + 4) * (m1sq + 6)) + + 9 / ((m1sq + 2) * (m1sq + 4) * (m1sq + 6) * (m1sq + 8)) + - 129 + / ((m1sq + 2) * (m1sq + 4) * (m1sq + 6) * (m1sq + 8) + * (m1sq + 10))) + / m1; } if (m2 < 6.5) { - mills_2 = SQRT_2PI * 0.5 * exp(0.5 * m2*m2) * - (1 + erf(SQRT_2_INV_NEG * m2)); + mills_2 = SQRT_2PI * 0.5 * exp(0.5 * m2 * m2) + * (1 + erf(SQRT_2_INV_NEG * m2)); } else { - double m2sq = m2*m2; - mills_2 = ( 1 - - 1 / (m2sq + 2) + - 1 / ( (m2sq + 2) * (m2sq + 4) ) - - 5 / ( (m2sq + 2) * (m2sq + 4) * (m2sq + 6) ) + - 9 / ( (m2sq + 2) * (m2sq + 4) * (m2sq + 6) * - (m2sq + 8) ) - - 129 / ( (m2sq + 2) * (m2sq + 4) * (m2sq + 6) * - (m2sq + 8) * (m2sq + 10) ) - ) / m2; + double m2sq = m2 * m2; + mills_2 = (1 - 1 / (m2sq + 2) + 1 / ((m2sq + 2) * (m2sq + 4)) + - 5 / ((m2sq + 2) * (m2sq + 4) * (m2sq + 6)) + + 9 / ((m2sq + 2) * (m2sq + 4) * (m2sq + 6) * (m2sq + 8)) + - 129 + / ((m2sq + 2) * (m2sq + 4) * (m2sq + 6) * (m2sq + 8) + * (m2sq + 10))) + / m2; } - term = SQRT_2PI_INV * exp(-0.5 * rj*rj / t) * (mills_1 + mills_2); + term = SQRT_2PI_INV * exp(-0.5 * rj * rj / t) * (mills_1 + mills_2); sum += term; } - + // Add sum and multiplier to lp - if (sum >= 0) { // if result is negative, treat as 0 and don't add to lp + if (sum >= 0) { // if result is negative, treat as 0 and don't add to lp lp += mult + log(sum); } - } else { // {NaN, NA} evaluate to FALSE + } else { // {NaN, NA} evaluate to FALSE if (isnan(t)) { throw_domain_error(function, "rt (response time)", rt_vec[i % Nrt], - "is a NaN and = ", ", but this value is invalid"); + "is a NaN and = ", ", but this value is invalid"); } else { - throw_domain_error(function, "rt (response time)", t0_vec[i % Nt0], - "is not greater than t0 = ", ", but it must be that rt - t0 > 0"); + throw_domain_error( + function, "rt (response time)", t0_vec[i % Nt0], + "is not greater than t0 = ", ", but it must be that rt - t0 > 0"); } } } @@ -386,14 +380,14 @@ return_type_t ddm_lcdf( return lp; } -template +template inline return_type_t ddm_lcdf( const T_rt& rt, const T_response& response, const T_a& a, const T_v& v, const T_t0& t0, const T_w& w, const T_sv& sv) { return ddm_lcdf(rt, response, a, v, t0, w, sv); } -} -} // Close namespace +} // namespace math +} // namespace stan #endif diff --git a/stan/math/prim/prob/ddm_lpdf.hpp b/stan/math/prim/prob/ddm_lpdf.hpp index b3d78c46b7a..93b5f22d49d 100644 --- a/stan/math/prim/prob/ddm_lpdf.hpp +++ b/stan/math/prim/prob/ddm_lpdf.hpp @@ -10,7 +10,6 @@ #include #include - // Open the Namespace namespace stan { namespace math { @@ -47,21 +46,21 @@ template ddm_lpdf( const T_rt& rt, const T_response& response, const T_a& a, const T_v& v, const T_t0& t0, const T_w& w, const T_sv& sv) { - using T_return_type = return_type_t; - using std::vector; - using std::log; - using std::exp; - using std::sqrt; - using std::max; - using std::ceil; - using std::isfinite; - using std::isnan; + using T_return_type + = return_type_t; using stan::ref_type_t; using stan::scalar_seq_view; - using stan::math::throw_domain_error; - using stan::math::invalid_argument; using stan::math::include_summand; + using stan::math::invalid_argument; + using stan::math::throw_domain_error; + using std::ceil; + using std::exp; + using std::isfinite; + using std::isnan; + using std::log; + using std::max; + using std::sqrt; + using std::vector; using T_rt_ref = ref_type_t; using T_response_ref = ref_type_t; using T_a_ref = ref_type_t; @@ -72,9 +71,9 @@ return_type_t ddm_lpdf( // Constants static const char* function = "ddm_lpdf"; - static const int max_terms_large = 1; // heuristic for switching mechanism - static const double ERR_TOL = 0.000001; // error tolerance for PDF approx - static const double SV_THRESH = 0; // threshold for using variable drift rate + static const int max_terms_large = 1; // heuristic for switching mechanism + static const double ERR_TOL = 0.000001; // error tolerance for PDF approx + static const double SV_THRESH = 0; // threshold for using variable drift rate static const double LOG_PI = log(M_PI); static const double LOG_2PI_2 = 0.5 * log(2 * M_PI); @@ -95,119 +94,119 @@ return_type_t ddm_lpdf( scalar_seq_view sv_vec(sv_ref); // Parameter Checks - size_t Nrt = rt_vec.size(); + size_t Nrt = rt_vec.size(); size_t Nres = response_vec.size(); - size_t Na = a_vec.size(); - size_t Nv = v_vec.size(); - size_t Nt0 = t0_vec.size(); - size_t Nw = w_vec.size(); - size_t Nsv = sv_vec.size(); + size_t Na = a_vec.size(); + size_t Nv = v_vec.size(); + size_t Nt0 = t0_vec.size(); + size_t Nw = w_vec.size(); + size_t Nsv = sv_vec.size(); size_t Nmax = max({Nrt, Nres, Na, Nv, Nt0, Nw, Nsv}); - vector out(Nmax); // initialize output-checking vector + vector out(Nmax); // initialize output-checking vector - if (Nrt < 1) { // rt, invalid inputs will be handled in calculation of the pdf + if (Nrt + < 1) { // rt, invalid inputs will be handled in calculation of the pdf return 0; } - if (Nres < 1) { // response + if (Nres < 1) { // response return 0; } else { for (size_t i = 0; i < Nres; i++) { - if (response_vec[i] == 1) { // lower + if (response_vec[i] == 1) { // lower for (size_t j = i; j < Nmax; j += Nres) { out[j] = 1; } - } else if (response_vec[i] == 2) { // upper + } else if (response_vec[i] == 2) { // upper for (size_t j = i; j < Nmax; j += Nres) { out[j] = 2; } - } else { // {NaN, NA} evaluate to FALSE + } else { // {NaN, NA} evaluate to FALSE throw_domain_error(function, "response", response_vec[i], " = ", ", but it must be either 1 (lower) or 2 (upper)"); } } } - if (Na < 1) { // a + if (Na < 1) { // a return 0; } else { for (size_t i = 0; i < Na; i++) { if (a_vec[i] > 0) { if (isfinite(a_vec[i])) { continue; - } else { // a = Inf implies PDF = log(0) + } else { // a = Inf implies PDF = log(0) throw_domain_error(function, "a (threshold separation)", a_vec[i], " = ", ", but it must be positive and finite"); } - } else { // {NaN, NA} evaluate to FALSE + } else { // {NaN, NA} evaluate to FALSE throw_domain_error(function, "a (threshold separation)", a_vec[i], " = ", ", but it must be positive and finite"); } } } - if (Nv < 1) { // v + if (Nv < 1) { // v return 0; } else { for (size_t i = 0; i < Nv; i++) { if (isfinite(v_vec[i])) { continue; - } else { // NaN, NA, Inf, -Inf are not finite + } else { // NaN, NA, Inf, -Inf are not finite throw_domain_error(function, "v (drift rate)", v_vec[i], " = ", ", but it must be finite"); } } } - if (Nt0 < 1) { // t0 + if (Nt0 < 1) { // t0 return 0; } else { for (size_t i = 0; i < Nt0; i++) { if (t0_vec[i] >= 0) { - if (isfinite(t0_vec[i])) { // this could also be handled in calculate_pdf() + if (isfinite( + t0_vec[i])) { // this could also be handled in calculate_pdf() continue; - } else { // t0 = Inf implies rt - t0 < 0 implies PDF = log(0) + } else { // t0 = Inf implies rt - t0 < 0 implies PDF = log(0) throw_domain_error(function, "t0 (non-decision time)", t0_vec[i], " = ", ", but it must be positive and finite"); } - } else { // {NaN, NA} evaluate to FALSE - throw_domain_error(function, "t0 (non-decision time)", t0_vec[i], - " = ", ", but it must be positive and finite"); + } else { // {NaN, NA} evaluate to FALSE + throw_domain_error(function, "t0 (non-decision time)", t0_vec[i], " = ", + ", but it must be positive and finite"); } } } - if (Nw < 1) { // w + if (Nw < 1) { // w return 0; } else { for (size_t i = 0; i < Nw; i++) { if (w_vec[i] > 0 && w_vec[i] < 1) { continue; - } else { // {NaN, NA} evaluate to FALSE + } else { // {NaN, NA} evaluate to FALSE throw_domain_error(function, "w (relative a priori bias)", w_vec[i], " = ", ", but it must be that 0 < w < 1"); } } } - if (Nsv < 1) { // sv + if (Nsv < 1) { // sv return 0; } else { for (size_t i = 0; i < Nsv; i++) { if (sv_vec[i] >= 0) { if (isfinite(sv_vec[i])) { continue; - } else { // sv = Inf implies PDF = log(0) - throw_domain_error(function, - "sv (standard deviation of drift rate across trials)", - sv_vec[i], " = ", - ", but it must be positive and finite"); + } else { // sv = Inf implies PDF = log(0) + throw_domain_error( + function, "sv (standard deviation of drift rate across trials)", + sv_vec[i], " = ", ", but it must be positive and finite"); } - } else { // {NaN, NA} evaluate to FALSE - throw_domain_error(function, - "sv (standard deviation of drift rate across trials)", - sv_vec[i], " = ", - ", but it must be positive and finite"); + } else { // {NaN, NA} evaluate to FALSE + throw_domain_error( + function, "sv (standard deviation of drift rate across trials)", + sv_vec[i], " = ", ", but it must be positive and finite"); } } } @@ -217,21 +216,20 @@ return_type_t ddm_lpdf( return 0; } - // Calculate log(PDF) T_return_type lp(0.0); double t, a_i, v_i, w_i, sv_i; for (size_t i = 0; i < Nmax; i++) { - // Check Parameter Values - t = rt_vec[i % Nrt] - t0_vec[i % Nt0]; // response time minus non-decision time - if (t > 0 && std::isfinite(t)) { // sort response and calculate density + t = rt_vec[i % Nrt] + - t0_vec[i % Nt0]; // response time minus non-decision time + if (t > 0 && std::isfinite(t)) { // sort response and calculate density a_i = a_vec[i % Na]; sv_i = sv_vec[i % Nsv]; - if (out[i] == 1) { // response is "lower" so use unchanged parameters + if (out[i] == 1) { // response is "lower" so use unchanged parameters v_i = v_vec[i % Nv]; w_i = w_vec[i % Nw]; - } else { // response is "upper" so use alternate parameters + } else { // response is "upper" so use alternate parameters v_i = -v_vec[i % Nv]; w_i = 1 - w_vec[i % Nw]; } @@ -239,125 +237,131 @@ return_type_t ddm_lpdf( // Approximate log(PDF) double mult; // Check large time - if (sv_i <= SV_THRESH) { // no sv - mult = - v_i * a_i * w_i - v_i*v_i * t / 2 - 2 * log(a_i); - } else { // sv - mult = (sv_i*sv_i * a_i*a_i * w_i*w_i - 2 * v_i * a_i * w_i - - v_i*v_i * t) / (2 + 2 * sv_i*sv_i * t) - - 0.5 * log(1 + sv_i*sv_i * t) - 2 * log(a_i); + if (sv_i <= SV_THRESH) { // no sv + mult = -v_i * a_i * w_i - v_i * v_i * t / 2 - 2 * log(a_i); + } else { // sv + mult = (sv_i * sv_i * a_i * a_i * w_i * w_i - 2 * v_i * a_i * w_i + - v_i * v_i * t) + / (2 + 2 * sv_i * sv_i * t) + - 0.5 * log(1 + sv_i * sv_i * t) - 2 * log(a_i); } int kl; double exp_err = ERR_TOL * exp(-mult); - double taa = t / (a_i*a_i); - double bc = 1 / (M_PI * sqrt(taa)); // boundary conditions - if (bc > INT_MAX) return INT_MAX; - if (exp_err * M_PI * taa < 1) { // error threshold is low enough - double kl_tmp = sqrt(-2 * log(M_PI * taa * exp_err) - / (M_PI*M_PI * taa)); + double taa = t / (a_i * a_i); + double bc = 1 / (M_PI * sqrt(taa)); // boundary conditions + if (bc > INT_MAX) + return INT_MAX; + if (exp_err * M_PI * taa < 1) { // error threshold is low enough + double kl_tmp + = sqrt(-2 * log(M_PI * taa * exp_err) / (M_PI * M_PI * taa)); if (kl_tmp > INT_MAX) { kl = INT_MAX; } else { - kl = ceil(max(kl_tmp, bc)); // ensure boundary conditions are met + kl = ceil(max(kl_tmp, bc)); // ensure boundary conditions are met } } else { - kl = ceil(bc); // else set to boundary condition + kl = ceil(bc); // else set to boundary condition } // Compare kl (large time) to max_terms_large (small time) - if (kl <= max_terms_large) { // use large time - double gamma = -0.5 * M_PI*M_PI * taa; + if (kl <= max_terms_large) { // use large time + double gamma = -0.5 * M_PI * M_PI * taa; double sum = 0.0; for (size_t j = 1; j <= kl; j++) { - sum += j * sin(j * w_i * M_PI) * exp(gamma * j*j); + sum += j * sin(j * w_i * M_PI) * exp(gamma * j * j); } - if (sum >= 0) { // if result is negative, don't add to lp + if (sum >= 0) { // if result is negative, don't add to lp lp += LOG_PI + mult + log(sum); } - } else { // use small time - if (sv_i <= SV_THRESH) { // no sv - mult = log(a_i) - LOG_2PI_2 - 1.5 * log(t) - - v_i * a_i * w_i - v_i*v_i * t / 2; - } else { // sv + } else { // use small time + if (sv_i <= SV_THRESH) { // no sv + mult = log(a_i) - LOG_2PI_2 - 1.5 * log(t) - v_i * a_i * w_i + - v_i * v_i * t / 2; + } else { // sv mult = log(a_i) - 1.5 * log(t) - LOG_2PI_2 - - 0.5 * log(1 + sv_i*sv_i * t) - + (sv_i*sv_i * a_i*a_i * w_i*w_i - 2 * v_i * a_i * w_i - - v_i*v_i * t) / (2 + 2 * sv_i*sv_i * t); + - 0.5 * log(1 + sv_i * sv_i * t) + + (sv_i * sv_i * a_i * a_i * w_i * w_i - 2 * v_i * a_i * w_i + - v_i * v_i * t) + / (2 + 2 * sv_i * sv_i * t); } exp_err = ERR_TOL / exp(mult); - size_t minterms = sqrt(t)/a_i - w_i; // min number of terms, truncates toward 0 + size_t minterms + = sqrt(t) / a_i - w_i; // min number of terms, truncates toward 0 double gamma = -1 / (2 * taa); - double sum = w_i * exp(gamma * w_i*w_i); // initialize with j=0 term + double sum = w_i * exp(gamma * w_i * w_i); // initialize with j=0 term double term, rj; size_t j = 0; - if (minterms % 2) { // minterms is odd (and at least 1) + if (minterms % 2) { // minterms is odd (and at least 1) j++; rj = j + 1 - w_i; - term = rj * exp(gamma * rj*rj); + term = rj * exp(gamma * rj * rj); sum -= term; while (j < minterms) { j++; rj = j + w_i; - sum += rj * exp(gamma * rj*rj); + sum += rj * exp(gamma * rj * rj); j++; rj = j + 1 - w_i; - term = rj * exp(gamma * rj*rj); + term = rj * exp(gamma * rj * rj); sum -= term; } j++; - rj = j + w_i; // j is now even - term = rj * exp(gamma * rj*rj); + rj = j + w_i; // j is now even + term = rj * exp(gamma * rj * rj); sum += term; while (term > exp_err) { j++; rj = j + 1 - w_i; - term = rj * exp(gamma * rj*rj); + term = rj * exp(gamma * rj * rj); sum -= term; - if (term <= exp_err) break; + if (term <= exp_err) + break; j++; rj = j + w_i; - term = rj * exp(gamma * rj*rj); + term = rj * exp(gamma * rj * rj); sum += term; } - } else { // minterms is even (and at least 0) - while (j < minterms) { // j is currently 0 + } else { // minterms is even (and at least 0) + while (j < minterms) { // j is currently 0 j++; rj = j + 1 - w_i; - sum -= rj * exp(gamma * rj*rj); + sum -= rj * exp(gamma * rj * rj); j++; rj = j + w_i; - term = rj * exp(gamma * rj*rj); + term = rj * exp(gamma * rj * rj); sum += term; } j++; - rj = j + 1 - w_i; // j is now odd - term = rj * exp(gamma * rj*rj); + rj = j + 1 - w_i; // j is now odd + term = rj * exp(gamma * rj * rj); sum -= term; while (term > exp_err) { j++; rj = j + w_i; - term = rj * exp(gamma * rj*rj); + term = rj * exp(gamma * rj * rj); sum += term; - if (term <= exp_err) break; + if (term <= exp_err) + break; j++; rj = j + 1 - w_i; - term = rj * exp(gamma * rj*rj); + term = rj * exp(gamma * rj * rj); sum -= term; } } - if (sum >= 0) { // if result is negative, don't add to lp + if (sum >= 0) { // if result is negative, don't add to lp lp += mult + log(sum); } } - } else { // {NaN, NA} evaluate to FALSE + } else { // {NaN, NA} evaluate to FALSE if (isnan(t)) { - throw_domain_error(function, "rt (response time)", rt_vec[i % Nrt], - "is a NaN and = ", - ", but rt must be positive and finite"); + throw_domain_error( + function, "rt (response time)", rt_vec[i % Nrt], + "is a NaN and = ", ", but rt must be positive and finite"); } else { - throw_domain_error(function, - "rt (response time)", t0_vec[i % Nt0], - "is not greater than t0 = ", - ", but it must be that rt - t0 is positive and finite"); + throw_domain_error( + function, "rt (response time)", t0_vec[i % Nt0], + "is not greater than t0 = ", + ", but it must be that rt - t0 is positive and finite"); } } } @@ -365,14 +369,14 @@ return_type_t ddm_lpdf( return lp; } -template +template inline return_type_t ddm_lpdf( const T_rt& rt, const T_response& response, const T_a& a, const T_v& v, const T_t0& t0, const T_w& w, const T_sv& sv) { return ddm_lpdf(rt, response, a, v, t0, w, sv); } -} -} // Close namespace +} // namespace math +} // namespace stan #endif diff --git a/test/prob/ddm/ddm_cdf_log_test.hpp b/test/prob/ddm/ddm_cdf_log_test.hpp index c28572beb8b..96985c081d1 100644 --- a/test/prob/ddm/ddm_cdf_log_test.hpp +++ b/test/prob/ddm/ddm_cdf_log_test.hpp @@ -1,18 +1,17 @@ // Arguments: Doubles, Ints, Doubles, Doubles, Doubles, Doubles, Doubles #include -using std::vector; using stan::math::INFTY; +using std::vector; class AgradCdfDdm : public AgradCdfTest { -public: - void valid_values(vector >& parameters, - vector& cdf) { + public: + void valid_values(vector >& parameters, vector& cdf) { vector param(7); - + // each expected log_prob is calculated with the R package `fddm` as follows // fddm::pfddm(rt, response, a, v, t0, w, sv, log = TRUE) - + param[0] = 1.0; // rt param[1] = 1; // response param[2] = 1.0; // a @@ -22,7 +21,7 @@ class AgradCdfDdm : public AgradCdfTest { param[6] = 0.0; // sv parameters.push_back(param); log_prob.push_back(-0.3189645693469165); // expected log_prob - + param[0] = 1.0; // rt param[1] = 2; // response param[2] = 1.0; // a @@ -32,7 +31,7 @@ class AgradCdfDdm : public AgradCdfTest { param[6] = 0.0; // sv parameters.push_back(param); log_prob.push_back(-1.318964569346917); // expected log_prob - + param[0] = 1.0; // rt param[1] = 1; // response param[2] = 2.0; // a @@ -42,17 +41,17 @@ class AgradCdfDdm : public AgradCdfTest { param[6] = 0.0; // sv parameters.push_back(param); log_prob.push_back(-0.4105356025317656); // expected log_prob - - param[0] = 1.0; // rt - param[1] = 1; // response - param[2] = 1.0; // a - param[3] = 1.0; // v - param[4] = 0.0; // t0 - param[5] = 0.5; // w - param[6] = 0.0; // sv + + param[0] = 1.0; // rt + param[1] = 1; // response + param[2] = 1.0; // a + param[3] = 1.0; // v + param[4] = 0.0; // t0 + param[5] = 0.5; // w + param[6] = 0.0; // sv parameters.push_back(param); log_prob.push_back(-1.318964569346917); // expected log_prob - + param[0] = 1.0; // rt param[1] = 1; // response param[2] = 1.0; // a @@ -62,7 +61,7 @@ class AgradCdfDdm : public AgradCdfTest { param[6] = 0.0; // sv parameters.push_back(param); log_prob.push_back(-0.403297055469638); // expected log_prob - + param[0] = 1.0; // rt param[1] = 1; // response param[2] = 1.0; // a @@ -72,7 +71,7 @@ class AgradCdfDdm : public AgradCdfTest { param[6] = 0.0; // sv parameters.push_back(param); log_prob.push_back(-0.08206670817568271); // expected log_prob - + param[0] = 1.0; // rt param[1] = 1; // response param[2] = 1.0; // a @@ -83,101 +82,101 @@ class AgradCdfDdm : public AgradCdfTest { parameters.push_back(param); log_prob.push_back(-0.3658772238413681); // expected log_prob } - + void invalid_values(vector& index, vector& value) { // rt index.push_back(0U); value.push_back(0.0); - + index.push_back(0U); value.push_back(-1.0); - + index.push_back(0U); value.push_back(INFTY); - + index.push_back(0U); value.push_back(-INFTY); - + // response index.push_back(1U); value.push_back(0); - + index.push_back(1U); value.push_back(3); - + index.push_back(1U); value.push_back(-1); - + index.push_back(1U); value.push_back(INFTY); - + index.push_back(1U); value.push_back(-INFTY); - + // a index.push_back(2U); value.push_back(0.0); - + index.push_back(2U); value.push_back(-1.0); - + index.push_back(2U); value.push_back(INFTY); - + index.push_back(2U); value.push_back(-INFTY); - + // v index.push_back(3U); value.push_back(INFTY); - + index.push_back(3U); value.push_back(-INFTY); - + // t0 index.push_back(4U); value.push_back(-1); - + index.push_back(4U); value.push_back(INFTY); - + index.push_back(4U); value.push_back(-INFTY); - + // w index.push_back(5U); value.push_back(-0.1); - + index.push_back(5U); value.push_back(0.0); - + index.push_back(5U); value.push_back(1.0); - + index.push_back(5U); value.push_back(1.1); - + index.push_back(5U); value.push_back(INFTY); - + index.push_back(5U); value.push_back(-INFTY); - + // sv index.push_back(6U); value.push_back(-1.0); - + index.push_back(6U); value.push_back(INFTY); - + index.push_back(6U); value.push_back(-INFTY); } - + bool has_upper_bound() { return true; } - + double upper_bound() { return 0.0; } - + template stan::return_type_t cdf( @@ -185,7 +184,7 @@ class AgradCdfDdm : public AgradCdfTest { const T_t0& t0, const T_w& w, const T_sv& sv, const T7&) { return stan::math::ddm_lcdf(rt, response, a, v, t0, w, sv); } - + template stan::return_type_t diff --git a/test/prob/ddm/ddm_test.hpp b/test/prob/ddm/ddm_test.hpp index 17d7cb7d239..8e26229ce3c 100644 --- a/test/prob/ddm/ddm_test.hpp +++ b/test/prob/ddm/ddm_test.hpp @@ -1,18 +1,18 @@ // Arguments: Doubles, Ints, Doubles, Doubles, Doubles, Doubles, Doubles #include -using std::vector; using stan::math::INFTY; +using std::vector; class AgradDistributionDdm : public AgradDistributionTest { -public: + public: void valid_values(vector >& parameters, vector& log_prob) { vector param(7); - + // each expected log_prob is calculated with the R package `fddm` as follows // fddm::dfddm(rt, response, a, v, t0, w, sv, log = TRUE) - + param[0] = 1.0; // rt param[1] = 1; // response param[2] = 1.0; // a @@ -22,7 +22,7 @@ class AgradDistributionDdm : public AgradDistributionTest { param[6] = 0.0; // sv parameters.push_back(param); log_prob.push_back(-3.790072391414288); // expected log_prob - + param[0] = 1.0; // rt param[1] = 2; // response param[2] = 1.0; // a @@ -32,7 +32,7 @@ class AgradDistributionDdm : public AgradDistributionTest { param[6] = 0.0; // sv parameters.push_back(param); log_prob.push_back(-4.790072391414288); // expected log_prob - + param[0] = 1.0; // rt param[1] = 1; // response param[2] = 2.0; // a @@ -42,17 +42,17 @@ class AgradDistributionDdm : public AgradDistributionTest { param[6] = 0.0; // sv parameters.push_back(param); log_prob.push_back(-0.9754202070046213); // expected log_prob - - param[0] = 1.0; // rt - param[1] = 1; // response - param[2] = 1.0; // a - param[3] = 1.0; // v - param[4] = 0.0; // t0 - param[5] = 0.5; // w - param[6] = 0.0; // sv + + param[0] = 1.0; // rt + param[1] = 1; // response + param[2] = 1.0; // a + param[3] = 1.0; // v + param[4] = 0.0; // t0 + param[5] = 0.5; // w + param[6] = 0.0; // sv parameters.push_back(param); log_prob.push_back(-4.790072391414288); // expected log_prob - + param[0] = 1.0; // rt param[1] = 1; // response param[2] = 1.0; // a @@ -62,7 +62,7 @@ class AgradDistributionDdm : public AgradDistributionTest { param[6] = 0.0; // sv parameters.push_back(param); log_prob.push_back(-1.072671222447106); // expected log_prob - + param[0] = 1.0; // rt param[1] = 1; // response param[2] = 1.0; // a @@ -72,7 +72,7 @@ class AgradDistributionDdm : public AgradDistributionTest { param[6] = 0.0; // sv parameters.push_back(param); log_prob.push_back(-4.621465563321226); // expected log_prob - + param[0] = 1.0; // rt param[1] = 1; // response param[2] = 1.0; // a @@ -83,97 +83,97 @@ class AgradDistributionDdm : public AgradDistributionTest { parameters.push_back(param); log_prob.push_back(-4.07414598169426); // expected log_prob } - + void invalid_values(vector& index, vector& value) { // rt index.push_back(0U); value.push_back(0.0); - + index.push_back(0U); value.push_back(-1.0); - + index.push_back(0U); value.push_back(INFTY); - + index.push_back(0U); value.push_back(-INFTY); - + // response index.push_back(1U); value.push_back(0); - + index.push_back(1U); value.push_back(3); - + index.push_back(1U); value.push_back(-1); - + index.push_back(1U); value.push_back(INFTY); - + index.push_back(1U); value.push_back(-INFTY); - + // a index.push_back(2U); value.push_back(0.0); - + index.push_back(2U); value.push_back(-1.0); - + index.push_back(2U); value.push_back(INFTY); - + index.push_back(2U); value.push_back(-INFTY); - + // v index.push_back(3U); value.push_back(INFTY); - + index.push_back(3U); value.push_back(-INFTY); - + // t0 index.push_back(4U); value.push_back(-1); - + index.push_back(4U); value.push_back(INFTY); - + index.push_back(4U); value.push_back(-INFTY); - + // w index.push_back(5U); value.push_back(-0.1); - + index.push_back(5U); value.push_back(0.0); - + index.push_back(5U); value.push_back(1.0); - + index.push_back(5U); value.push_back(1.1); - + index.push_back(5U); value.push_back(INFTY); - + index.push_back(5U); value.push_back(-INFTY); - + // sv index.push_back(6U); value.push_back(-1.0); - + index.push_back(6U); value.push_back(INFTY); - + index.push_back(6U); value.push_back(-INFTY); } - + template stan::return_type_t log_prob( @@ -181,7 +181,7 @@ class AgradDistributionDdm : public AgradDistributionTest { const T_t0& t0, const T_w& w, const T_sv& sv, const T7&) { return stan::math::ddm_lpdf(rt, response, a, v, t0, w, sv); } - + template @@ -190,7 +190,7 @@ class AgradDistributionDdm : public AgradDistributionTest { const T_t0& t0, const T_w& w, const T_sv& sv, const T7&) { return stan::math::ddm_lpdf(rt, response, a, v, t0, w, sv); } - + template stan::return_type_t diff --git a/test/unit/math/prim/prob/ddm_cdf_log_test.cpp b/test/unit/math/prim/prob/ddm_cdf_log_test.cpp index 863476f8ac2..f00522522f7 100644 --- a/test/unit/math/prim/prob/ddm_cdf_log_test.cpp +++ b/test/unit/math/prim/prob/ddm_cdf_log_test.cpp @@ -2,12 +2,11 @@ #include #include - TEST(ProbDdm, ddm_lcdf_matches_known_cdf) { - using std::vector; - using std::exp; using stan::math::ddm_lcdf; - + using std::exp; + using std::vector; + // Note: // 1. Each expected log_prob is calculated with the R package `fddm` using // fddm::pfddm(rt, response, a, v, t0, w, sv, log = TRUE) @@ -16,17 +15,17 @@ TEST(ProbDdm, ddm_lcdf_matches_known_cdf) { // exponentiating this very negative value would result in rounding to // zero. This value of error tolerance is slightly arbitrary, but it is // more useful than comparing zero to zero due to rounding issues. - + static const double pfddm_output = -972.3893812; static const double err_tol = 1.0; static const vector rt{0.1, 1, 10.0}; - static const int response = 1; // "lower" threshold + static const int response = 1; // "lower" threshold static const vector a{0.5, 1.0, 5.0}; static const vector v{-2.0, 0.0, 2.0}; static const double t0 = 0.0001; static const vector w{0.2, 0.5, 0.8}; static const vector sv{0.0, 0.5, 1.0, 1.5}; - + static const int n_rt = rt.size(); static const int n_a = a.size(); static const int n_v = v.size(); @@ -39,7 +38,7 @@ TEST(ProbDdm, ddm_lcdf_matches_known_cdf) { vector t0_vec(n_max, t0); vector w_vec(n_max); vector sv_vec(n_max); - + for (int i = 0; i < n_rt; i++) { for (int j = i; j < n_max; j += n_rt) { rt_vec[j] = rt[i]; @@ -65,33 +64,27 @@ TEST(ProbDdm, ddm_lcdf_matches_known_cdf) { sv_vec[j] = sv[i]; } } - - - EXPECT_NEAR((ddm_lcdf(rt, response, a, v, t0_vec, w, sv_vec)), - pfddm_output, + + EXPECT_NEAR((ddm_lcdf(rt, response, a, v, t0_vec, w, sv_vec)), pfddm_output, err_tol); EXPECT_NEAR((ddm_lcdf(rt, response, a, v, t0_vec, w, sv_vec)), - 0, // true makes ddm_lcdf() and wiener_lpdf() evaluate to 0 + 0, // true makes ddm_lcdf() and wiener_lpdf() evaluate to 0 err_tol); EXPECT_NEAR((ddm_lcdf(rt, response, a, v, t0_vec, w, sv_vec)), - pfddm_output, - err_tol); - EXPECT_NEAR( - (ddm_lcdf, int, vector, vector, - vector, vector, vector >( - rt, response, a, v, t0_vec, w, sv_vec)), - pfddm_output, - err_tol); + pfddm_output, err_tol); + EXPECT_NEAR((ddm_lcdf, int, vector, vector, + vector, vector, vector >( + rt, response, a, v, t0_vec, w, sv_vec)), + pfddm_output, err_tol); EXPECT_NEAR( - (ddm_lcdf, int, vector, vector, - vector, vector, vector >( - rt, response, a, v, t0_vec, w, sv_vec)), - 0, // true makes ddm_lcdf() and wiener_lpdf() evaluate to 0 - err_tol); + (ddm_lcdf, int, vector, vector, + vector, vector, vector >( + rt, response, a, v, t0_vec, w, sv_vec)), + 0, // true makes ddm_lcdf() and wiener_lpdf() evaluate to 0 + err_tol); EXPECT_NEAR( - (ddm_lcdf, int, vector, vector, - vector, vector, vector >( - rt, response, a, v, t0_vec, w, sv_vec)), - pfddm_output, - err_tol); + (ddm_lcdf, int, vector, vector, + vector, vector, vector >( + rt, response, a, v, t0_vec, w, sv_vec)), + pfddm_output, err_tol); } diff --git a/test/unit/math/prim/prob/ddm_test.cpp b/test/unit/math/prim/prob/ddm_test.cpp index e1790c4f152..875522ffe14 100644 --- a/test/unit/math/prim/prob/ddm_test.cpp +++ b/test/unit/math/prim/prob/ddm_test.cpp @@ -18,7 +18,6 @@ // delta -> v // Note: `response` and `sv` are not included in wiener_lpdf() - // Check invalid arguments // rt @@ -27,10 +26,8 @@ TEST(mathPrimScalProbDdmScal, invalid_rt) { using stan::math::INFTY; EXPECT_THROW(ddm_lpdf(0, 2, 1, -1, 0, 0.5, 0), std::domain_error); EXPECT_THROW(ddm_lpdf(-1, 2, 1, -1, 0, 0.5, 0), std::domain_error); - EXPECT_THROW(ddm_lpdf(INFTY, 2, 1, -1, 0, 0.5, 0), - std::domain_error); - EXPECT_THROW(ddm_lpdf(-INFTY, 2, 1, -1, 0, 0.5, 0), - std::domain_error); + EXPECT_THROW(ddm_lpdf(INFTY, 2, 1, -1, 0, 0.5, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(-INFTY, 2, 1, -1, 0, 0.5, 0), std::domain_error); EXPECT_THROW(ddm_lpdf(NAN, 2, 1, -1, 0, 0.5, 0), std::domain_error); } TEST(mathPrimScalProbDdmMat, invalid_rt) { @@ -196,10 +193,10 @@ TEST(mathPrimScalProbDdmMat, invalid_sv) { } TEST(ProbDdm, ddm_lpdf_matches_wiener_lpdf) { - using std::vector; - using std::exp; using stan::math::ddm_lpdf; using stan::math::wiener_lpdf; + using std::exp; + using std::vector; // Notes: // 1. define error tolerance for PDF approximations, use double tolerance to // allow for convergence (of the truncated infinite sum) from above and below @@ -208,13 +205,14 @@ TEST(ProbDdm, ddm_lpdf_matches_wiener_lpdf) { // parameter values defined below double err_tol = 2 * 0.000001; vector rt{0.1, 1, 10.0}; - int response = 2; // wiener_lpdf() always uses the "upper" threshold + int response = 2; // wiener_lpdf() always uses the "upper" threshold vector a{0.5, 1.0, 5.0}; vector v{-2.0, 0.0, 2.0}; - double t0 = 0.0001; // t0 (i.e., tau) needs to be > 0 for wiener_lpdf() + double t0 = 0.0001; // t0 (i.e., tau) needs to be > 0 for wiener_lpdf() vector w{0.2, 0.5, 0.8}; - double sv = 0.0; // sv is not included in wiener_lpdf(), and thus it must be 0 - + double sv + = 0.0; // sv is not included in wiener_lpdf(), and thus it must be 0 + int n_rt = rt.size(); int n_a = a.size(); int n_v = v.size(); @@ -225,7 +223,7 @@ TEST(ProbDdm, ddm_lpdf_matches_wiener_lpdf) { vector v_vec(n_max); vector t0_vec(n_max, t0); vector w_vec(n_max); - + for (int i = 0; i < n_rt; i++) { for (int j = i; j < n_max; j += n_rt) { rt_vec[j] = rt[i]; @@ -246,44 +244,42 @@ TEST(ProbDdm, ddm_lpdf_matches_wiener_lpdf) { w_vec[j] = w[i]; } } - + // The PDF approximation error tolerance is based on the non-log version of // the PDF. To account for this, we compare the exponentiated version of the // *_lpdf results. EXPECT_NEAR(exp(ddm_lpdf(rt, response, a, v, t0_vec, w, sv)), - exp(wiener_lpdf(rt_vec, a_vec, t0_vec, w_vec, v_vec)), - err_tol); + exp(wiener_lpdf(rt_vec, a_vec, t0_vec, w_vec, v_vec)), err_tol); EXPECT_NEAR(exp(ddm_lpdf(rt, response, a, v, t0_vec, w, sv)), exp(wiener_lpdf(rt_vec, a_vec, t0_vec, w_vec, v_vec)), err_tol); EXPECT_NEAR(exp(ddm_lpdf(rt, response, a, v, t0_vec, w, sv)), exp(wiener_lpdf(rt_vec, a_vec, t0_vec, w_vec, v_vec)), err_tol); + EXPECT_NEAR(exp(ddm_lpdf, int, vector, vector, + vector, vector, double>( + rt, response, a, v, t0_vec, w, sv)), + exp(wiener_lpdf, vector, vector, + vector, vector >( + rt_vec, a_vec, t0_vec, w_vec, v_vec)), + err_tol); EXPECT_NEAR( - exp(ddm_lpdf, int, vector, vector, - vector, vector, double>( - rt, response, a, v, t0_vec, w, sv)), - exp(wiener_lpdf, vector, vector, - vector, vector >( - rt_vec, a_vec, t0_vec, w_vec, v_vec)), - err_tol); - EXPECT_NEAR( - exp(ddm_lpdf, int, vector, vector, - vector, vector, double>( - rt, response, a, v, t0_vec, w, sv)), - exp(wiener_lpdf, vector, vector, - vector, vector >( - rt_vec, a_vec, t0_vec, w_vec, v_vec)), - err_tol); + exp(ddm_lpdf, int, vector, vector, + vector, vector, double>(rt, response, a, v, + t0_vec, w, sv)), + exp(wiener_lpdf, vector, vector, + vector, vector >(rt_vec, a_vec, t0_vec, + w_vec, v_vec)), + err_tol); EXPECT_NEAR( - exp(ddm_lpdf, int, vector, vector, - vector, vector, double>( - rt, response, a, v, t0_vec, w, sv)), - exp(wiener_lpdf, vector, vector, - vector, vector >( - rt_vec, a_vec, t0_vec, w_vec, v_vec)), - err_tol); - + exp(ddm_lpdf, int, vector, vector, + vector, vector, double>(rt, response, a, v, + t0_vec, w, sv)), + exp(wiener_lpdf, vector, vector, + vector, vector >(rt_vec, a_vec, t0_vec, + w_vec, v_vec)), + err_tol); + // check with variable drift rate (against results from R package `fddm`) // Notes: // 1. We must redefine the error tolerance because we must check on the log @@ -302,32 +298,27 @@ TEST(ProbDdm, ddm_lpdf_matches_wiener_lpdf) { sv_vec[j] = sv_vals[i]; } } - - EXPECT_NEAR((ddm_lpdf(rt, response, a, v, t0_vec, w, sv_vec)), - dfddm_output, + + EXPECT_NEAR((ddm_lpdf(rt, response, a, v, t0_vec, w, sv_vec)), dfddm_output, err_tol); EXPECT_NEAR((ddm_lpdf(rt, response, a, v, t0_vec, w, sv_vec)), - 0, // true makes ddm_lpdf() and wiener_lpdf() evaluate to 0 + 0, // true makes ddm_lpdf() and wiener_lpdf() evaluate to 0 err_tol); EXPECT_NEAR((ddm_lpdf(rt, response, a, v, t0_vec, w, sv_vec)), - dfddm_output, - err_tol); - EXPECT_NEAR( - (ddm_lpdf, int, vector, vector, - vector, vector, vector >( - rt, response, a, v, t0_vec, w, sv_vec)), - dfddm_output, - err_tol); + dfddm_output, err_tol); + EXPECT_NEAR((ddm_lpdf, int, vector, vector, + vector, vector, vector >( + rt, response, a, v, t0_vec, w, sv_vec)), + dfddm_output, err_tol); EXPECT_NEAR( - (ddm_lpdf, int, vector, vector, - vector, vector, vector >( - rt, response, a, v, t0_vec, w, sv_vec)), - 0, // true makes ddm_lpdf() and wiener_lpdf() evaluate to 0 - err_tol); + (ddm_lpdf, int, vector, vector, + vector, vector, vector >( + rt, response, a, v, t0_vec, w, sv_vec)), + 0, // true makes ddm_lpdf() and wiener_lpdf() evaluate to 0 + err_tol); EXPECT_NEAR( - (ddm_lpdf, int, vector, vector, - vector, vector, vector >( - rt, response, a, v, t0_vec, w, sv_vec)), - dfddm_output, - err_tol); + (ddm_lpdf, int, vector, vector, + vector, vector, vector >( + rt, response, a, v, t0_vec, w, sv_vec)), + dfddm_output, err_tol); }