Skip to content

Commit

Permalink
test: Test sampling for zero and non-zero sized vectors.
Browse files Browse the repository at this point in the history
Test sampling zero and non-zero sized vectors. Test catches KeyError raised by the helper function.
  • Loading branch information
Ari Hartikainen authored and ahartikainen committed Mar 5, 2021
1 parent 2a291d3 commit 3525dd7
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions tests/test_fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

headers = {"content-type": "application/json"}
program_code = "parameters {real y;} model {y ~ normal(0, 0.0001);}"
program_code_vector = "parameters {vector[2] z; vector[0] x;} model {z ~ normal(0, 0.0001);}"


@pytest.mark.asyncio
Expand Down Expand Up @@ -68,3 +69,27 @@ async def draws(random_seed: Optional[int] = None) -> List[Union[int, float]]:
# look at all draws
assert np.allclose(draws1, draws2)
assert not np.allclose(draws1, draws3)


@pytest.mark.asyncio
async def test_fits_vector_sizes(api_url: str) -> None:
"""Simple test of sampling with zero and non-zero vector sizes."""

num_samples = num_warmup = 500
payload = {
"function": "stan::services::sample::hmc_nuts_diag_e_adapt",
"random_seed": 123,
"num_samples": num_samples,
"num_warmup": num_warmup,
}
param_name = "z.1"
draws = await helpers.sample_then_extract(api_url, program_code_vector, payload, param_name)
assert len(draws) == num_samples, (len(draws), num_samples)

param_name = "z.2"
draws = await helpers.sample_then_extract(api_url, program_code_vector, payload, param_name)
assert len(draws) == num_samples, (len(draws), num_samples)

param_name = "x.1"
with pytest.raises(KeyError, match="No draws found for parameter `x.1`."):
await helpers.sample_then_extract(api_url, program_code_vector, payload, param_name)

0 comments on commit 3525dd7

Please sign in to comment.