Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Functionality necessary to nest Stan models in other PPLs #169

Open
sethaxen opened this issue Sep 11, 2023 · 13 comments
Open

Functionality necessary to nest Stan models in other PPLs #169

sethaxen opened this issue Sep 11, 2023 · 13 comments
Labels
enhancement New feature or request question Further information is requested

Comments

@sethaxen
Copy link
Contributor

Someone asked me if it would be possible to nest an existing Stan model within a model defined in another PPL using bridgestan. Currently the major limitation of doing this is that we have no way to autodiff through the constraining transformation. We also would in general need to be able to separately compute the following:

  • the log-density wrt constrained parameters
  • the gradient of the above
  • the constraining transform
  • the unconstraining transform
  • the pushforward/pullback (AKA jvp and vjp) of the constraining transform
  • the logdetjac correction wrt unconstrained parameters
  • the gradient of the above

As far as I can tell, only the two transforms are currently part of the API. The available log-density and gradient are only wrt unconstrained parameters, the Jacobian adjustment is only available as part of the density calculation, and no AD primitives are available for the transforms.

This is purely exploratory at this stage, but I wonder if it would be feasible and of interest to include the missing functionality in the API.

@WardBrian
Copy link
Collaborator

I think all of these would require upstream changes in the Stan library and stanc compiler.

Are all of those quantities well defined for a generic Stan model? The gradient w.r.t constrained parameters seems particularly problematic for constraints like the simplex which have a different number of dimensions on the unconstrained scale

@bob-carpenter
Copy link
Collaborator

The ones with checks are available right now as functions in the Stan math library or in the generated code for a model.

  • the log-density wrt constrained parameters
  • the gradient of the above
  • the constraining transform
  • the unconstraining transform
  • the pushforward/pullback (AKA jvp and vjp) of the constraining transform
  • the logdetjac correction wrt unconstrained parameters
  • the gradient of the above

As @WardBrian notes, doing the other ones would require modifying how we do code generation in the transpiler.

We can pull out log density w.r.t. constrained parameters and gradient pretty easily. During code generation, we separately generate (a) code for the constraining transform and Jacobian determinants, followed by (b) code for the constrained log density and gradients.

I'm not sure what exactly you want with jvp (Jacobian-vector product?). We can calculate a Jacobian and do vector-jacobean products explicitly, but we don't have the constraining transforms coded that way, so we can't do this very efficiently compared to an optimal implementation. Instead, they do the transform and add the log Jacobian determinant terms to the target density then we do autodiff. Making this efficient would require recoding the transforms.

@sethaxen
Copy link
Contributor Author

Are all of those quantities well defined for a generic Stan model? The gradient w.r.t constrained parameters seems particularly problematic for constraints like the simplex which have a different number of dimensions on the unconstrained scale

The magic of AD is that at any given point in the program, you are computing either the partial derivative of a real scalar output with respect to a real scalar intermediate (reverse-mode) or computing the partial derivative of a real scalar intermediate with respect to a real scalar input (forward-mode). It doesn't matter whether those intermediates have constraints or not; with application of the chain rule everything just works.

Simplex is an easy one because it has a linear constraint, but let's take the harder case of a unit vector, which has a nonlinear constraint. The Jacobian of l2-normalization $y \mapsto x = y/\lVert y \rVert$ is $(I - x x^\top) / \lVert y \rVert$. The left term $(I - x x^\top)$ projects any vector to the plane tangent to the unit sphere at $x$, so simply applying the chain rule in either forward- or reverse-mode ensures that everything works correctly.

I'm not sure what exactly you want with jvp (Jacobian-vector product?). We can calculate a Jacobian and do vector-jacobean products explicitly, but we don't have the constraining transforms coded that way, so we can't do this very efficiently compared to an optimal implementation. Instead, they do the transform and add the log Jacobian determinant terms to the target density then we do autodiff. Making this efficient would require recoding the transforms.

The reason this is necessary is because if a constrained model parameter defined in a Stan model is ever used elsewhere in the model defined in the other PPL, then we need a way to "backpropagate" the gradient of that parameter back to the unconstrained parameter, which is also handled by the other PPL. The only way to do this is if such a jacobian-vector-product primitive is available for the constraining transform; without such a primitive, this use of a Stan model is not possible (except maybe by computing full Jacobians - 🤢 )

Do you think the necessary changes to the transforms would be a major project?

@sethaxen
Copy link
Contributor Author

To be clear, I'm just gathering information at this point. I don't have a use case for this, but the people I was talking with do.

@bob-carpenter
Copy link
Collaborator

Do you think the necessary changes to the transforms would be a major project?

We could use autodiff to calculate the Jacobian of the transform and then explicitly multiply it by a vector. It wouldn't be efficient, but we'd get the right answer. Right now, we code the (scalar) log determinant of the absolute Jacobian and its derivative. We don't need Jacobian-vector products for transforms anywhere in Stan, so we've never coded it. It probably wouldn't be too hard to code all these as Jacobian-vector products as there aren't that many transforms and they're all well described mathematically in our reference manual.

@sethaxen
Copy link
Contributor Author

We don't need Jacobian-vector products for transforms anywhere in Stan, so we've never coded it.
It probably wouldn't be too hard to code all these as Jacobian-vector products as there aren't that many transforms and they're all well described mathematically in our reference manual.

No problem, vector-Jacobian-product is generally more useful anyways for these applications. And these are implicitly available, right, in that Stan can reverse-mode AD through the transforms? When Stan computes e.g. l = lpdf_exponential(x | lambda), where lambda is also a sampled parameter, the gradient of l must in the case of reverse-mode AD be backpropagated to the unconstrained parameter from which lambda is computed, so the vector-jacobian-product of the constraining transform must be implicitly computed.

