Skip to content

wiener_lpdf reverse-mode derivative wrt w disagrees with finite differences of Stan Math value #3322

@martonaronvarga

Description

@martonaronvarga

Summary

stan::math::wiener_lpdf appears to return an incorrect reverse-mode adjoint for the relative starting point parameter w.

The log-density value itself appears smooth and internally consistent, but the reverse-mode derivative w.adj() disagrees with central finite differences of Stan Math’s own scalar wiener_lpdf value function.

The mismatch appears in both:

stan::math::wiener_lpdf(y, a, t0, w, v, sv)

and

stan::math::wiener_lpdf(y, a, t0, w, v, sv, sw, st0)

The issue also persists when sv = 0, so it does not appear to be limited to the drift-variability correction.

Environment

Stan Math version: v5.2.0
Stan Math commit: 4c5180a0d9bf0e686fdd9e71ad609fb8ffbf3b27
Compiler: g++ (GCC) 15.2.0
Compiler flags: -std=c++17 -O3 -D_REENTRANT \
  -I ~/cmdstan/stan/lib/stan_math/lib/eigen_3.4.0 \
  -I ~/cmdstan/stan/lib/stan_math/lib/boost_1.78.0 \
  -I ~/cmdstan/stan/lib/stan_math/lib/sundials_6.1.1/include \
  -I ~/cmdstan/stan/lib/stan_math/ \
  -I ~/cmdstan/stan/lib/stan_math/lib/tbb_2020.3/include \
  wiener_w_grad_bug.cpp -o wiener_w_grad_bug \
  -L ~/cmdstan/stan/lib/stan_math/lib/tbb/ \
  -ltbb
OS: NixOS 26.05 (Yarara)
CmdStan commit: 1ddbddeab633b9d0e4620d1d425325e49e3695c2

Minimal reproducer

#include <cmath>
#include <iomanip>
#include <iostream>
#include <stan/math.hpp>

using stan::math::var;

template <typename F>
double central_diff(F&& f, double x, double h) {
  return (f(x + h) - f(x - h)) / (2.0 * h);
}

double lp5_double(double y, double a, double t0, double w, double v,
                  double sv) {
  return stan::math::wiener_lpdf(y, a, t0, w, v, sv);
}

double lp_full_double(double y, double a, double t0, double w, double v,
                      double sv, double sw, double st0) {
  return stan::math::wiener_lpdf(y, a, t0, w, v, sv, sw, st0);
}

void ad_5param(double y, double a, double t0, double w, double v,
               double sv) {
  var yv = y;
  var av = a;
  var t0v = t0;
  var wv = w;
  var vv = v;
  var svv = sv;

  auto lp = stan::math::wiener_lpdf(yv, av, t0v, wv, vv, svv);
  lp.grad();

  std::cout << "AD_5PARAM"
            << " lp=" << lp.val()
            << " gy=" << yv.adj()
            << " ga=" << av.adj()
            << " gt0=" << t0v.adj()
            << " gw=" << wv.adj()
            << " gv=" << vv.adj()
            << " gsv=" << svv.adj()
            << "\n";

  stan::math::recover_memory();
}

void ad_full(double y, double a, double t0, double w, double v, double sv,
             double sw, double st0) {
  var yv = y;
  var av = a;
  var t0v = t0;
  var wv = w;
  var vv = v;
  var svv = sv;
  var swv = sw;
  var st0v = st0;

  auto lp = stan::math::wiener_lpdf(yv, av, t0v, wv, vv, svv, swv, st0v);
  lp.grad();

  std::cout << "AD_FULL"
            << " lp=" << lp.val()
            << " gy=" << yv.adj()
            << " ga=" << av.adj()
            << " gt0=" << t0v.adj()
            << " gw=" << wv.adj()
            << " gv=" << vv.adj()
            << " gsv=" << svv.adj()
            << " gsw=" << swv.adj()
            << " gst0=" << st0v.adj()
            << "\n";

  stan::math::recover_memory();
}

