Skip to content

Commit

Permalink
feat: Add log_prob method to Model
Browse files Browse the repository at this point in the history
Added log_prob method to Model instances, allowing users to calculate
the log probability of a list of unconstrained parameters.

This feature is accompanied by a test: the log_prob method is validated
by comparing the output against the log probability (lp__) extracted
from a model fit.

Closes #40
  • Loading branch information
mjcarter95 authored and riddell-stan committed Jan 17, 2021
1 parent 2c530b8 commit 2426d63
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 0 deletions.
33 changes: 33 additions & 0 deletions stan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,39 @@ async def go():

return asyncio.run(go())

def log_prob(self, unconstrained_parameters: List[float], adjust_transform: bool = True) -> float:
"""Calculate the log probability of a set of unconstrained parameters.
Arguments:
unconstrained_parameters: A sequence of unconstrained parameters.
adjust_transform: Apply jacobian adjust transform.
Returns: The log probability of the unconstrained parameters.
Notes:
The unconstrained parameters are passed to the log_prob
function in stan::model.
"""
assert isinstance(self.data, dict)

payload = {
"data": self.data,
"unconstrained_parameters": unconstrained_parameters,
"adjust_transform": adjust_transform,
}

async def go():
async with stan.common.httpstan_server() as (host, port):
log_prob_url = f"http://{host}:{port}/v1/{self.model_name}/log_prob"
async with aiohttp.request("POST", log_prob_url, json=payload) as resp:
response_payload = await resp.json()
if resp.status != 200:
raise RuntimeError(response_payload)
return (await resp.json())["log_prob"]

return asyncio.run(go())


def build(program_code, data=None, random_seed=None) -> Model:
"""Build (compile) a Stan program.
Expand Down
54 changes: 54 additions & 0 deletions tests/test_log_prob.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""Test model with array parameter."""
import numpy as np
import pytest

import stan

unrestricted_program = """
parameters {
real y;
}
model {
y ~ normal(0, 1);
}
"""

restricted_program = """
parameters {
real<lower=0> y;
}
model {
y ~ normal(0, 1);
}
"""

num_samples = 1000
num_chains = 4


@pytest.fixture
def posterior(request):
return stan.build(request.param, random_seed=1)


@pytest.mark.parametrize("posterior", [unrestricted_program], indirect=True)
def test_log_prob(posterior):
"""Test log probability against sampled model with unrestriction."""
fit = posterior.sample(num_chains=num_chains, num_samples=num_samples)
y = fit["y"][0][0]
lp__ = fit["lp__"][0][0]
lp = posterior.log_prob(unconstrained_parameters=[y])
assert np.allclose(lp__, lp)


@pytest.mark.parametrize("posterior", [restricted_program], indirect=True)
def test_log_prob_restricted(posterior):
"""Test log probability against sampled model with restriction."""
fit = posterior.sample(num_chains=num_chains, num_samples=num_samples)
y = fit["y"][0][0]
y = posterior.unconstrain_pars({"y": y})[0]
lp__ = fit["lp__"][0][0]
lp = posterior.log_prob(unconstrained_parameters=[y], adjust_transform=False)
assert np.allclose(lp__, lp + y)
adjusted_lp = posterior.log_prob(unconstrained_parameters=[y], adjust_transform=True)
assert np.allclose(lp__, adjusted_lp)

0 comments on commit 2426d63

Please sign in to comment.