So for VJPs at least, would it be possible to simply provide a function like (pseucodode)

transform_adjoint(constrain_type, y, dx) -> dy

Here y is the unconstrained parameter, dx is the gradient of the constrained parameter, and dy is the corresponding gradient of the unconstrained parameter. The function would compute the transform using reverse-mode AD primitives. then input the provided gradient and backpropagate it to the unconstrained transform. This is only wasteful in the sense that the user would end up computing the transform twice (once for the transform, and once for the VJP), but this will generally be much cheaper than computing the full Jacobian.

@bob-carpenter
Copy link
Collaborator

We don't code the transforms with either jvp or vjp to use Seth's lingo. When stan executes

parameters {
  real<lower=0> lambda:
  ...
model {
  target += exponential_lpdf(y | lambda);
  ...

In this example, lambda is the result of applying exp to the (implicit) unconstrained parameter corresponding to log(lambda). So the derivative propagated through exponential_lpdf will propagate back to the unconstrained scale.

Yes, it'd be possible to implement the transform_adjoint function you propose. But no, we don't have the pieces of it already coded that way, so we'd have to do it all from scratch.

Out of curiosity, why does someone want all this?

@sethaxen
Copy link
Contributor Author

In this example, lambda is the result of applying exp to the (implicit) unconstrained parameter corresponding to log(lambda). So the derivative propagated through exponential_lpdf will propagate back to the unconstrained scale.

Yes, this is precisely what I mean. The transform is computed on a number, and the derivative is backpropagated by the AD. The part of the program that backpropagates is a vjp automatically constructed by the AD. I don't really understand why one would need to rewrite the transforms from scratch to allow this to be done on the transform by itself instead of the transform composed with the downstream operations that compute the logpdf.

Is it the case that Stan's reverse-mode AD API only supports computing gradients? If so, one can still implement a VJP even if the only available primitive is a gradient by computing the gradient of the function y -> dot(constrain(y), dx), where dot is the usual dot product. This gradient is exactly the result of applying the VJP to dx.

Out of curiosity, why does someone want all this?

I haven't seen a detailed use case yet, but I think it's to support submodels, e.g. as supported by Turing and PyMC: https://discourse.pymc.io/t/nested-submodels/11725. The idea is that one's model might be composed of modular components that each have been well-characterized, are considered reliable, and have efficient and numerically stable implementations. I suspect the goal might be to support cases where such a Stan model already exists for a subset of the parameters, and one wants to augment the model with additional parameters or combine it with a separate model; that other model might require language features unavailable in Stan.

@bob-carpenter
Copy link
Collaborator

Stan's really not compatible with that use case because of the way blocks are defined. Maria Gorinova wrote a cool thesis with an alternative design that's more like Turing.jl that would accommodate the generality of being able to define submodules modularly. I haven't seen this functionality in Turing or PyMC---is there a pointer to how they do it somewhere?

I'm having trouble imaging where that'd be the right thing to do from a software perspective (mixing Stan and something else), because Stan code only gets so complicated and is thus not too big of a bottleneck just to reimplement.

Yes, it's the case that Stan's reverse-mode AD only computes gradients. We didn't template out the types so that we could autodiff through reverse mode. We have forward mode implemented for higher-order gradients, but not for our implicit functions like ODE solvers.

We have the transforms and inverse transforms implemented with templated functions. So we can do all of this by brute force with autodiff by explicitly constructing the Jacobian. To evaluate a Jacobian-adjoint product more efficiently without building the Jacobian explicitly, we'd have to rewrite the transform code.

@yebai
Copy link

yebai commented Sep 12, 2023

@bob-carpenter Here is an example of using a Stan model inside a Turing model, and it is very helpful in two ways:

  1. It allows users to benefit from the excellent Stan math libraries; Julia is fast in many places but still underperforms Stan due to auto diff issues.
  2. It allows users to explore other inference options than HMC. In certain cases, more efficient algorithms exist for real-time inference.

@sethaxen
Copy link
Contributor Author

I haven't seen this functionality in Turing or PyMC---is there a pointer to how they do it somewhere?

Not that I'm aware of. I think that would get too much into the internals.

@bob-carpenter Here is an example of using a Stan model inside a Turing model, and it is very helpful in two ways:

I don't think that example is really a great case for this feature though. Once can benefit from Stan's math libraries and use alternative samplers from HMC already without needing the additional features described here. e.g. the example in the readme of StanLogDensityProblems shows how to sample a Stan model with DynamicHMC.jl, and one could easily swap in Pathfinder.jl to build a variational model instead.

I think only needs these features if one wants to combine a Stan model with additional parameters whose log-density depends on the constrained parameters in the Stan model, and it would be nice to see a compelling use case for that.

@yebai
Copy link

yebai commented Sep 12, 2023

I think only needs these features if one wants to combine a Stan model with additional parameters whose log-density depends on the constrained parameters in the Stan model, and it would be nice to see a compelling use case for that.

Fair point. Models involving discrete variables and non-parametric models where the model dimensionality changes during inference. These are hard to handle in Stan but are unavoidable in some applied areas.

@bob-carpenter
Copy link
Collaborator

bob-carpenter commented Sep 12, 2023

@yebai: For algorithm development, we've been using BridgeStan, but that's still limited to Stan models. @roualdes originally developed it for use in Julia.

Models involving discrete variables and non-parametric models where the model dimensionality changes during inference.

Those problems are too hard for us. We like to stick to examples where we can get simulation-based calibration to work. Even in simpler cases like K-means clustering where we can marginalize all the discrete parameters, the posterior is too multimodal to sample.

@WardBrian WardBrian added enhancement New feature or request question Further information is requested labels Dec 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request question Further information is requested
Projects
None yet
Development

No branches or pull requests

4 participants