int main() {
  std::cout << std::setprecision(17);

  const double y = 6.0;
  const double a = 10.0;
  const double t0 = 0.01;
  const double w = 0.1;
  const double v = -3.0;
  const double sv = 0.2;
  const double sw = 0.1;
  const double st0 = 0.0;

  std::cout << "=== five parameter: sv = 0.2 ===\n";
  ad_5param(y, a, t0, w, v, sv);

  for (double h : {1e-3, 1e-4, 1e-5, 1e-6, 1e-7}) {
    const double fd = central_diff(
        [&](double ww) { return lp5_double(y, a, t0, ww, v, sv); },
        w,
        h);
    std::cout << "FD_5PARAM h=" << h << " gw_fd=" << fd << "\n";
  }

  for (double ww : {0.08, 0.09, 0.10, 0.11, 0.12}) {
    std::cout << "GRID_5PARAM w=" << ww
              << " lp=" << lp5_double(y, a, t0, ww, v, sv)
              << "\n";
  }

  std::cout << "\n=== five parameter: sv = 0 control ===\n";
  ad_5param(y, a, t0, w, v, 0.0);

  for (double h : {1e-3, 1e-4, 1e-5, 1e-6, 1e-7}) {
    const double fd = central_diff(
        [&](double ww) { return lp5_double(y, a, t0, ww, v, 0.0); },
        w,
        h);
    std::cout << "FD_5PARAM_SV0 h=" << h << " gw_fd=" << fd << "\n";
  }

  std::cout << "\n=== full parameter: sw = 0.1, st0 = 0 ===\n";
  ad_full(y, a, t0, w, v, sv, sw, st0);

  for (double h : {1e-3, 1e-4, 1e-5, 1e-6, 1e-7}) {
    const double fd = central_diff(
        [&](double ww) {
          return lp_full_double(y, a, t0, ww, v, sv, sw, st0);
        },
        w,
        h);
    std::cout << "FD_FULL h=" << h << " gw_fd=" << fd << "\n";
  }

  for (double ww : {0.08, 0.09, 0.10, 0.11, 0.12}) {
    std::cout << "GRID_FULL w=" << ww
              << " lp=" << lp_full_double(y, a, t0, ww, v, sv, sw, st0)
              << "\n";
  }

  return 0;
}

Observed output

=== five parameter: sv = 0.2 ===
AD_5PARAM lp=-50.539106208790059 gy=-1.4309088946356134 ga=-3.1387340804639763 gt0=1.4309088946356134 gw=-82.160098301170336 gv=21.757018393030009 gsv=93.707129083578479
FD_5PARAM h=0.001 gw_fd=36.633174261808676
FD_5PARAM h=0.0001 gw_fd=36.632911247345135
FD_5PARAM h=1.0000000000000001e-05 gw_fd=36.632908617661997
FD_5PARAM h=9.9999999999999995e-07 gw_fd=36.632908592793001
FD_5PARAM h=9.9999999999999995e-08 gw_fd=36.632908617661997
GRID_5PARAM w=0.080000000000000002 lp=-51.286630294961604
GRID_5PARAM w=0.089999999999999997 lp=-50.908794872521725
GRID_5PARAM w=0.10000000000000001 lp=-50.539106208790059
GRID_5PARAM w=0.11 lp=-50.175601258770968
GRID_5PARAM w=0.12 lp=-49.81692817570945

=== five parameter: sv = 0 control ===
AD_5PARAM lp=-62.167447556524714 gy=-3.6469225854648935 ga=-3.9219867426130568 gt0=3.6469225854648935 gw=-73.457290943958327 gv=26.969999999999999 gsv=0
FD_5PARAM_SV0 h=0.001 gw_fd=45.335981619022192
FD_5PARAM_SV0 h=0.0001 gw_fd=45.335718604526676
FD_5PARAM_SV0 h=1.0000000000000001e-05 gw_fd=45.335715974559314
FD_5PARAM_SV0 h=9.9999999999999995e-07 gw_fd=45.335715952887767
FD_5PARAM_SV0 h=9.9999999999999995e-08 gw_fd=45.33571594578234

