Skip to content

Commit

Permalink
test: Test Stan shapes.
Browse files Browse the repository at this point in the history
Test Stan shapes for array, vector, matrix and combination of array and matrix.
  • Loading branch information
Ari Hartikainen authored and ahartikainen committed Mar 24, 2021
1 parent 948eab5 commit 8860ef5
Showing 1 changed file with 107 additions and 0 deletions.
107 changes: 107 additions & 0 deletions tests/test_fit_shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""Test model parameter shapes."""
import pytest

import stan

program_code = """
data {
int K;
int L;
int M;
int N;
int O;
int P;
int Q;
int R;
int S;
}
parameters {
real a[K];
real B[L, M];
vector[N] c;
matrix[O, P] D;
matrix[R, S] E[Q];
}
model {
for (k in 1:K) {
a[k] ~ std_normal();
}
for (l in 1:L) {
for (m in 1:M) {
B[l, m] ~ std_normal();
}
}
for (n in 1:N) {
c[n] ~ std_normal();
}
for (o in 1:O) {
for (p in 1:P) {
D[o, p] ~ std_normal();
}
}
for (q in 1:Q) {
for (r in 1:R) {
for (s in 1:S) {
E[q, r, s] ~ std_normal();
}
}
}
}
"""
num_samples = 100
num_chains = 3

dims = {
"a": ("K",),
"B": ("L", "M"),
"c": ("N",),
"D": ("O", "P"),
"E": ("Q", "R", "S"),
}


def get_posterior(data):
return stan.build(program_code, data=data)


def get_fit(data):
posterior = get_posterior(data)
return posterior.sample(num_samples=num_samples, num_chains=num_chains)


def get_data(zero_dims):
data = {
"K": 2,
"L": 3,
"M": 2,
"N": 2,
"O": 3,
"P": 2,
"Q": 4,
"R": 3,
"S": 2,
}
for zero_dim in zero_dims:
assert zero_dim in data
data[zero_dim] = 0
return data


@pytest.mark.parametrize(
"zero_dims",
["K", "L", "M", "LM", "N", "O", "P", "OP", "Q", "R", "S", "QR", "QS", "RS", "QRS", "LMNOPQRS"],
)
def test_fit_empty_array_shape(zero_dims):
"""
Make sure shapes are correct.
"""
data = get_data(zero_dims)
fit = get_fit(data)
for parameter, dim in dims.items():
shape = tuple(map(data.get, dim)) + (num_samples * num_chains,)
assert fit[parameter].shape == shape

0 comments on commit 8860ef5

Please sign in to comment.