diff --git a/stan/math/prim/prob.hpp b/stan/math/prim/prob.hpp index 2fce0b34e1d..5b302630f32 100644 --- a/stan/math/prim/prob.hpp +++ b/stan/math/prim/prob.hpp @@ -71,6 +71,8 @@ #include #include #include +#include +#include #include #include #include diff --git a/stan/math/prim/prob/ddm_lcdf.hpp b/stan/math/prim/prob/ddm_lcdf.hpp new file mode 100644 index 00000000000..66ffee27ff0 --- /dev/null +++ b/stan/math/prim/prob/ddm_lcdf.hpp @@ -0,0 +1,393 @@ +#ifndef STAN_MATH_PRIM_PROB_DDM_LCDF_HPP +#define STAN_MATH_PRIM_PROB_DDM_LCDF_HPP + +#include +#include +#include +#include +#include +#include +#include + +// Open the Namespace +namespace stan { +namespace math { +using stan::return_type_t; + +/** + * The log of the first passage time distribution function for a + * (Ratcliff, 1978) drift diffusion model with intrinsic trial-trial variability + * for the given response time \f$rt\f$, response \f$response\f$, boundary + * separation \f$a\f$, mean drift rate across trials \f$v\f$, non-decision + * time \f$t0\f$, relative bias \f$w\f$, standard deviation of drift rate + * across trials \f$sv\f$. + * + * @tparam T_rt type of parameter `rt` + * @tparam T_response type of parameter `response` + * @tparam T_a type of parameter `a` + * @tparam T_v type of parameter `v` + * @tparam T_t0 type of parameter `t0` + * @tparam T_w type of parameter `w` + * @tparam T_sv type of parameter `sv` + * + * @param rt The response time; rt >= 0. + * @param response The response; response in {1, 2}. + * @param a The threshold separation; a > 0. + * @param v The mean drift rate across trials. + * @param t0 The non-decision time; t0 >= 0. + * @param w The relative a priori bias; 0 < w < 1. + * @param sv The standard deviation of drift rate across trials; sv >= 0. + * @return The log of the Wiener first passage time density of + * the specified arguments. + */ +template +return_type_t ddm_lcdf( + const T_rt& rt, const T_response& response, const T_a& a, const T_v& v, + const T_t0& t0, const T_w& w, const T_sv& sv) { + using T_return_type + = return_type_t; + using stan::ref_type_t; + using stan::scalar_seq_view; + using stan::math::include_summand; + using stan::math::invalid_argument; + using stan::math::throw_domain_error; + using std::erf; + using std::exp; + using std::isfinite; + using std::isnan; + using std::log; + using std::max; + using std::sqrt; + using std::vector; + using T_rt_ref = ref_type_t; + using T_response_ref = ref_type_t; + using T_a_ref = ref_type_t; + using T_v_ref = ref_type_t; + using T_t0_ref = ref_type_t; + using T_w_ref = ref_type_t; + using T_sv_ref = ref_type_t; + + // Constants + static const char* function = "ddm_lcdf"; + static const double ERR_TOL = 0.000001; // error tolerance for PDF approx + static const double PI_CONST = 3.14159265358979323846; // define pi like C++ + static const double SQRT_2PI = sqrt(2 * PI_CONST); + static const double SQRT_2PI_INV = 1 / SQRT_2PI; + static const double SQRT_2_INV_NEG = -1 / sqrt(2); + + // Convert Inputs + T_rt_ref rt_ref = rt; + T_response_ref response_ref = response; + T_a_ref a_ref = a; + T_v_ref v_ref = v; + T_t0_ref t0_ref = t0; + T_w_ref w_ref = w; + T_sv_ref sv_ref = sv; + scalar_seq_view rt_vec(rt_ref); + scalar_seq_view response_vec(response_ref); + scalar_seq_view a_vec(a_ref); + scalar_seq_view v_vec(v_ref); + scalar_seq_view t0_vec(t0_ref); + scalar_seq_view w_vec(w_ref); + scalar_seq_view sv_vec(sv_ref); + + // Parameter Checks + size_t Nrt = rt_vec.size(); + size_t Nres = response_vec.size(); + size_t Na = a_vec.size(); + size_t Nv = v_vec.size(); + size_t Nt0 = t0_vec.size(); + size_t Nw = w_vec.size(); + size_t Nsv = sv_vec.size(); + size_t Nmax = max({Nrt, Nres, Na, Nv, Nt0, Nw, Nsv}); + vector out(Nmax); // initialize output-checking vector + + if (Nrt + < 1) { // rt, invalid inputs will be handled in calculation of the CDF + return 0; + } + + if (Nres < 1) { // response + return 0; + } else { + for (size_t i = 0; i < Nres; i++) { + if (response_vec[i] == 1) { // lower + for (size_t j = i; j < Nmax; j += Nres) { + out[j] = 1; + } + } else if (response_vec[i] == 2) { // upper + for (size_t j = i; j < Nmax; j += Nres) { + out[j] = 2; + } + } else { // {NaN, NA} evaluate to FALSE + throw_domain_error(function, "response", response_vec[i], " = ", + ", but it must be either 1 (lower) or 2 (upper)"); + } + } + } + + if (Na < 1) { // a + return 0; + } else { + for (size_t i = 0; i < Na; i++) { + if (a_vec[i] > 0) { + if (isfinite(a_vec[i])) { + continue; + } else { // a = Inf implies PDF = log(0) and CDF problems + throw_domain_error(function, "a (threshold separation)", a_vec[i], + " = ", ", but it must be finite"); + } + } else { // {NaN, NA} evaluate to FALSE + throw_domain_error(function, "a (threshold separation)", a_vec[i], + " = ", ", but it must be positive and finite"); + } + } + } + + if (Nv < 1) { // v + return 0; + } else { + for (size_t i = 0; i < Nv; i++) { + if (isfinite(v_vec[i])) { + continue; + } else { // NaN, NA, Inf, -Inf are not finite + throw_domain_error(function, "v (drift rate)", v_vec[i], " = ", + ", but it must be finite"); + } + } + } + + if (Nt0 < 1) { // t0 + return 0; + } else { + for (size_t i = 0; i < Nt0; i++) { + if (t0_vec[i] >= 0) { + if (isfinite(t0_vec[i])) { // this could also be handled in calculation + // of CDF + continue; + } else { // t0 = Inf implies rt - t0 < 0 implies CDF = log(0) + throw_domain_error(function, "t0 (non-decision time)", t0_vec[i], + " = ", ", but it must be finite"); + } + } else { // {NaN, NA} evaluate to FALSE + throw_domain_error(function, "t0 (non-decision time)", t0_vec[i], " = ", + ", but it must be positive and finite"); + } + } + } + + if (Nw < 1) { // w + return 0; + } else { + for (size_t i = 0; i < Nw; i++) { + if (w_vec[i] > 0 && w_vec[i] < 1) { + continue; + } else { // {NaN, NA} evaluate to FALSE + throw_domain_error(function, "w (relative a priori bias)", w_vec[i], + " = ", ", but it must be that 0 < w < 1"); + } + } + } + + if (Nsv < 1) { // sv + return 0; + } else { + for (size_t i = 0; i < Nsv; i++) { + if (sv_vec[i] >= 0) { + if (isfinite(sv_vec[i])) { + continue; + } else { // sv = Inf implies PDF = log(0) and CDF problems + throw_domain_error( + function, "sv (standard deviation of drift rate across trials)", + sv_vec[i], " = ", ", but it must be finite"); + } + } else { // {NaN, NA} evaluate to FALSE + throw_domain_error( + function, "sv (standard deviation of drift rate across trials)", + sv_vec[i], " = ", ", but it must be positive and finite"); + } + } + } + + if (!include_summand::value) { + return 0; + } + + // Calculate log(CDF) + T_return_type lp(0.0); + double t, a_i, v_i, w_i, sv_i; + for (size_t i = 0; i < Nmax; i++) { + // Check Parameter Values + t = rt_vec[i % Nrt] + - t0_vec[i % Nt0]; // response time minus non-decision time + if (t > 0) { // sort response and calculate density + a_i = a_vec[i % Na]; + sv_i = sv_vec[i % Nsv]; + if (out[i] == 1) { // response is "lower" so use unchanged parameters + v_i = v_vec[i % Nv]; + w_i = w_vec[i % Nw]; + } else { // response is "upper" so use alternate parameters + v_i = -v_vec[i % Nv]; + w_i = 1 - w_vec[i % Nw]; + } + + if (t > 32) { // approximation for t = +Infinity + t = 32; + } + + // Calculate sum multiplier + double mult = (sv_i * sv_i * a_i * a_i * w_i * w_i - 2 * v_i * a_i * w_i + - v_i * v_i * t) + / (2 + 2 * sv_i * sv_i * t); + + // Scale error so it is valid inside the sum (not logged) + double exp_err = ERR_TOL * exp(-mult); + + // Calculate sum + double sum = 0; + double gamma = v_i - sv_i * sv_i * a_i * w_i; + double lambda = 1 + sv_i * sv_i * t; + double rho = sqrt(t * lambda); + + int j = 0; + double rj = a_i * j + a_i * w_i; + double m1 = (lambda * rj - gamma * t) / rho; + double m2 = (lambda * rj + gamma * t) / rho; + double mills_1, mills_2; + if (m1 < 6.5) { + mills_1 = SQRT_2PI * 0.5 * exp(0.5 * m1 * m1) + * (1 + erf(SQRT_2_INV_NEG * m1)); + } else { + double m1sq = m1 * m1; + mills_1 = (1 - 1 / (m1sq + 2) + 1 / ((m1sq + 2) * (m1sq + 4)) + - 5 / ((m1sq + 2) * (m1sq + 4) * (m1sq + 6)) + + 9 / ((m1sq + 2) * (m1sq + 4) * (m1sq + 6) * (m1sq + 8)) + - 129 + / ((m1sq + 2) * (m1sq + 4) * (m1sq + 6) * (m1sq + 8) + * (m1sq + 10))) + / m1; + } + if (m2 < 6.5) { + mills_2 = SQRT_2PI * 0.5 * exp(0.5 * m2 * m2) + * (1 + erf(SQRT_2_INV_NEG * m2)); + } else { + double m2sq = m2 * m2; + mills_2 = (1 - 1 / (m2sq + 2) + 1 / ((m2sq + 2) * (m2sq + 4)) + - 5 / ((m2sq + 2) * (m2sq + 4) * (m2sq + 6)) + + 9 / ((m2sq + 2) * (m2sq + 4) * (m2sq + 6) * (m2sq + 8)) + - 129 + / ((m2sq + 2) * (m2sq + 4) * (m2sq + 6) * (m2sq + 8) + * (m2sq + 10))) + / m2; + } + double term + = SQRT_2PI_INV * exp(-0.5 * rj * rj / t) * (mills_1 + mills_2); + sum += term; + + while (term > exp_err) { + if (j > 1000) { + // maybe include a warning here? + break; + } + j++; + rj = a_i * j + a_i * (1 - w_i); + m1 = (lambda * rj - gamma * t) / rho; + m2 = (lambda * rj + gamma * t) / rho; + if (m1 < 6.5) { + mills_1 = SQRT_2PI * 0.5 * exp(0.5 * m1 * m1) + * (1 + erf(SQRT_2_INV_NEG * m1)); + } else { + double m1sq = m1 * m1; + mills_1 = (1 - 1 / (m1sq + 2) + 1 / ((m1sq + 2) * (m1sq + 4)) + - 5 / ((m1sq + 2) * (m1sq + 4) * (m1sq + 6)) + + 9 / ((m1sq + 2) * (m1sq + 4) * (m1sq + 6) * (m1sq + 8)) + - 129 + / ((m1sq + 2) * (m1sq + 4) * (m1sq + 6) * (m1sq + 8) + * (m1sq + 10))) + / m1; + } + if (m2 < 6.5) { + mills_2 = SQRT_2PI * 0.5 * exp(0.5 * m2 * m2) + * (1 + erf(SQRT_2_INV_NEG * m2)); + } else { + double m2sq = m2 * m2; + mills_2 = (1 - 1 / (m2sq + 2) + 1 / ((m2sq + 2) * (m2sq + 4)) + - 5 / ((m2sq + 2) * (m2sq + 4) * (m2sq + 6)) + + 9 / ((m2sq + 2) * (m2sq + 4) * (m2sq + 6) * (m2sq + 8)) + - 129 + / ((m2sq + 2) * (m2sq + 4) * (m2sq + 6) * (m2sq + 8) + * (m2sq + 10))) + / m2; + } + term = SQRT_2PI_INV * exp(-0.5 * rj * rj / t) * (mills_1 + mills_2); + sum -= term; + + if (term <= exp_err) + break; + + j++; + rj = a_i * j + a_i * w_i; + m1 = (lambda * rj - gamma * t) / rho; + m2 = (lambda * rj + gamma * t) / rho; + if (m1 < 6.5) { + mills_1 = SQRT_2PI * 0.5 * exp(0.5 * m1 * m1) + * (1 + erf(SQRT_2_INV_NEG * m1)); + } else { + double m1sq = m1 * m1; + mills_1 = (1 - 1 / (m1sq + 2) + 1 / ((m1sq + 2) * (m1sq + 4)) + - 5 / ((m1sq + 2) * (m1sq + 4) * (m1sq + 6)) + + 9 / ((m1sq + 2) * (m1sq + 4) * (m1sq + 6) * (m1sq + 8)) + - 129 + / ((m1sq + 2) * (m1sq + 4) * (m1sq + 6) * (m1sq + 8) + * (m1sq + 10))) + / m1; + } + if (m2 < 6.5) { + mills_2 = SQRT_2PI * 0.5 * exp(0.5 * m2 * m2) + * (1 + erf(SQRT_2_INV_NEG * m2)); + } else { + double m2sq = m2 * m2; + mills_2 = (1 - 1 / (m2sq + 2) + 1 / ((m2sq + 2) * (m2sq + 4)) + - 5 / ((m2sq + 2) * (m2sq + 4) * (m2sq + 6)) + + 9 / ((m2sq + 2) * (m2sq + 4) * (m2sq + 6) * (m2sq + 8)) + - 129 + / ((m2sq + 2) * (m2sq + 4) * (m2sq + 6) * (m2sq + 8) + * (m2sq + 10))) + / m2; + } + term = SQRT_2PI_INV * exp(-0.5 * rj * rj / t) * (mills_1 + mills_2); + sum += term; + } + + // Add sum and multiplier to lp + if (sum >= 0) { // if result is negative, treat as 0 and don't add to lp + lp += mult + log(sum); + } + } else { // {NaN, NA} evaluate to FALSE + if (isnan(t)) { + throw_domain_error(function, "rt (response time)", rt_vec[i % Nrt], + "is a NaN and = ", ", but this value is invalid"); + } else { + throw_domain_error( + function, "rt (response time)", t0_vec[i % Nt0], + "is not greater than t0 = ", ", but it must be that rt - t0 > 0"); + } + } + } + + return lp; +} + +template +inline return_type_t ddm_lcdf( + const T_rt& rt, const T_response& response, const T_a& a, const T_v& v, + const T_t0& t0, const T_w& w, const T_sv& sv) { + return ddm_lcdf(rt, response, a, v, t0, w, sv); +} + +} // namespace math +} // namespace stan +#endif diff --git a/stan/math/prim/prob/ddm_lpdf.hpp b/stan/math/prim/prob/ddm_lpdf.hpp new file mode 100644 index 00000000000..93b5f22d49d --- /dev/null +++ b/stan/math/prim/prob/ddm_lpdf.hpp @@ -0,0 +1,382 @@ +#ifndef STAN_MATH_PRIM_PROB_DDM_LPDF_HPP +#define STAN_MATH_PRIM_PROB_DDM_LPDF_HPP + +#define _USE_MATH_DEFINES +#include +#include +#include +#include +#include +#include +#include + +// Open the Namespace +namespace stan { +namespace math { +using stan::return_type_t; + +/** + * The log of the first passage time density function for a (Ratcliff, 1978) + * drift diffusion model with intrinsic trial-trial variability + * for the given response time \f$rt\f$, response \f$response\f$, boundary + * separation \f$a\f$, mean drift rate across trials \f$v\f$, non-decision + * time \f$t0\f$, relative bias \f$w\f$, and standard deviation of drift rate + * across trials \f$sv\f$. + * + * @tparam T_rt type of parameter `rt` + * @tparam T_response type of parameter `response` + * @tparam T_a type of parameter `a` + * @tparam T_v type of parameter `v` + * @tparam T_t0 type of parameter `t0` + * @tparam T_w type of parameter `w` + * @tparam T_sv type of parameter `sv` + * + * @param rt The response time; rt >= 0. + * @param response The response; response in {1, 2}. + * @param a The threshold separation; a > 0. + * @param v The mean drift rate across trials. + * @param t0 The non-decision time; t0 >= 0. + * @param w The relative a priori bias; 0 < w < 1. + * @param sv The standard deviation of drift rate across trials; sv >= 0. + * @return The log of the Wiener first passage time density of + * the specified arguments. + */ +template +return_type_t ddm_lpdf( + const T_rt& rt, const T_response& response, const T_a& a, const T_v& v, + const T_t0& t0, const T_w& w, const T_sv& sv) { + using T_return_type + = return_type_t; + using stan::ref_type_t; + using stan::scalar_seq_view; + using stan::math::include_summand; + using stan::math::invalid_argument; + using stan::math::throw_domain_error; + using std::ceil; + using std::exp; + using std::isfinite; + using std::isnan; + using std::log; + using std::max; + using std::sqrt; + using std::vector; + using T_rt_ref = ref_type_t; + using T_response_ref = ref_type_t; + using T_a_ref = ref_type_t; + using T_v_ref = ref_type_t; + using T_t0_ref = ref_type_t; + using T_w_ref = ref_type_t; + using T_sv_ref = ref_type_t; + + // Constants + static const char* function = "ddm_lpdf"; + static const int max_terms_large = 1; // heuristic for switching mechanism + static const double ERR_TOL = 0.000001; // error tolerance for PDF approx + static const double SV_THRESH = 0; // threshold for using variable drift rate + static const double LOG_PI = log(M_PI); + static const double LOG_2PI_2 = 0.5 * log(2 * M_PI); + + // Convert Inputs + T_rt_ref rt_ref = rt; + T_response_ref response_ref = response; + T_a_ref a_ref = a; + T_v_ref v_ref = v; + T_t0_ref t0_ref = t0; + T_w_ref w_ref = w; + T_sv_ref sv_ref = sv; + scalar_seq_view rt_vec(rt_ref); + scalar_seq_view response_vec(response_ref); + scalar_seq_view a_vec(a_ref); + scalar_seq_view v_vec(v_ref); + scalar_seq_view t0_vec(t0_ref); + scalar_seq_view w_vec(w_ref); + scalar_seq_view sv_vec(sv_ref); + + // Parameter Checks + size_t Nrt = rt_vec.size(); + size_t Nres = response_vec.size(); + size_t Na = a_vec.size(); + size_t Nv = v_vec.size(); + size_t Nt0 = t0_vec.size(); + size_t Nw = w_vec.size(); + size_t Nsv = sv_vec.size(); + size_t Nmax = max({Nrt, Nres, Na, Nv, Nt0, Nw, Nsv}); + vector out(Nmax); // initialize output-checking vector + + if (Nrt + < 1) { // rt, invalid inputs will be handled in calculation of the pdf + return 0; + } + + if (Nres < 1) { // response + return 0; + } else { + for (size_t i = 0; i < Nres; i++) { + if (response_vec[i] == 1) { // lower + for (size_t j = i; j < Nmax; j += Nres) { + out[j] = 1; + } + } else if (response_vec[i] == 2) { // upper + for (size_t j = i; j < Nmax; j += Nres) { + out[j] = 2; + } + } else { // {NaN, NA} evaluate to FALSE + throw_domain_error(function, "response", response_vec[i], " = ", + ", but it must be either 1 (lower) or 2 (upper)"); + } + } + } + + if (Na < 1) { // a + return 0; + } else { + for (size_t i = 0; i < Na; i++) { + if (a_vec[i] > 0) { + if (isfinite(a_vec[i])) { + continue; + } else { // a = Inf implies PDF = log(0) + throw_domain_error(function, "a (threshold separation)", a_vec[i], + " = ", ", but it must be positive and finite"); + } + } else { // {NaN, NA} evaluate to FALSE + throw_domain_error(function, "a (threshold separation)", a_vec[i], + " = ", ", but it must be positive and finite"); + } + } + } + + if (Nv < 1) { // v + return 0; + } else { + for (size_t i = 0; i < Nv; i++) { + if (isfinite(v_vec[i])) { + continue; + } else { // NaN, NA, Inf, -Inf are not finite + throw_domain_error(function, "v (drift rate)", v_vec[i], " = ", + ", but it must be finite"); + } + } + } + + if (Nt0 < 1) { // t0 + return 0; + } else { + for (size_t i = 0; i < Nt0; i++) { + if (t0_vec[i] >= 0) { + if (isfinite( + t0_vec[i])) { // this could also be handled in calculate_pdf() + continue; + } else { // t0 = Inf implies rt - t0 < 0 implies PDF = log(0) + throw_domain_error(function, "t0 (non-decision time)", t0_vec[i], + " = ", ", but it must be positive and finite"); + } + } else { // {NaN, NA} evaluate to FALSE + throw_domain_error(function, "t0 (non-decision time)", t0_vec[i], " = ", + ", but it must be positive and finite"); + } + } + } + + if (Nw < 1) { // w + return 0; + } else { + for (size_t i = 0; i < Nw; i++) { + if (w_vec[i] > 0 && w_vec[i] < 1) { + continue; + } else { // {NaN, NA} evaluate to FALSE + throw_domain_error(function, "w (relative a priori bias)", w_vec[i], + " = ", ", but it must be that 0 < w < 1"); + } + } + } + + if (Nsv < 1) { // sv + return 0; + } else { + for (size_t i = 0; i < Nsv; i++) { + if (sv_vec[i] >= 0) { + if (isfinite(sv_vec[i])) { + continue; + } else { // sv = Inf implies PDF = log(0) + throw_domain_error( + function, "sv (standard deviation of drift rate across trials)", + sv_vec[i], " = ", ", but it must be positive and finite"); + } + } else { // {NaN, NA} evaluate to FALSE + throw_domain_error( + function, "sv (standard deviation of drift rate across trials)", + sv_vec[i], " = ", ", but it must be positive and finite"); + } + } + } + + if (!include_summand::value) { + return 0; + } + + // Calculate log(PDF) + T_return_type lp(0.0); + double t, a_i, v_i, w_i, sv_i; + for (size_t i = 0; i < Nmax; i++) { + // Check Parameter Values + t = rt_vec[i % Nrt] + - t0_vec[i % Nt0]; // response time minus non-decision time + if (t > 0 && std::isfinite(t)) { // sort response and calculate density + a_i = a_vec[i % Na]; + sv_i = sv_vec[i % Nsv]; + if (out[i] == 1) { // response is "lower" so use unchanged parameters + v_i = v_vec[i % Nv]; + w_i = w_vec[i % Nw]; + } else { // response is "upper" so use alternate parameters + v_i = -v_vec[i % Nv]; + w_i = 1 - w_vec[i % Nw]; + } + + // Approximate log(PDF) + double mult; + // Check large time + if (sv_i <= SV_THRESH) { // no sv + mult = -v_i * a_i * w_i - v_i * v_i * t / 2 - 2 * log(a_i); + } else { // sv + mult = (sv_i * sv_i * a_i * a_i * w_i * w_i - 2 * v_i * a_i * w_i + - v_i * v_i * t) + / (2 + 2 * sv_i * sv_i * t) + - 0.5 * log(1 + sv_i * sv_i * t) - 2 * log(a_i); + } + int kl; + double exp_err = ERR_TOL * exp(-mult); + double taa = t / (a_i * a_i); + double bc = 1 / (M_PI * sqrt(taa)); // boundary conditions + if (bc > INT_MAX) + return INT_MAX; + if (exp_err * M_PI * taa < 1) { // error threshold is low enough + double kl_tmp + = sqrt(-2 * log(M_PI * taa * exp_err) / (M_PI * M_PI * taa)); + if (kl_tmp > INT_MAX) { + kl = INT_MAX; + } else { + kl = ceil(max(kl_tmp, bc)); // ensure boundary conditions are met + } + } else { + kl = ceil(bc); // else set to boundary condition + } + + // Compare kl (large time) to max_terms_large (small time) + if (kl <= max_terms_large) { // use large time + double gamma = -0.5 * M_PI * M_PI * taa; + double sum = 0.0; + for (size_t j = 1; j <= kl; j++) { + sum += j * sin(j * w_i * M_PI) * exp(gamma * j * j); + } + if (sum >= 0) { // if result is negative, don't add to lp + lp += LOG_PI + mult + log(sum); + } + } else { // use small time + if (sv_i <= SV_THRESH) { // no sv + mult = log(a_i) - LOG_2PI_2 - 1.5 * log(t) - v_i * a_i * w_i + - v_i * v_i * t / 2; + } else { // sv + mult = log(a_i) - 1.5 * log(t) - LOG_2PI_2 + - 0.5 * log(1 + sv_i * sv_i * t) + + (sv_i * sv_i * a_i * a_i * w_i * w_i - 2 * v_i * a_i * w_i + - v_i * v_i * t) + / (2 + 2 * sv_i * sv_i * t); + } + exp_err = ERR_TOL / exp(mult); + size_t minterms + = sqrt(t) / a_i - w_i; // min number of terms, truncates toward 0 + double gamma = -1 / (2 * taa); + double sum = w_i * exp(gamma * w_i * w_i); // initialize with j=0 term + double term, rj; + size_t j = 0; + if (minterms % 2) { // minterms is odd (and at least 1) + j++; + rj = j + 1 - w_i; + term = rj * exp(gamma * rj * rj); + sum -= term; + while (j < minterms) { + j++; + rj = j + w_i; + sum += rj * exp(gamma * rj * rj); + j++; + rj = j + 1 - w_i; + term = rj * exp(gamma * rj * rj); + sum -= term; + } + j++; + rj = j + w_i; // j is now even + term = rj * exp(gamma * rj * rj); + sum += term; + while (term > exp_err) { + j++; + rj = j + 1 - w_i; + term = rj * exp(gamma * rj * rj); + sum -= term; + if (term <= exp_err) + break; + j++; + rj = j + w_i; + term = rj * exp(gamma * rj * rj); + sum += term; + } + } else { // minterms is even (and at least 0) + while (j < minterms) { // j is currently 0 + j++; + rj = j + 1 - w_i; + sum -= rj * exp(gamma * rj * rj); + j++; + rj = j + w_i; + term = rj * exp(gamma * rj * rj); + sum += term; + } + j++; + rj = j + 1 - w_i; // j is now odd + term = rj * exp(gamma * rj * rj); + sum -= term; + while (term > exp_err) { + j++; + rj = j + w_i; + term = rj * exp(gamma * rj * rj); + sum += term; + if (term <= exp_err) + break; + j++; + rj = j + 1 - w_i; + term = rj * exp(gamma * rj * rj); + sum -= term; + } + } + if (sum >= 0) { // if result is negative, don't add to lp + lp += mult + log(sum); + } + } + } else { // {NaN, NA} evaluate to FALSE + if (isnan(t)) { + throw_domain_error( + function, "rt (response time)", rt_vec[i % Nrt], + "is a NaN and = ", ", but rt must be positive and finite"); + } else { + throw_domain_error( + function, "rt (response time)", t0_vec[i % Nt0], + "is not greater than t0 = ", + ", but it must be that rt - t0 is positive and finite"); + } + } + } + + return lp; +} + +template +inline return_type_t ddm_lpdf( + const T_rt& rt, const T_response& response, const T_a& a, const T_v& v, + const T_t0& t0, const T_w& w, const T_sv& sv) { + return ddm_lpdf(rt, response, a, v, t0, w, sv); +} + +} // namespace math +} // namespace stan +#endif diff --git a/test/prob/ddm/ddm_cdf_log_test.hpp b/test/prob/ddm/ddm_cdf_log_test.hpp new file mode 100644 index 00000000000..96985c081d1 --- /dev/null +++ b/test/prob/ddm/ddm_cdf_log_test.hpp @@ -0,0 +1,196 @@ +// Arguments: Doubles, Ints, Doubles, Doubles, Doubles, Doubles, Doubles +#include + +using stan::math::INFTY; +using std::vector; + +class AgradCdfDdm : public AgradCdfTest { + public: + void valid_values(vector >& parameters, vector& cdf) { + vector param(7); + + // each expected log_prob is calculated with the R package `fddm` as follows + // fddm::pfddm(rt, response, a, v, t0, w, sv, log = TRUE) + + param[0] = 1.0; // rt + param[1] = 1; // response + param[2] = 1.0; // a + param[3] = -1.0; // v + param[4] = 0.0; // t0 + param[5] = 0.5; // w + param[6] = 0.0; // sv + parameters.push_back(param); + log_prob.push_back(-0.3189645693469165); // expected log_prob + + param[0] = 1.0; // rt + param[1] = 2; // response + param[2] = 1.0; // a + param[3] = -1.0; // v + param[4] = 0.0; // t0 + param[5] = 0.5; // w + param[6] = 0.0; // sv + parameters.push_back(param); + log_prob.push_back(-1.318964569346917); // expected log_prob + + param[0] = 1.0; // rt + param[1] = 1; // response + param[2] = 2.0; // a + param[3] = -1.0; // v + param[4] = 0.0; // t0 + param[5] = 0.5; // w + param[6] = 0.0; // sv + parameters.push_back(param); + log_prob.push_back(-0.4105356025317656); // expected log_prob + + param[0] = 1.0; // rt + param[1] = 1; // response + param[2] = 1.0; // a + param[3] = 1.0; // v + param[4] = 0.0; // t0 + param[5] = 0.5; // w + param[6] = 0.0; // sv + parameters.push_back(param); + log_prob.push_back(-1.318964569346917); // expected log_prob + + param[0] = 1.0; // rt + param[1] = 1; // response + param[2] = 1.0; // a + param[3] = -1.0; // v + param[4] = 0.5; // t0 + param[5] = 0.5; // w + param[6] = 0.0; // sv + parameters.push_back(param); + log_prob.push_back(-0.403297055469638); // expected log_prob + + param[0] = 1.0; // rt + param[1] = 1; // response + param[2] = 1.0; // a + param[3] = -1.0; // v + param[4] = 0.0; // t0 + param[5] = 0.2; // w + param[6] = 0.0; // sv + parameters.push_back(param); + log_prob.push_back(-0.08206670817568271); // expected log_prob + + param[0] = 1.0; // rt + param[1] = 1; // response + param[2] = 1.0; // a + param[3] = -1.0; // v + param[4] = 0.0; // t0 + param[5] = 0.5; // w + param[6] = 1.0; // sv + parameters.push_back(param); + log_prob.push_back(-0.3658772238413681); // expected log_prob + } + + void invalid_values(vector& index, vector& value) { + // rt + index.push_back(0U); + value.push_back(0.0); + + index.push_back(0U); + value.push_back(-1.0); + + index.push_back(0U); + value.push_back(INFTY); + + index.push_back(0U); + value.push_back(-INFTY); + + // response + index.push_back(1U); + value.push_back(0); + + index.push_back(1U); + value.push_back(3); + + index.push_back(1U); + value.push_back(-1); + + index.push_back(1U); + value.push_back(INFTY); + + index.push_back(1U); + value.push_back(-INFTY); + + // a + index.push_back(2U); + value.push_back(0.0); + + index.push_back(2U); + value.push_back(-1.0); + + index.push_back(2U); + value.push_back(INFTY); + + index.push_back(2U); + value.push_back(-INFTY); + + // v + index.push_back(3U); + value.push_back(INFTY); + + index.push_back(3U); + value.push_back(-INFTY); + + // t0 + index.push_back(4U); + value.push_back(-1); + + index.push_back(4U); + value.push_back(INFTY); + + index.push_back(4U); + value.push_back(-INFTY); + + // w + index.push_back(5U); + value.push_back(-0.1); + + index.push_back(5U); + value.push_back(0.0); + + index.push_back(5U); + value.push_back(1.0); + + index.push_back(5U); + value.push_back(1.1); + + index.push_back(5U); + value.push_back(INFTY); + + index.push_back(5U); + value.push_back(-INFTY); + + // sv + index.push_back(6U); + value.push_back(-1.0); + + index.push_back(6U); + value.push_back(INFTY); + + index.push_back(6U); + value.push_back(-INFTY); + } + + bool has_upper_bound() { return true; } + + double upper_bound() { return 0.0; } + + template + stan::return_type_t cdf( + const T_rt& rt, const T_response& response, const T_a& a, const T_v& v, + const T_t0& t0, const T_w& w, const T_sv& sv, const T7&) { + return stan::math::ddm_lcdf(rt, response, a, v, t0, w, sv); + } + + template + stan::return_type_t + cdf_function(const T_rt& rt, const T_response& response, const T_a& a, + const T_v& v, const T_t0& t0, const T_w& w, const T_sv& sv, + const T7&) { + return stan::math::ddm_lcdf(rt, response, a, v, t0, w, sv); + } +}; diff --git a/test/prob/ddm/ddm_test.hpp b/test/prob/ddm/ddm_test.hpp new file mode 100644 index 00000000000..8e26229ce3c --- /dev/null +++ b/test/prob/ddm/ddm_test.hpp @@ -0,0 +1,202 @@ +// Arguments: Doubles, Ints, Doubles, Doubles, Doubles, Doubles, Doubles +#include + +using stan::math::INFTY; +using std::vector; + +class AgradDistributionDdm : public AgradDistributionTest { + public: + void valid_values(vector >& parameters, + vector& log_prob) { + vector param(7); + + // each expected log_prob is calculated with the R package `fddm` as follows + // fddm::dfddm(rt, response, a, v, t0, w, sv, log = TRUE) + + param[0] = 1.0; // rt + param[1] = 1; // response + param[2] = 1.0; // a + param[3] = -1.0; // v + param[4] = 0.0; // t0 + param[5] = 0.5; // w + param[6] = 0.0; // sv + parameters.push_back(param); + log_prob.push_back(-3.790072391414288); // expected log_prob + + param[0] = 1.0; // rt + param[1] = 2; // response + param[2] = 1.0; // a + param[3] = -1.0; // v + param[4] = 0.0; // t0 + param[5] = 0.5; // w + param[6] = 0.0; // sv + parameters.push_back(param); + log_prob.push_back(-4.790072391414288); // expected log_prob + + param[0] = 1.0; // rt + param[1] = 1; // response + param[2] = 2.0; // a + param[3] = -1.0; // v + param[4] = 0.0; // t0 + param[5] = 0.5; // w + param[6] = 0.0; // sv + parameters.push_back(param); + log_prob.push_back(-0.9754202070046213); // expected log_prob + + param[0] = 1.0; // rt + param[1] = 1; // response + param[2] = 1.0; // a + param[3] = 1.0; // v + param[4] = 0.0; // t0 + param[5] = 0.5; // w + param[6] = 0.0; // sv + parameters.push_back(param); + log_prob.push_back(-4.790072391414288); // expected log_prob + + param[0] = 1.0; // rt + param[1] = 1; // response + param[2] = 1.0; // a + param[3] = -1.0; // v + param[4] = 0.5; // t0 + param[5] = 0.5; // w + param[6] = 0.0; // sv + parameters.push_back(param); + log_prob.push_back(-1.072671222447106); // expected log_prob + + param[0] = 1.0; // rt + param[1] = 1; // response + param[2] = 1.0; // a + param[3] = -1.0; // v + param[4] = 0.0; // t0 + param[5] = 0.2; // w + param[6] = 0.0; // sv + parameters.push_back(param); + log_prob.push_back(-4.621465563321226); // expected log_prob + + param[0] = 1.0; // rt + param[1] = 1; // response + param[2] = 1.0; // a + param[3] = -1.0; // v + param[4] = 0.0; // t0 + param[5] = 0.5; // w + param[6] = 1.0; // sv + parameters.push_back(param); + log_prob.push_back(-4.07414598169426); // expected log_prob + } + + void invalid_values(vector& index, vector& value) { + // rt + index.push_back(0U); + value.push_back(0.0); + + index.push_back(0U); + value.push_back(-1.0); + + index.push_back(0U); + value.push_back(INFTY); + + index.push_back(0U); + value.push_back(-INFTY); + + // response + index.push_back(1U); + value.push_back(0); + + index.push_back(1U); + value.push_back(3); + + index.push_back(1U); + value.push_back(-1); + + index.push_back(1U); + value.push_back(INFTY); + + index.push_back(1U); + value.push_back(-INFTY); + + // a + index.push_back(2U); + value.push_back(0.0); + + index.push_back(2U); + value.push_back(-1.0); + + index.push_back(2U); + value.push_back(INFTY); + + index.push_back(2U); + value.push_back(-INFTY); + + // v + index.push_back(3U); + value.push_back(INFTY); + + index.push_back(3U); + value.push_back(-INFTY); + + // t0 + index.push_back(4U); + value.push_back(-1); + + index.push_back(4U); + value.push_back(INFTY); + + index.push_back(4U); + value.push_back(-INFTY); + + // w + index.push_back(5U); + value.push_back(-0.1); + + index.push_back(5U); + value.push_back(0.0); + + index.push_back(5U); + value.push_back(1.0); + + index.push_back(5U); + value.push_back(1.1); + + index.push_back(5U); + value.push_back(INFTY); + + index.push_back(5U); + value.push_back(-INFTY); + + // sv + index.push_back(6U); + value.push_back(-1.0); + + index.push_back(6U); + value.push_back(INFTY); + + index.push_back(6U); + value.push_back(-INFTY); + } + + template + stan::return_type_t log_prob( + const T_rt& rt, const T_response& response, const T_a& a, const T_v& v, + const T_t0& t0, const T_w& w, const T_sv& sv, const T7&) { + return stan::math::ddm_lpdf(rt, response, a, v, t0, w, sv); + } + + template + stan::return_type_t log_prob( + const T_rt& rt, const T_response& response, const T_a& a, const T_v& v, + const T_t0& t0, const T_w& w, const T_sv& sv, const T7&) { + return stan::math::ddm_lpdf(rt, response, a, v, t0, w, sv); + } + + template + stan::return_type_t + log_prob_function(const T_rt& rt, const T_response& response, const T_a& a, + const T_v& v, const T_t0& t0, const T_w& w, const T_sv& sv, + const T7&) { + return stan::math::ddm_lpdf(rt, response, a, v, t0, w, sv); + } +}; diff --git a/test/unit/math/prim/prob/ddm_cdf_log_test.cpp b/test/unit/math/prim/prob/ddm_cdf_log_test.cpp new file mode 100644 index 00000000000..f00522522f7 --- /dev/null +++ b/test/unit/math/prim/prob/ddm_cdf_log_test.cpp @@ -0,0 +1,90 @@ +#include +#include +#include + +TEST(ProbDdm, ddm_lcdf_matches_known_cdf) { + using stan::math::ddm_lcdf; + using std::exp; + using std::vector; + + // Note: + // 1. Each expected log_prob is calculated with the R package `fddm` using + // fddm::pfddm(rt, response, a, v, t0, w, sv, log = TRUE) + // 2. We must define a large error tolerance because we must check on the log + // scale, and since the sum of the log PDFs is very negative (~ -1800), + // exponentiating this very negative value would result in rounding to + // zero. This value of error tolerance is slightly arbitrary, but it is + // more useful than comparing zero to zero due to rounding issues. + + static const double pfddm_output = -972.3893812; + static const double err_tol = 1.0; + static const vector rt{0.1, 1, 10.0}; + static const int response = 1; // "lower" threshold + static const vector a{0.5, 1.0, 5.0}; + static const vector v{-2.0, 0.0, 2.0}; + static const double t0 = 0.0001; + static const vector w{0.2, 0.5, 0.8}; + static const vector sv{0.0, 0.5, 1.0, 1.5}; + + static const int n_rt = rt.size(); + static const int n_a = a.size(); + static const int n_v = v.size(); + static const int n_w = w.size(); + static const int n_sv = sv.size(); + static const int n_max = n_rt * n_a * n_v * n_w * n_sv; + vector rt_vec(n_max); + vector a_vec(n_max); + vector v_vec(n_max); + vector t0_vec(n_max, t0); + vector w_vec(n_max); + vector sv_vec(n_max); + + for (int i = 0; i < n_rt; i++) { + for (int j = i; j < n_max; j += n_rt) { + rt_vec[j] = rt[i]; + } + } + for (int i = 0; i < n_a; i++) { + for (int j = i; j < n_max; j += n_a) { + a_vec[j] = a[i]; + } + } + for (int i = 0; i < n_v; i++) { + for (int j = i; j < n_max; j += n_v) { + v_vec[j] = v[i]; + } + } + for (int i = 0; i < n_w; i++) { + for (int j = i; j < n_max; j += n_w) { + w_vec[j] = w[i]; + } + } + for (int i = 0; i < n_sv; i++) { + for (int j = i; j < n_max; j += n_sv) { + sv_vec[j] = sv[i]; + } + } + + EXPECT_NEAR((ddm_lcdf(rt, response, a, v, t0_vec, w, sv_vec)), pfddm_output, + err_tol); + EXPECT_NEAR((ddm_lcdf(rt, response, a, v, t0_vec, w, sv_vec)), + 0, // true makes ddm_lcdf() and wiener_lpdf() evaluate to 0 + err_tol); + EXPECT_NEAR((ddm_lcdf(rt, response, a, v, t0_vec, w, sv_vec)), + pfddm_output, err_tol); + EXPECT_NEAR((ddm_lcdf, int, vector, vector, + vector, vector, vector >( + rt, response, a, v, t0_vec, w, sv_vec)), + pfddm_output, err_tol); + EXPECT_NEAR( + (ddm_lcdf, int, vector, vector, + vector, vector, vector >( + rt, response, a, v, t0_vec, w, sv_vec)), + 0, // true makes ddm_lcdf() and wiener_lpdf() evaluate to 0 + err_tol); + EXPECT_NEAR( + (ddm_lcdf, int, vector, vector, + vector, vector, vector >( + rt, response, a, v, t0_vec, w, sv_vec)), + pfddm_output, err_tol); +} diff --git a/test/unit/math/prim/prob/ddm_test.cpp b/test/unit/math/prim/prob/ddm_test.cpp new file mode 100644 index 00000000000..875522ffe14 --- /dev/null +++ b/test/unit/math/prim/prob/ddm_test.cpp @@ -0,0 +1,324 @@ +#include +#include +#include +#include + +// ddm_lpdf(rt, response, a, v, t0, w, sv) +// wiener_lpdf(y, alpha, tau, beta, delta) +// rt <- y +// response <- 2 (wiener_lpdf() always uses the "upper" threshold) +// a <- alpha +// v <- delta +// t0 <- tau +// w <- beta +// sv is not included in wiener_lpdf() +// alpha -> a +// tau -> t0 +// beta -> w +// delta -> v +// Note: `response` and `sv` are not included in wiener_lpdf() + +// Check invalid arguments + +// rt +TEST(mathPrimScalProbDdmScal, invalid_rt) { + using stan::math::ddm_lpdf; + using stan::math::INFTY; + EXPECT_THROW(ddm_lpdf(0, 2, 1, -1, 0, 0.5, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(-1, 2, 1, -1, 0, 0.5, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(INFTY, 2, 1, -1, 0, 0.5, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(-INFTY, 2, 1, -1, 0, 0.5, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(NAN, 2, 1, -1, 0, 0.5, 0), std::domain_error); +} +TEST(mathPrimScalProbDdmMat, invalid_rt) { + using stan::math::ddm_lpdf; + using stan::math::INFTY; + std::vector rt{1, 0}; + EXPECT_THROW(ddm_lpdf(rt, 2, 1, -1, 0, 0.5, 0), std::domain_error); + rt[1] = -1; + EXPECT_THROW(ddm_lpdf(rt, 2, 1, -1, 0, 0.5, 0), std::domain_error); + rt[1] = INFTY; + EXPECT_THROW(ddm_lpdf(rt, 2, 1, -1, 0, 0.5, 0), std::domain_error); + rt[1] = -INFTY; + EXPECT_THROW(ddm_lpdf(rt, 2, 1, -1, 0, 0.5, 0), std::domain_error); + rt[1] = NAN; + EXPECT_THROW(ddm_lpdf(rt, 2, 1, -1, 0, 0.5, 0), std::domain_error); +} + +// response +TEST(mathPrimScalProbDdmScal, invalid_response) { + using stan::math::ddm_lpdf; + using stan::math::INFTY; + EXPECT_THROW(ddm_lpdf(1, 0, 1, -1, 0, 0.5, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 3, 1, -1, 0, 0.5, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, -1, 1, -1, 0, 0.5, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, INFTY, 1, -1, 0, 0.5, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, -INFTY, 1, -1, 0, 0.5, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, NAN, 1, -1, 0, 0.5, 0), std::domain_error); +} +TEST(mathPrimScalProbDdmMat, invalid_response) { + using stan::math::ddm_lpdf; + using stan::math::INFTY; + std::vector response{2, 0}; + EXPECT_THROW(ddm_lpdf(1, response, 1, -1, 0, 0.5, 0), std::domain_error); + response[1] = 3; + EXPECT_THROW(ddm_lpdf(1, response, 1, -1, 0, 0.5, 0), std::domain_error); + response[1] = -1; + EXPECT_THROW(ddm_lpdf(1, response, 1, -1, 0, 0.5, 0), std::domain_error); + response[1] = INFTY; + EXPECT_THROW(ddm_lpdf(1, response, 1, -1, 0, 0.5, 0), std::domain_error); + response[1] = -INFTY; + EXPECT_THROW(ddm_lpdf(1, response, 1, -1, 0, 0.5, 0), std::domain_error); + response[1] = NAN; + EXPECT_THROW(ddm_lpdf(1, response, 1, -1, 0, 0.5, 0), std::domain_error); +} + +// a +TEST(mathPrimScalProbDdmScal, invalid_a) { + using stan::math::ddm_lpdf; + using stan::math::INFTY; + EXPECT_THROW(ddm_lpdf(1, 2, 0, -1, 0, 0.5, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 2, -1, -1, 0, 0.5, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 2, INFTY, -1, 0, 0.5, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 2, -INFTY, -1, 0, 0.5, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 2, NAN, -1, 0, 0.5, 0), std::domain_error); +} +TEST(mathPrimScalProbDdmMat, invalid_a) { + using stan::math::ddm_lpdf; + using stan::math::INFTY; + std::vector a{1, 0}; + EXPECT_THROW(ddm_lpdf(1, 2, a, -1, 0, 0.5, 0), std::domain_error); + a[1] = -1; + EXPECT_THROW(ddm_lpdf(1, 2, a, -1, 0, 0.5, 0), std::domain_error); + a[1] = INFTY; + EXPECT_THROW(ddm_lpdf(1, 2, a, -1, 0, 0.5, 0), std::domain_error); + a[1] = -INFTY; + EXPECT_THROW(ddm_lpdf(1, 2, a, -1, 0, 0.5, 0), std::domain_error); + a[1] = NAN; + EXPECT_THROW(ddm_lpdf(1, 2, a, -1, 0, 0.5, 0), std::domain_error); +} + +// v +TEST(mathPrimScalProbDdmScal, invalid_v) { + using stan::math::ddm_lpdf; + using stan::math::INFTY; + EXPECT_THROW(ddm_lpdf(1, 2, 1, INFTY, 0, 0.5, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 2, 1, -INFTY, 0, 0.5, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 2, 1, NAN, 0, 0.5, 0), std::domain_error); +} +TEST(mathPrimScalProbDdmMat, invalid_v) { + using stan::math::ddm_lpdf; + using stan::math::INFTY; + std::vector v{1, INFTY}; + EXPECT_THROW(ddm_lpdf(1, 2, 1, v, 0, 0.5, 0), std::domain_error); + v[1] = -INFTY; + EXPECT_THROW(ddm_lpdf(1, 2, 1, v, 0, 0.5, 0), std::domain_error); + v[1] = NAN; + EXPECT_THROW(ddm_lpdf(1, 2, 1, v, 0, 0.5, 0), std::domain_error); +} + +// t0 +TEST(mathPrimScalProbDdmScal, invalid_t0) { + using stan::math::ddm_lpdf; + using stan::math::INFTY; + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, -1, 0.5, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, INFTY, 0.5, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, -INFTY, 0.5, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, NAN, 0.5, 0), std::domain_error); +} +TEST(mathPrimScalProbDdmMat, invalid_t0) { + using stan::math::ddm_lpdf; + using stan::math::INFTY; + std::vector t0{1, -1}; + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, t0, 0.5, 0), std::domain_error); + t0[1] = INFTY; + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, t0, 0.5, 0), std::domain_error); + t0[1] = -INFTY; + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, t0, 0.5, 0), std::domain_error); + t0[1] = NAN; + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, t0, 0.5, 0), std::domain_error); +} + +// w +TEST(mathPrimScalProbDdmScal, invalid_w) { + using stan::math::ddm_lpdf; + using stan::math::INFTY; + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, -0.1, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, 0, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, 1, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, 1.1, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, INFTY, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, -INFTY, 0), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, NAN, 0), std::domain_error); +} +TEST(mathPrimScalProbDdmMat, invalid_w) { + using stan::math::ddm_lpdf; + using stan::math::INFTY; + std::vector w{1, -0.1}; + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, w, 0), std::domain_error); + w[1] = 0; + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, w, 0), std::domain_error); + w[1] = 1; + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, w, 0), std::domain_error); + w[1] = 1.1; + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, w, 0), std::domain_error); + w[1] = INFTY; + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, w, 0), std::domain_error); + w[1] = -INFTY; + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, w, 0), std::domain_error); + w[1] = NAN; + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, w, 0), std::domain_error); +} + +// sv +TEST(mathPrimScalProbDdmScal, invalid_sv) { + using stan::math::ddm_lpdf; + using stan::math::INFTY; + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, 0.5, -1), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, 0.5, INFTY), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, 0.5, -INFTY), std::domain_error); + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, 0.5, NAN), std::domain_error); +} +TEST(mathPrimScalProbDdmMat, invalid_sv) { + using stan::math::ddm_lpdf; + using stan::math::INFTY; + std::vector sv{1, -1}; + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, 0.5, sv), std::domain_error); + sv[1] = INFTY; + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, 0.5, sv), std::domain_error); + sv[1] = -INFTY; + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, 0.5, sv), std::domain_error); + sv[1] = NAN; + EXPECT_THROW(ddm_lpdf(1, 2, 1, -1, 0, 0.5, sv), std::domain_error); +} + +TEST(ProbDdm, ddm_lpdf_matches_wiener_lpdf) { + using stan::math::ddm_lpdf; + using stan::math::wiener_lpdf; + using std::exp; + using std::vector; + // Notes: + // 1. define error tolerance for PDF approximations, use double tolerance to + // allow for convergence (of the truncated infinite sum) from above and below + // 2. wiener_lpdf() only uses the "upper" threshold in the DDM, but the + // "lower" threshold maps v -> -v and w -> 1-w, and this is covered in the + // parameter values defined below + double err_tol = 2 * 0.000001; + vector rt{0.1, 1, 10.0}; + int response = 2; // wiener_lpdf() always uses the "upper" threshold + vector a{0.5, 1.0, 5.0}; + vector v{-2.0, 0.0, 2.0}; + double t0 = 0.0001; // t0 (i.e., tau) needs to be > 0 for wiener_lpdf() + vector w{0.2, 0.5, 0.8}; + double sv + = 0.0; // sv is not included in wiener_lpdf(), and thus it must be 0 + + int n_rt = rt.size(); + int n_a = a.size(); + int n_v = v.size(); + int n_w = w.size(); + int n_max = n_rt * n_a * n_v * n_w; + vector rt_vec(n_max); + vector a_vec(n_max); + vector v_vec(n_max); + vector t0_vec(n_max, t0); + vector w_vec(n_max); + + for (int i = 0; i < n_rt; i++) { + for (int j = i; j < n_max; j += n_rt) { + rt_vec[j] = rt[i]; + } + } + for (int i = 0; i < n_a; i++) { + for (int j = i; j < n_max; j += n_a) { + a_vec[j] = a[i]; + } + } + for (int i = 0; i < n_v; i++) { + for (int j = i; j < n_max; j += n_v) { + v_vec[j] = v[i]; + } + } + for (int i = 0; i < n_w; i++) { + for (int j = i; j < n_max; j += n_w) { + w_vec[j] = w[i]; + } + } + + // The PDF approximation error tolerance is based on the non-log version of + // the PDF. To account for this, we compare the exponentiated version of the + // *_lpdf results. + EXPECT_NEAR(exp(ddm_lpdf(rt, response, a, v, t0_vec, w, sv)), + exp(wiener_lpdf(rt_vec, a_vec, t0_vec, w_vec, v_vec)), err_tol); + EXPECT_NEAR(exp(ddm_lpdf(rt, response, a, v, t0_vec, w, sv)), + exp(wiener_lpdf(rt_vec, a_vec, t0_vec, w_vec, v_vec)), + err_tol); + EXPECT_NEAR(exp(ddm_lpdf(rt, response, a, v, t0_vec, w, sv)), + exp(wiener_lpdf(rt_vec, a_vec, t0_vec, w_vec, v_vec)), + err_tol); + EXPECT_NEAR(exp(ddm_lpdf, int, vector, vector, + vector, vector, double>( + rt, response, a, v, t0_vec, w, sv)), + exp(wiener_lpdf, vector, vector, + vector, vector >( + rt_vec, a_vec, t0_vec, w_vec, v_vec)), + err_tol); + EXPECT_NEAR( + exp(ddm_lpdf, int, vector, vector, + vector, vector, double>(rt, response, a, v, + t0_vec, w, sv)), + exp(wiener_lpdf, vector, vector, + vector, vector >(rt_vec, a_vec, t0_vec, + w_vec, v_vec)), + err_tol); + EXPECT_NEAR( + exp(ddm_lpdf, int, vector, vector, + vector, vector, double>(rt, response, a, v, + t0_vec, w, sv)), + exp(wiener_lpdf, vector, vector, + vector, vector >(rt_vec, a_vec, t0_vec, + w_vec, v_vec)), + err_tol); + + // check with variable drift rate (against results from R package `fddm`) + // Notes: + // 1. We must redefine the error tolerance because we must check on the log + // scale as the sum of the log PDFs is very negative (~ -1800), and + // exponentiating this very negative value would result in rounding to zero. + // This value of error tolerance is slightly arbitrary, but it is more useful + // than comparing zero to zero due to rounding issues. + err_tol = 1.0; + static const double dfddm_output = -1800.6359154; + vector sv_vals{0.0, 0.5, 1.0, 1.5}; + int n_sv = sv_vals.size(); + n_max *= n_sv; + vector sv_vec(n_max); + for (int i = 0; i < n_sv; i++) { + for (int j = i; j < n_max; j += n_sv) { + sv_vec[j] = sv_vals[i]; + } + } + + EXPECT_NEAR((ddm_lpdf(rt, response, a, v, t0_vec, w, sv_vec)), dfddm_output, + err_tol); + EXPECT_NEAR((ddm_lpdf(rt, response, a, v, t0_vec, w, sv_vec)), + 0, // true makes ddm_lpdf() and wiener_lpdf() evaluate to 0 + err_tol); + EXPECT_NEAR((ddm_lpdf(rt, response, a, v, t0_vec, w, sv_vec)), + dfddm_output, err_tol); + EXPECT_NEAR((ddm_lpdf, int, vector, vector, + vector, vector, vector >( + rt, response, a, v, t0_vec, w, sv_vec)), + dfddm_output, err_tol); + EXPECT_NEAR( + (ddm_lpdf, int, vector, vector, + vector, vector, vector >( + rt, response, a, v, t0_vec, w, sv_vec)), + 0, // true makes ddm_lpdf() and wiener_lpdf() evaluate to 0 + err_tol); + EXPECT_NEAR( + (ddm_lpdf, int, vector, vector, + vector, vector, vector >( + rt, response, a, v, t0_vec, w, sv_vec)), + dfddm_output, err_tol); +}