Skip to content

Commit

Permalink
Add mvnormal logp dlogp benchmark test
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Nov 11, 2023
1 parent 9c06de2 commit 0f802ab
Showing 1 changed file with 40 additions and 2 deletions.
42 changes: 40 additions & 2 deletions tests/tensor/test_blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from pytensor.gradient import grad
from pytensor.graph import Apply, Op
from pytensor.graph.replace import vectorize_node
from pytensor.tensor import tensor
from pytensor.tensor import diagonal, log, tensor
from pytensor.tensor.blockwise import Blockwise, _parse_gufunc_signature
from pytensor.tensor.nlinalg import MatrixInverse
from pytensor.tensor.slinalg import Cholesky, Solve
from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular


def test_vectorize_blockwise():
Expand Down Expand Up @@ -320,3 +320,41 @@ class TestSolveVector(BlockwiseOpTester):
class TestSolveMatrix(BlockwiseOpTester):
core_op = Solve(lower=True, b_ndim=2)
signature = "(m, m),(m, n) -> (m, n)"


@pytest.mark.parametrize(
"mu_batch_shape", [(), (1000,), (4, 1000)], ids=lambda arg: f"mu:{arg}"
)
@pytest.mark.parametrize(
"cov_batch_shape", [(), (1000,), (4, 1000)], ids=lambda arg: f"cov:{arg}"
)
def test_batched_mvnormal_logp_and_dlogp(mu_batch_shape, cov_batch_shape, benchmark):
rng = np.random.default_rng(sum(map(ord, "batched_mvnormal")))

value_batch_shape = mu_batch_shape
if len(cov_batch_shape) > len(mu_batch_shape):
value_batch_shape = cov_batch_shape

value = tensor("value", shape=(*value_batch_shape, 10))
mu = tensor("mu", shape=(*mu_batch_shape, 10))
cov = tensor("cov", shape=(*cov_batch_shape, 10, 10))

test_values = [
rng.normal(size=value.type.shape),
rng.normal(size=mu.type.shape),
np.eye(cov.type.shape[-1]) * np.abs(rng.normal(size=cov.type.shape)),
]

chol_cov = cholesky(cov, lower=True, on_error="raise")
delta_trans = solve_triangular(chol_cov, value - mu, b_ndim=1)
quaddist = (delta_trans**2).sum(axis=-1)
diag = diagonal(chol_cov, axis1=-2, axis2=-1)
logdet = log(diag).sum(axis=-1)
k = value.shape[-1]
norm = -0.5 * k * (np.log(2 * np.pi))

logp = norm - 0.5 * quaddist - logdet
dlogp = grad(logp.sum(), wrt=[value, mu, cov])

fn = pytensor.function([value, mu, cov], [logp, *dlogp])
benchmark(fn, *test_values)

0 comments on commit 0f802ab

Please sign in to comment.