=== full parameter: sw = 0.1, st0 = 0 ===
AD_FULL lp=-50.058755379483443 gy=-1.5006905642810147 ga=-3.0292867403094825 gt0=1.5006905642810147 gw=-34.586923958070955 gv=21.554056380238126 gsv=91.955143887665372 gsw=8.7239803669499629 gst0=0
FD_FULL h=0.001 gw_fd=35.7193384750083
FD_FULL h=0.0001 gw_fd=35.719188331704288
FD_FULL h=1.0000000000000001e-05 gw_fd=35.719186830363014
FD_FULL h=9.9999999999999995e-07 gw_fd=35.719186815441617
FD_FULL h=9.9999999999999995e-08 gw_fd=35.719186861626895
GRID_FULL w=0.080000000000000002 lp=-50.783094978471475
GRID_FULL w=0.089999999999999997 lp=-50.418236984588823
GRID_FULL w=0.10000000000000001 lp=-50.058755379483443
GRID_FULL w=0.11 lp=-49.703547959822139
GRID_FULL w=0.12 lp=-49.351836459209828

Expected behavior

The reverse-mode adjoint for w should agree with central finite differences of the same scalar value function.

For the 5-parameter case with sv = 0.2, the finite-difference derivative is stable around:

gw ~ +36.63290859

but reverse-mode reports:

gw = -82.160098301170336

For the sv = 0 control, the finite-difference derivative is stable around:

gw ~ +45.33571595

but reverse-mode reports:

gw = -73.457290943958327

For the full-parameter case with sw = 0.1, st0 = 0, the finite-difference derivative is stable around:

gw ~ +35.719186815

but reverse-mode reports:

gw = -34.586923958070955

Possibility of a bug

For a scalar differentiable log density

$$l(w) = \log p(y\ \bar\ a, t_0, w, v, s_v),$$

the reverse-mode adjoint should satisfy

$$w.adj()\ \sim\ \frac{l(w + h) - l(w - h)}{2h}.$$

This is not a comparison against an external implementation. (However, that's how I discovered the issue, as Stan tests for a WienR reference implementation which appears to be also affected). The finite differences are computed from Stan Math's own scalaer wiener_lpdf value.

The value grids are locally monotone increasing in w, for example:

five-param:
w=0.08 lp=-51.286630294961604
w=0.10 lp=-50.539106208790059
w=0.12 lp=-49.81692817570945

and

full-param:
w=0.08 lp=-50.783094978471475
w=0.10 lp=-50.058755379483443
w=0.12 lp=-49.351836459209828

so the local derivate with respect to w should be positive. The finite differences are stable over several step sizes, but the reverse-mode adjoint is negative and of substantially different magnitude.

The sv = 0 control suggests this is not solely cause by the drift-variability term.

Location

The likely bug path is internal::wiener5_grad_w in stan/math/prim/prob/wiener5_lpdf.hpp. The full overload delegates the no-sw, no-st0 case to the 5-parameter overload, and for the full sw > 0 case it integrates internal::wiener5_grad_w for the w partial. The source page shows the full overload delegating to wiener_lpdf(..., sv, precision_derivatives) when sw == 0 && st0 == 0, and using internal::wiener5_grad_w inside the T_w partial path for the full case.

So the fix should start at the 5-parameter w derivative, not in cubature or the full wrapper.

Related tests

The existing Wiener full test appears to compare gradients against constants generated externally, for example from WienR. this issue suggests that at least the w / beta adjoint should also be tested against finite differences of Stan Math's own value function.

A possible regression test would check w.adj() against central finite differences for:

stan::math::wiener_lpdf(y, a, t0, w, v, sv)

and

stan::math::wiener_lpdf(y, a, t0, w, v, sv, sw, st0)

at the parameter values above.

Additional note

For the full case with sw > 0 and st == 0, there is also a simple endpoint identity for the derivative with respect to the center w0 of the uniform starting-point variability interval:

$$F(w_0) = \frac{1}{s_w} \int_{w_0 - s_w / 2}^{w_0 + s_w / 2} f(u) \,du,$$

where $f(u)$ is the 5-parameter density evaluated at starting point $u$.
Therefore

$$l(w_0) = \log F(w_0)$$

has derivative

$$\frac{\partial l}{\partial w_0} = \frac{f(w_0 + s_w / 2) - f(w_0 - s_w / 2)}{s_w F(w_0)}.$$

This identity agrees with the finite-difference sign and magnitude in an independent implementation, not with the reverse-mode adjoint shown above.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions