From f5ee4de417aea17440446d767fe2d9696e2c3dc3 Mon Sep 17 00:00:00 2001 From: Luke LB Date: Tue, 4 Apr 2023 21:02:10 +0100 Subject: [PATCH 01/15] added tranform classes for sinh cosh and tanh --- pymc/logprob/transforms.py | 54 +++++++++++++++++++++++++++++++++++--- 1 file changed, 51 insertions(+), 3 deletions(-) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 903f013abe..5d313088fb 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -49,21 +49,24 @@ from pytensor.graph.op import Op from pytensor.graph.replace import clone_replace from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter -from pytensor.scalar import Abs, Add, Exp, Log, Mul, Pow, Sqr, Sqrt +from pytensor.scalar import Abs, Add, Cosh, Exp, Log, Mul, Pow, Sinh, Sqr, Sqrt, Tanh from pytensor.scan.op import Scan from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import ( abs, add, + cosh, exp, log, mul, neg, pow, reciprocal, + sinh, sqr, sqrt, sub, + tanh, true_div, ) from pytensor.tensor.rewriting.basic import ( @@ -340,7 +343,7 @@ def apply(self, fgraph: FunctionGraph): class MeasurableTransform(MeasurableElemwise): """A placeholder used to specify a log-likelihood for a transformed measurable variable""" - valid_scalar_types = (Exp, Log, Add, Mul, Pow, Abs) + valid_scalar_types = (Exp, Log, Add, Mul, Pow, Abs, Sinh, Cosh, Tanh) # Cannot use `transform` as name because it would clash with the property added by # the `TransformValuesRewrite` @@ -540,7 +543,7 @@ def measurable_sub_to_neg(fgraph, node): return [pt.add(minuend, pt.neg(subtrahend))] -@node_rewriter([exp, log, add, mul, pow, abs]) +@node_rewriter([exp, log, add, mul, pow, abs, sinh, cosh, tanh]) def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]: """Find measurable transformations from Elemwise operators.""" @@ -602,6 +605,12 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li transform = LogTransform() elif isinstance(scalar_op, Abs): transform = AbsTransform() + elif isinstance(scalar_op, Sinh): + transform = SinhTransform() + elif isinstance(scalar_op, Cosh): + transform = CoshTransform() + elif isinstance(scalar_op, Tanh): + transform = TanhTransform() elif isinstance(scalar_op, Pow): # We only allow for the base to be measurable if measurable_input_idx != 0: @@ -682,6 +691,45 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li ) +class SinhTransform(RVTransform): + name = "sinh" + + def forward(self, value, *inputs): + return pt.sinh(value) + + def backward(self, value, *inputs): + return pt.arcsinh(value) + + def log_jac_det(self, value, *inputs): + return pt.log(pt.cosh(value)) + + +class CoshTransform(RVTransform): + name = "cosh" + + def forward(self, value, *inputs): + return pt.cosh(value) + + def backward(self, value, *inputs): + return pt.arccosh(value) + + def log_jac_det(self, value, *inputs): + return pt.log(pt.sinh(value)) + + +class TanhTransform(RVTransform): + name = "tanh" + + def forward(self, value, *inputs): + return pt.tanh(value) + + def backward(self, value, *inputs): + return pt.arctanh(value) + + def log_jac_det(self, value, *inputs): + return pt.log(1 / pt.cosh(value)) + + class LocTransform(RVTransform): name = "loc" From d80d8ecb2da7d97ca430c1d6a43ede6cecfa9099 Mon Sep 17 00:00:00 2001 From: Luke LB Date: Tue, 4 Apr 2023 22:45:41 +0100 Subject: [PATCH 02/15] cleaned up if block in find_measurable_transform --- pymc/logprob/transforms.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 5d313088fb..1f079b1a31 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -599,19 +599,17 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li measurable_input_idx = 0 transform_inputs: Tuple[TensorVariable, ...] = (measurable_input,) transform: RVTransform - if isinstance(scalar_op, Exp): - transform = ExpTransform() - elif isinstance(scalar_op, Log): - transform = LogTransform() - elif isinstance(scalar_op, Abs): - transform = AbsTransform() - elif isinstance(scalar_op, Sinh): - transform = SinhTransform() - elif isinstance(scalar_op, Cosh): - transform = CoshTransform() - elif isinstance(scalar_op, Tanh): - transform = TanhTransform() - elif isinstance(scalar_op, Pow): + + transform_dict = { + Exp: ExpTransform(), + Log: LogTransform(), + Abs: AbsTransform(), + Sinh: SinhTransform(), + Cosh: CoshTransform(), + Tanh: TanhTransform(), + } + transform = transform_dict.get(type(scalar_op), None) + if isinstance(scalar_op, Pow): # We only allow for the base to be measurable if measurable_input_idx != 0: return None @@ -628,7 +626,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li transform = LocTransform( transform_args_fn=lambda *inputs: inputs[-1], ) - else: + elif transform is None: transform_inputs = (measurable_input, pt.mul(*other_inputs)) transform = ScaleTransform( transform_args_fn=lambda *inputs: inputs[-1], From 592068f6d06403ab3ecc703ec194a1f35d79bcf1 Mon Sep 17 00:00:00 2001 From: Luke LB Date: Sun, 9 Apr 2023 23:47:09 +0100 Subject: [PATCH 03/15] added an erfcx transform --- pymc/logprob/transforms.py | 50 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 1f079b1a31..bf8d602f5e 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -728,6 +728,56 @@ def log_jac_det(self, value, *inputs): return pt.log(1 / pt.cosh(value)) +class ErfTransform(RVTransform): + name = "erf" + + def forward(self, value, *inputs): + return pt.erf(value) + + def backward(self, value, *inputs): + return pt.erfinv(value) + + def log_jac_det(self, value, *inputs): + return at.log(2 * at.exp(-(value**2)) / (at.sqrt(np.pi))) + + +class ErfcTransform(RVTransform): + name = "erfc" + + def forward(self, value, *inputs): + return at.erfc(value) + + def backward(self, value, *inputs): + return at.erfcinv(value) + + def log_jac_det(self, value, *inputs): + return at.log(-2 * at.exp(-(value**2)) / (at.sqrt(np.pi))) + + +class ErfcxTransform(RVTransform): + name = "erfcx" + + def forward(self, value, *inputs): + return at.erfcx(value) + + def backward(y, tol=1e-10, max_iter=100): + # Compute x using the appropriate formula for each value of y + x = at.switch(y <= 1, 1.0 / (y * at.sqrt(np.pi)), -at.sqrt(at.log(y))) + iter_count = 0 + while iter_count < max_iter: + iter_count += 1 + fx = at.erfcx(x) - y + fpx = 2 * x * at.erfcx(x) - 2 / at.sqrt(np.pi) + delta_x = fx / fpx + x = x - delta_x + if (at.abs(delta_x) < tol).all(): + break + return x + + def log_jac_det(self, value, *inputs): + return at.log(2 * x * at.exp(x**2) * at.erfc(x) - 2.0 / at.sqrt(np.pi)) + + class LocTransform(RVTransform): name = "loc" From 236e7d7f91160fc001c9c9315b9a45631e434a13 Mon Sep 17 00:00:00 2001 From: Luke LB Date: Mon, 10 Apr 2023 12:31:40 +0100 Subject: [PATCH 04/15] added erf, erfc, erfcx now working in notebook --- pymc/logprob/transforms.py | 49 +++++++++++++++++++++++++++----------- 1 file changed, 35 insertions(+), 14 deletions(-) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index bf8d602f5e..65905556b8 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -49,13 +49,31 @@ from pytensor.graph.op import Op from pytensor.graph.replace import clone_replace from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter -from pytensor.scalar import Abs, Add, Cosh, Exp, Log, Mul, Pow, Sinh, Sqr, Sqrt, Tanh +from pytensor.scalar import ( + Abs, + Add, + Cosh, + Erf, + Erfc, + Erfcx, + Exp, + Log, + Mul, + Pow, + Sinh, + Sqr, + Sqrt, + Tanh, +) from pytensor.scan.op import Scan from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import ( abs, add, cosh, + erf, + erfc, + erfcx, exp, log, mul, @@ -343,7 +361,7 @@ def apply(self, fgraph: FunctionGraph): class MeasurableTransform(MeasurableElemwise): """A placeholder used to specify a log-likelihood for a transformed measurable variable""" - valid_scalar_types = (Exp, Log, Add, Mul, Pow, Abs, Sinh, Cosh, Tanh) + valid_scalar_types = (Exp, Log, Add, Mul, Pow, Abs, Sinh, Cosh, Tanh, Erf, Erfc, Erfcx) # Cannot use `transform` as name because it would clash with the property added by # the `TransformValuesRewrite` @@ -543,7 +561,7 @@ def measurable_sub_to_neg(fgraph, node): return [pt.add(minuend, pt.neg(subtrahend))] -@node_rewriter([exp, log, add, mul, pow, abs, sinh, cosh, tanh]) +@node_rewriter([exp, log, add, mul, pow, abs, sinh, cosh, tanh, erf, erfc, erfcx]) def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]: """Find measurable transformations from Elemwise operators.""" @@ -607,6 +625,9 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li Sinh: SinhTransform(), Cosh: CoshTransform(), Tanh: TanhTransform(), + Erf: ErfTransform(), + Erfc: ErfcTransform(), + Erfcx: ErfcxTransform(), } transform = transform_dict.get(type(scalar_op), None) if isinstance(scalar_op, Pow): @@ -738,44 +759,44 @@ def backward(self, value, *inputs): return pt.erfinv(value) def log_jac_det(self, value, *inputs): - return at.log(2 * at.exp(-(value**2)) / (at.sqrt(np.pi))) + return pt.log(2 * pt.exp(-(value**2)) / (pt.sqrt(np.pi))) class ErfcTransform(RVTransform): name = "erfc" def forward(self, value, *inputs): - return at.erfc(value) + return pt.erfc(value) def backward(self, value, *inputs): - return at.erfcinv(value) + return pt.erfcinv(value) def log_jac_det(self, value, *inputs): - return at.log(-2 * at.exp(-(value**2)) / (at.sqrt(np.pi))) + return pt.log(-2 * pt.exp(-(value**2)) / (pt.sqrt(np.pi))) class ErfcxTransform(RVTransform): name = "erfcx" def forward(self, value, *inputs): - return at.erfcx(value) + return pt.erfcx(value) - def backward(y, tol=1e-10, max_iter=100): + def backward(self, value, tol=1e-10, max_iter=100): # Compute x using the appropriate formula for each value of y - x = at.switch(y <= 1, 1.0 / (y * at.sqrt(np.pi)), -at.sqrt(at.log(y))) + x = pt.switch(value <= 1, 1.0 / (value * pt.sqrt(np.pi)), -pt.sqrt(pt.log(value))) iter_count = 0 while iter_count < max_iter: iter_count += 1 - fx = at.erfcx(x) - y - fpx = 2 * x * at.erfcx(x) - 2 / at.sqrt(np.pi) + fx = pt.erfcx(x) - value + fpx = 2 * x * pt.erfcx(x) - 2 / pt.sqrt(np.pi) delta_x = fx / fpx x = x - delta_x - if (at.abs(delta_x) < tol).all(): + if (pt.abs(delta_x) < tol).all(): break return x def log_jac_det(self, value, *inputs): - return at.log(2 * x * at.exp(x**2) * at.erfc(x) - 2.0 / at.sqrt(np.pi)) + return pt.log((2 * value * pt.exp(value**2) * pt.erfc(value) - 2.0) / pt.sqrt(np.pi)) class LocTransform(RVTransform): From 05b8d5506eaf871258928ac1a3e8566293135d4f Mon Sep 17 00:00:00 2001 From: Luke LB Date: Mon, 10 Apr 2023 23:02:06 +0100 Subject: [PATCH 05/15] added a comment to erfcx backward fn and simplified the iteration code --- pymc/logprob/transforms.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 65905556b8..a03b109360 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -782,14 +782,13 @@ def forward(self, value, *inputs): return pt.erfcx(value) def backward(self, value, tol=1e-10, max_iter=100): - # Compute x using the appropriate formula for each value of y + # computes the inverse of erfcx, this was adapted from + # https://tinyurl.com/4mxfd3cz x = pt.switch(value <= 1, 1.0 / (value * pt.sqrt(np.pi)), -pt.sqrt(pt.log(value))) iter_count = 0 while iter_count < max_iter: iter_count += 1 - fx = pt.erfcx(x) - value - fpx = 2 * x * pt.erfcx(x) - 2 / pt.sqrt(np.pi) - delta_x = fx / fpx + delta_x = (pt.erfcx(x) - value) / (2 * x * pt.erfcx(x) - 2 / pt.sqrt(np.pi)) x = x - delta_x if (pt.abs(delta_x) < tol).all(): break From 1825315744182c36f550d2aaccd620f7f9538e29 Mon Sep 17 00:00:00 2001 From: Luke LB Date: Sat, 15 Apr 2023 16:31:13 +0100 Subject: [PATCH 06/15] used scan for erfcx and got test running --- pymc/logprob/transforms.py | 24 +++++++++++++++--------- tests/logprob/test_transforms.py | 20 ++++++++++++++++++++ 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index a03b109360..ea6b6256a1 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -781,18 +781,24 @@ class ErfcxTransform(RVTransform): def forward(self, value, *inputs): return pt.erfcx(value) - def backward(self, value, tol=1e-10, max_iter=100): + def backward(self, value, k=10): # computes the inverse of erfcx, this was adapted from # https://tinyurl.com/4mxfd3cz x = pt.switch(value <= 1, 1.0 / (value * pt.sqrt(np.pi)), -pt.sqrt(pt.log(value))) - iter_count = 0 - while iter_count < max_iter: - iter_count += 1 - delta_x = (pt.erfcx(x) - value) / (2 * x * pt.erfcx(x) - 2 / pt.sqrt(np.pi)) - x = x - delta_x - if (pt.abs(delta_x) < tol).all(): - break - return x + + def calc_delta_x(value, prior_result): + return prior_result - (pt.erfcx(prior_result) - value) / ( + 2 * prior_result * pt.erfcx(prior_result) - 2 / pt.sqrt(np.pi) + ) + + result, updates = pytensor.scan( + fn=calc_delta_x, + outputs_info=at.ones_like(x), + sequences=value, + non_sequences=x, + n_steps=k, + ) + return result[-1] def log_jac_det(self, value, *inputs): return pt.log((2 * value * pt.exp(value**2) * pt.erfc(value) - 2.0) / pt.sqrt(np.pi)) diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index ba24fadee1..0ddf190e77 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -52,6 +52,7 @@ from pymc.logprob.basic import factorized_joint_logprob from pymc.logprob.transforms import ( ChainedTransform, + ErfTransform, ExpTransform, IntervalTransform, LocTransform, @@ -989,3 +990,22 @@ def test_multivariate_transform(shift, scale): scale_mat @ cov @ scale_mat.T, ), ) + + +@pytest.mark.parametrize("transform", [ErfTransform]) +def test_erf_logp(transform): + base_rv = pt.random.normal( + 0.5, 1, name="base_rv" + ) # Something not centered around 0 is usually better + rv = pt.erf(base_rv) + vv = rv.clone() + rv_logp = joint_logprob({rv: vv}) + + transform = transform() + expected_logp = joint_logprob({rv: transform.backward(vv)}) + transform.log_jac_det(vv) + + vv_test = np.array(0.25) # Arbitrary test value + np.testing.assert_almost_equal( + rv_logp.eval({vv: vv_test}), + expected_logp.eval({vv: vv_test}), + ) From f847786ab2757c516acda339538f980384693ee2 Mon Sep 17 00:00:00 2001 From: Luke LB Date: Tue, 18 Apr 2023 23:33:52 +0100 Subject: [PATCH 07/15] test now running and passing on all but erfc and erfcx, will fix --- pymc/logprob/transforms.py | 5 ++-- tests/logprob/test_transforms.py | 40 ++++++++++++++++++++++++++------ 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index ea6b6256a1..61354c4b14 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -42,6 +42,7 @@ import numpy as np import pytensor.tensor as pt +from pytensor import scan from pytensor.gradient import DisconnectedType, jacobian from pytensor.graph.basic import Apply, Node, Variable from pytensor.graph.features import AlreadyThere, Feature @@ -791,9 +792,9 @@ def calc_delta_x(value, prior_result): 2 * prior_result * pt.erfcx(prior_result) - 2 / pt.sqrt(np.pi) ) - result, updates = pytensor.scan( + result, updates = scan( fn=calc_delta_x, - outputs_info=at.ones_like(x), + outputs_info=pt.ones_like(x), sequences=value, non_sequences=x, n_steps=k, diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index 0ddf190e77..ae3e1544bf 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -49,9 +49,12 @@ from pymc.distributions.transforms import _default_transform, log, logodds from pymc.logprob.abstract import MeasurableVariable, _get_measurable_outputs, _logprob -from pymc.logprob.basic import factorized_joint_logprob +from pymc.logprob.basic import factorized_joint_logprob, logp from pymc.logprob.transforms import ( ChainedTransform, + CoshTransform, + ErfcTransform, + ErfcxTransform, ErfTransform, ExpTransform, IntervalTransform, @@ -60,6 +63,8 @@ LogTransform, RVTransform, ScaleTransform, + SinhTransform, + TanhTransform, TransformValuesMapping, TransformValuesRewrite, transformed_variable, @@ -992,20 +997,41 @@ def test_multivariate_transform(shift, scale): ) -@pytest.mark.parametrize("transform", [ErfTransform]) -def test_erf_logp(transform): +from pytensor.tensor import cosh, erf, erfc, erfcx, sinh, tanh + + +@pytest.mark.parametrize( + "pt_transform, transform", + [ + (erf, ErfTransform()), + (erfc, ErfcTransform()), + (erfcx, ErfcxTransform()), + (sinh, SinhTransform()), + (cosh, CoshTransform()), + (tanh, TanhTransform()), + ], +) +def test_erf_logp(pt_transform, transform): base_rv = pt.random.normal( 0.5, 1, name="base_rv" ) # Something not centered around 0 is usually better - rv = pt.erf(base_rv) + rv = pt_transform(base_rv) + vv = rv.clone() - rv_logp = joint_logprob({rv: vv}) + rv_logp = logp(rv, vv) - transform = transform() - expected_logp = joint_logprob({rv: transform.backward(vv)}) + transform.log_jac_det(vv) + expected_logp = logp(base_rv, transform.backward(vv)) + transform.log_jac_det(vv) vv_test = np.array(0.25) # Arbitrary test value np.testing.assert_almost_equal( rv_logp.eval({vv: vv_test}), expected_logp.eval({vv: vv_test}), ) + + +from pymc.testing import Rplusbig, Vector +from tests.distributions.test_transform import check_jacobian_det + + +def test_check_jac_det(): + check_jacobian_det(ErfTransform(), Vector(Rplusbig, 2), pt.dvector, [0, 0], elemwise=False) From b5e15175afe8a85dd51081518a7db121ef5570f0 Mon Sep 17 00:00:00 2001 From: Luke LB Date: Wed, 19 Apr 2023 19:36:27 +0100 Subject: [PATCH 08/15] test 2. now passing --- pymc/logprob/transforms.py | 5 ++--- tests/logprob/test_transforms.py | 13 +++++++++++-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 61354c4b14..cb3a005d04 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -413,7 +413,7 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa jacobian = jacobian.sum(axis=tuple(range(-ndim_supp, 0))) # The jacobian is used to ensure a value in the supported domain was provided - return pt.switch(pt.isnan(jacobian), -np.inf, input_logprob + jacobian) + return pt.switch(pt.isnan(input_logprob + jacobian), -np.inf, input_logprob + jacobian) @_logcdf.register(MeasurableTransform) @@ -795,8 +795,7 @@ def calc_delta_x(value, prior_result): result, updates = scan( fn=calc_delta_x, outputs_info=pt.ones_like(x), - sequences=value, - non_sequences=x, + non_sequences=value, n_steps=k, ) return result[-1] diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index ae3e1544bf..8489169b83 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -1025,7 +1025,9 @@ def test_erf_logp(pt_transform, transform): vv_test = np.array(0.25) # Arbitrary test value np.testing.assert_almost_equal( rv_logp.eval({vv: vv_test}), - expected_logp.eval({vv: vv_test}), + np.where( + np.isnan(expected_logp.eval({vv: vv_test})), -np.inf, expected_logp.eval({vv: vv_test}) + ), ) @@ -1034,4 +1036,11 @@ def test_erf_logp(pt_transform, transform): def test_check_jac_det(): - check_jacobian_det(ErfTransform(), Vector(Rplusbig, 2), pt.dvector, [0, 0], elemwise=False) + check_jacobian_det( + ErfTransform(), + Vector(Rplusbig, 2), + pt.dvector, + [0.1, 0.1], + elemwise=True, + rv_var=pt.random.normal(0.5, 1, name="base_rv"), + ) From 3c5c733cfb6db82753f023e01629dc5fcf8489f8 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 20 Apr 2023 09:56:45 +0200 Subject: [PATCH 09/15] Simplify nan to ninf in test --- tests/logprob/test_transforms.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index 8489169b83..462554ae5d 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -997,18 +997,15 @@ def test_multivariate_transform(shift, scale): ) -from pytensor.tensor import cosh, erf, erfc, erfcx, sinh, tanh - - @pytest.mark.parametrize( "pt_transform, transform", [ - (erf, ErfTransform()), - (erfc, ErfcTransform()), - (erfcx, ErfcxTransform()), - (sinh, SinhTransform()), - (cosh, CoshTransform()), - (tanh, TanhTransform()), + (pt.erf, ErfTransform()), + (pt.erfc, ErfcTransform()), + (pt.erfcx, ErfcxTransform()), + (pt.sinh, SinhTransform()), + (pt.cosh, CoshTransform()), + (pt.tanh, TanhTransform()), ], ) def test_erf_logp(pt_transform, transform): @@ -1024,10 +1021,7 @@ def test_erf_logp(pt_transform, transform): vv_test = np.array(0.25) # Arbitrary test value np.testing.assert_almost_equal( - rv_logp.eval({vv: vv_test}), - np.where( - np.isnan(expected_logp.eval({vv: vv_test})), -np.inf, expected_logp.eval({vv: vv_test}) - ), + rv_logp.eval({vv: vv_test}), np.nan_to_num(expected_logp.eval({vv: vv_test}), nan=-np.inf) ) From b8bf86f40f9253dc66b8384399dd78f51f7358fc Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 20 Apr 2023 10:09:15 +0200 Subject: [PATCH 10/15] Adapt default `RVTransform.log_jac_det` to univariate and vector transformations. --- pymc/logprob/transforms.py | 18 ++++++++++++------ tests/logprob/test_transforms.py | 27 +++++++++++++++++++++++---- 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index cb3a005d04..da30cde92a 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -144,6 +144,8 @@ def remove_TransformedVariables(fgraph, node): class RVTransform(abc.ABC): + ndim_supp = None + @abc.abstractmethod def forward(self, value: TensorVariable, *inputs: Variable) -> TensorVariable: """Apply the transformation.""" @@ -157,12 +159,16 @@ def backward( def log_jac_det(self, value: TensorVariable, *inputs) -> TensorVariable: """Construct the log of the absolute value of the Jacobian determinant.""" - # jac = pt.reshape( - # gradient(pt.sum(self.backward(value, *inputs)), [value]), value.shape - # ) - # return pt.log(pt.abs(jac)) - phi_inv = self.backward(value, *inputs) - return pt.log(pt.abs(pt.nlinalg.det(pt.atleast_2d(jacobian(phi_inv, [value])[0])))) + if self.ndim_supp not in (0, 1): + raise NotImplementedError( + f"RVTransform default log_jac_det only implemented for ndim_supp in (0, 1), got {self.ndim_supp=}" + ) + if self.ndim_supp == 0: + jac = pt.reshape(pt.grad(pt.sum(self.backward(value, *inputs)), [value]), value.shape) + return pt.log(pt.abs(jac)) + else: + phi_inv = self.backward(value, *inputs) + return pt.log(pt.abs(pt.nlinalg.det(pt.atleast_2d(jacobian(phi_inv, [value])[0])))) @node_rewriter(tracks=None) diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index 462554ae5d..814e23fa11 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -333,6 +333,7 @@ def test_fallback_log_jac_det(ndim): class SquareTransform(RVTransform): name = "square" + ndim_supp = ndim def forward(self, value, *inputs): return pt.power(value, 2) @@ -342,13 +343,31 @@ def backward(self, value, *inputs): square_tr = SquareTransform() - value = pt.TensorType("float64", (None,) * ndim)("value") + value = pt.vector("value") value_tr = square_tr.forward(value) log_jac_det = square_tr.log_jac_det(value_tr) - test_value = np.full((2,) * ndim, 3) - expected_log_jac_det = -np.log(6) * test_value.size - assert np.isclose(log_jac_det.eval({value: test_value}), expected_log_jac_det) + test_value = np.r_[3, 4] + expected_log_jac_det = -np.log(2 * test_value) + if ndim == 1: + expected_log_jac_det = expected_log_jac_det.sum() + np.testing.assert_array_equal(log_jac_det.eval({value: test_value}), expected_log_jac_det) + + +@pytest.mark.parametrize("ndim", (None, 2)) +def test_fallback_log_jac_det_undefined_ndim(ndim): + class SquareTransform(RVTransform): + name = "square" + ndim_supp = ndim + + def forward(self, value, *inputs): + return pt.power(value, 2) + + def backward(self, value, *inputs): + return pt.sqrt(value) + + with pytest.raises(NotImplementedError, match=r"only implemented for ndim_supp in \(0, 1\)"): + SquareTransform().log_jac_det(0) def test_hierarchical_uniform_transform(): From bcc1eb9beada85be33abb946bb73a4947678b014 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 20 Apr 2023 10:19:58 +0200 Subject: [PATCH 11/15] Use np.testing in check_jacobian_det --- tests/distributions/test_transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/distributions/test_transform.py b/tests/distributions/test_transform.py index a29ab16679..4acf463bfd 100644 --- a/tests/distributions/test_transform.py +++ b/tests/distributions/test_transform.py @@ -112,7 +112,7 @@ def check_jacobian_det( ) for yval in domain.vals: - close_to(actual_ljd(yval), computed_ljd(yval), tol) + np.testing.assert_allclose(actual_ljd(yval), computed_ljd(yval), rtol=tol) def test_simplex(): From 43f4150ddf5b4bdc8650c07a0bc8239f3098871f Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 20 Apr 2023 10:20:39 +0200 Subject: [PATCH 12/15] Use default log_jac_det in ErfTransform --- pymc/logprob/transforms.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index da30cde92a..a0a4b15fc0 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -758,6 +758,7 @@ def log_jac_det(self, value, *inputs): class ErfTransform(RVTransform): name = "erf" + ndim_supp = 0 def forward(self, value, *inputs): return pt.erf(value) @@ -765,9 +766,6 @@ def forward(self, value, *inputs): def backward(self, value, *inputs): return pt.erfinv(value) - def log_jac_det(self, value, *inputs): - return pt.log(2 * pt.exp(-(value**2)) / (pt.sqrt(np.pi))) - class ErfcTransform(RVTransform): name = "erfc" From c429b2dfb4b5c2c25e27af6df160608477b69ed1 Mon Sep 17 00:00:00 2001 From: Luke LB Date: Thu, 20 Apr 2023 21:39:34 +0100 Subject: [PATCH 13/15] tests fixed, required removing handwritten log_jac_det --- pymc/logprob/transforms.py | 25 +++++++++---------------- tests/logprob/test_transforms.py | 15 +++++++++++++-- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index a0a4b15fc0..12707bac65 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -719,6 +719,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li class SinhTransform(RVTransform): name = "sinh" + ndim_supp = 0 def forward(self, value, *inputs): return pt.sinh(value) @@ -726,12 +727,10 @@ def forward(self, value, *inputs): def backward(self, value, *inputs): return pt.arcsinh(value) - def log_jac_det(self, value, *inputs): - return pt.log(pt.cosh(value)) - class CoshTransform(RVTransform): name = "cosh" + ndim_supp = 0 def forward(self, value, *inputs): return pt.cosh(value) @@ -739,12 +738,13 @@ def forward(self, value, *inputs): def backward(self, value, *inputs): return pt.arccosh(value) - def log_jac_det(self, value, *inputs): - return pt.log(pt.sinh(value)) + # def log_jac_det(self, value, *inputs): + # return pt.log(pt.sinh(value)) class TanhTransform(RVTransform): name = "tanh" + ndim_supp = 0 def forward(self, value, *inputs): return pt.tanh(value) @@ -752,9 +752,6 @@ def forward(self, value, *inputs): def backward(self, value, *inputs): return pt.arctanh(value) - def log_jac_det(self, value, *inputs): - return pt.log(1 / pt.cosh(value)) - class ErfTransform(RVTransform): name = "erf" @@ -769,6 +766,7 @@ def backward(self, value, *inputs): class ErfcTransform(RVTransform): name = "erfc" + ndim_supp = 0 def forward(self, value, *inputs): return pt.erfc(value) @@ -776,17 +774,15 @@ def forward(self, value, *inputs): def backward(self, value, *inputs): return pt.erfcinv(value) - def log_jac_det(self, value, *inputs): - return pt.log(-2 * pt.exp(-(value**2)) / (pt.sqrt(np.pi))) - class ErfcxTransform(RVTransform): name = "erfcx" + ndim_supp = 0 def forward(self, value, *inputs): return pt.erfcx(value) - def backward(self, value, k=10): + def backward(self, value, *inputs): # computes the inverse of erfcx, this was adapted from # https://tinyurl.com/4mxfd3cz x = pt.switch(value <= 1, 1.0 / (value * pt.sqrt(np.pi)), -pt.sqrt(pt.log(value))) @@ -800,13 +796,10 @@ def calc_delta_x(value, prior_result): fn=calc_delta_x, outputs_info=pt.ones_like(x), non_sequences=value, - n_steps=k, + n_steps=10, ) return result[-1] - def log_jac_det(self, value, *inputs): - return pt.log((2 * value * pt.exp(value**2) * pt.erfc(value) - 2.0) / pt.sqrt(np.pi)) - class LocTransform(RVTransform): name = "loc" diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index 814e23fa11..3484b70ba4 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -1048,9 +1048,20 @@ def test_erf_logp(pt_transform, transform): from tests.distributions.test_transform import check_jacobian_det -def test_check_jac_det(): - check_jacobian_det( +@pytest.mark.parametrize( + "transform", + [ ErfTransform(), + ErfcTransform(), + ErfcxTransform(), + SinhTransform(), + CoshTransform(), + TanhTransform(), + ], +) +def test_check_jac_det(transform): + check_jacobian_det( + transform, Vector(Rplusbig, 2), pt.dvector, [0.1, 0.1], From 2f3087c0d1f78c7c7e498ddcfc2c690fe4b08140 Mon Sep 17 00:00:00 2001 From: Luke Lewis-Borrell <35955390+LukeLB@users.noreply.github.com> Date: Fri, 28 Apr 2023 09:36:12 +0100 Subject: [PATCH 14/15] Update pymc/logprob/transforms.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pymc/logprob/transforms.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 12707bac65..7c76ea6e58 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -738,8 +738,6 @@ def forward(self, value, *inputs): def backward(self, value, *inputs): return pt.arccosh(value) - # def log_jac_det(self, value, *inputs): - # return pt.log(pt.sinh(value)) class TanhTransform(RVTransform): From 11d41db2fb87a51e55128f3e9e902a7dab12d370 Mon Sep 17 00:00:00 2001 From: Luke LB Date: Fri, 28 Apr 2023 09:42:50 +0100 Subject: [PATCH 15/15] reverted change to measurable_transform_logprob --- pymc/logprob/transforms.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 7c76ea6e58..56280fd302 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -419,7 +419,7 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa jacobian = jacobian.sum(axis=tuple(range(-ndim_supp, 0))) # The jacobian is used to ensure a value in the supported domain was provided - return pt.switch(pt.isnan(input_logprob + jacobian), -np.inf, input_logprob + jacobian) + return pt.switch(pt.isnan(jacobian), -np.inf, input_logprob + jacobian) @_logcdf.register(MeasurableTransform) @@ -739,7 +739,6 @@ def backward(self, value, *inputs): return pt.arccosh(value) - class TanhTransform(RVTransform): name = "tanh" ndim_supp = 0