Skip to content

Commit

Permalink
feat: Add grad_log_prob method to Model
Browse files Browse the repository at this point in the history
Added grad_log_prob method to Model instances, allowing users
to calculate the gradient of the log posterior evaluated at
the unconstrained parameters.

This feature is accompanied by a test: the grad_log_prob method is
validated by comparing the output against an analytical calculation
of the gradient.
  • Loading branch information
mjcarter95 authored and riddell-stan committed Jan 20, 2021
1 parent 399ec6c commit 66ce87c
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 1 deletion.
37 changes: 36 additions & 1 deletion stan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,8 @@ def log_prob(self, unconstrained_parameters: List[float], adjust_transform: bool
unconstrained_parameters: A sequence of unconstrained parameters.
adjust_transform: Apply jacobian adjust transform.
Returns: The log probability of the unconstrained parameters.
Returns:
The log probability of the unconstrained parameters.
Notes:
The unconstrained parameters are passed to the log_prob
Expand All @@ -346,6 +347,40 @@ async def go():

return asyncio.run(go())

def grad_log_prob(self, unconstrained_parameters: List[float]) -> float:
"""Calculate the gradient of the log posterior evaluated at
the unconstrained parameters.
Arguments:
unconstrained_parameters: A sequence of unconstrained parameters.
adjust_transform: Apply jacobian adjust transform.
Returns:
The gradient of the log posterior evalauted at the
unconstrained parameters.
Notes:
The unconstrained parameters are passed to the log_prob_grad
function in stan::model.
"""
assert isinstance(self.data, dict)

payload = {
"data": self.data,
"unconstrained_parameters": unconstrained_parameters,
}

async def go():
async with stan.common.httpstan_server() as (host, port):
log_prob_grad_url = f"http://{host}:{port}/v1/{self.model_name}/log_prob_grad"
async with aiohttp.request("POST", log_prob_grad_url, json=payload) as resp:
response_payload = await resp.json()
if resp.status != 200:
raise RuntimeError(response_payload)
return (await resp.json())["log_prob_grad"]

return asyncio.run(go())


def build(program_code, data=None, random_seed=None) -> Model:
"""Build (compile) a Stan program.
Expand Down
39 changes: 39 additions & 0 deletions tests/test_grad_log_prob.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Test model with array parameter."""
import random

import numpy as np
import pytest

import stan

program = """
parameters {
real y;
}
model {
y ~ normal(0, 1);
}
"""

num_samples = 1000
num_chains = 4


def gaussian_gradient(x: float, mean: float, var: float) -> float:
"""Analytically evaluate Gaussian gradient."""
gradient = (mean - x) / (var ** 2)
return gradient


@pytest.fixture
def posterior(request):
return stan.build(request.param, random_seed=1)


@pytest.mark.parametrize("posterior", [program], indirect=True)
def test_grad_log_prob(posterior):
"""Test log probability against sampled model with no restriction."""
y = random.uniform(0, 10)
lp__ = gaussian_gradient(y, 0, 1)
lp = posterior.grad_log_prob(unconstrained_parameters=[y])
assert np.allclose(lp__, lp)

0 comments on commit 66ce87c

Please sign in to comment.