From 74b3fbbbd82387099c6ed68b8a520a351c8718f5 Mon Sep 17 00:00:00 2001 From: Mr-Imperium Date: Tue, 25 Nov 2025 21:57:58 +0545 Subject: [PATCH 1/2] Add transform_labels to Transform, Simplex, SumTo1, and ZeroSum #7907 Implemented transform_labels method to handle coordinate transformations. - Base Transform: Identity (returns labels unchanged). - SimplexTransform: Drops the last label. - SumTo1: Drops the last label. - ZeroSumTransform: Implemented logic to drop labels corresponding to zerosum_axes. Verified via existing test suite and manual verification script. --- pymc/distributions/transforms.py | 9 +++++++++ pymc/logprob/transforms.py | 10 +++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index 38bc0e840a..8baca143de 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -13,6 +13,7 @@ # limitations under the License. from functools import singledispatch +from collections.abc import Sequence import numpy as np import pytensor.tensor as pt @@ -134,6 +135,10 @@ def log_jac_det(self, value, *inputs): y = pt.zeros(value.shape) return pt.sum(y, axis=-1) + def transform_labels(self, labels: Sequence[str]) -> Sequence[str]: + """Drop the last label since SumTo1 reduces dimensionality by 1.""" + return labels[:-1] + class CholeskyCovPacked(Transform): """Transforms the diagonal elements of the LKJCholeskyCov distribution to be on the log scale.""" @@ -311,6 +316,10 @@ def backward(self, value, *rv_inputs): def log_jac_det(self, value, *rv_inputs): return pt.constant(0.0) + def transform_labels(self, labels: Sequence[str]) -> Sequence[str]: + """Drop the last label since ZeroSumTransform reduces dimensionality by 1.""" + return labels[:-1] + log_exp_m1 = LogExpM1() log_exp_m1.__doc__ = """ diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 8d2bbacd26..6c6c7910da 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -35,7 +35,7 @@ # SOFTWARE. import abc -from collections.abc import Callable +from collections.abc import Callable, Sequence import numpy as np import pytensor.tensor as pt @@ -154,6 +154,10 @@ def log_jac_det(self, value: TensorVariable, *inputs) -> TensorVariable: phi_inv = self.backward(value, *inputs) return pt.log(pt.abs(pt.nlinalg.det(pt.atleast_2d(jacobian(phi_inv, [value])[0])))) + def transform_labels(self, labels: Sequence[str]) -> Sequence[str]: + """Mutate user-provided coordinates associated with the variable to label transformed values returned by this class.""" + return labels + def __str__(self): """Return a string representation of the object.""" return f"{self.__class__.__name__}" @@ -1006,6 +1010,10 @@ def log_jac_det(self, value, *inputs): res = pt.log(N) + (N * sum_value) - (N * logsumexp_value_expanded) return pt.sum(res, -1) + def transform_labels(self, labels: Sequence[str]) -> Sequence[str]: + """Drop the last label since Simplex reduces dimensionality by 1.""" + return labels[:-1] + class CircularTransform(Transform): name = "circular" From 16d948dbe2ed3e3a27eb57b83e7d42f6efae2f4c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 Nov 2025 16:28:59 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pymc/distributions/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index 8baca143de..76db660f14 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from functools import singledispatch from collections.abc import Sequence +from functools import singledispatch import numpy as np import pytensor.tensor as pt