From 594447a52b5673dd6983d49baae4e1c3e0271872 Mon Sep 17 00:00:00 2001 From: Max Kochurov Date: Fri, 15 Dec 2023 12:42:45 +0000 Subject: [PATCH] rework KL --- pymc/distributions/__init__.py | 6 ++++ pymc/distributions/_stats/__init__.py | 14 ++++++++ pymc/distributions/_stats/kl_divergence.py | 36 +++++++++++++++++++ pymc/distributions/continuous.py | 19 +--------- .../stats}/test_kl_divergence.py | 0 5 files changed, 57 insertions(+), 18 deletions(-) create mode 100644 pymc/distributions/_stats/__init__.py create mode 100644 pymc/distributions/_stats/kl_divergence.py rename tests/{logprob => distributions/stats}/test_kl_divergence.py (100%) diff --git a/pymc/distributions/__init__.py b/pymc/distributions/__init__.py index bc3d9c7863..6c0d5de326 100644 --- a/pymc/distributions/__init__.py +++ b/pymc/distributions/__init__.py @@ -114,6 +114,12 @@ ) from pymc.distributions.truncated import Truncated +# no dispatched stats are being initialized +# isort: off +import pymc.distributions._stats + +# isort: on + __all__ = [ "Uniform", "Flat", diff --git a/pymc/distributions/_stats/__init__.py b/pymc/distributions/_stats/__init__.py new file mode 100644 index 0000000000..1c70431eea --- /dev/null +++ b/pymc/distributions/_stats/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pymc.distributions._stats import kl_divergence diff --git a/pymc/distributions/_stats/kl_divergence.py b/pymc/distributions/_stats/kl_divergence.py new file mode 100644 index 0000000000..58a57bdec5 --- /dev/null +++ b/pymc/distributions/_stats/kl_divergence.py @@ -0,0 +1,36 @@ +# Copyright 2023 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List + +import pytensor.tensor as pt + +from pymc.distributions.continuous import Normal +from pymc.logprob.abstract import _kl_div + + +@_kl_div.register(Normal, Normal) +def _normal_normal_kl( + q_dist: Normal, + p_dist: Normal, + q_inputs: List[pt.TensorVariable], + p_inputs: List[pt.TensorVariable], +): + _, _, _, q_mu, q_sigma = q_inputs + _, _, _, p_mu, p_sigma = p_inputs + diff_log_scale = pt.log(q_sigma) - pt.log(p_sigma) + return ( + 0.5 * (q_mu / p_sigma - p_mu / p_sigma) ** 2 + + 0.5 * pt.expm1(2.0 * diff_log_scale) + - diff_log_scale + ) diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 00c154de86..2b41cb46d8 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -56,7 +56,7 @@ from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.variable import TensorConstant -from pymc.logprob.abstract import _kl_div, _logprob_helper +from pymc.logprob.abstract import _logprob_helper from pymc.logprob.basic import icdf try: @@ -551,23 +551,6 @@ def icdf(value, mu, sigma): ) -@_kl_div.register(Normal, Normal) -def _normal_normal_kl( - q_dist: Normal, - p_dist: Normal, - q_inputs: List[pt.TensorVariable], - p_inputs: List[pt.TensorVariable], -): - _, _, _, q_mu, q_sigma = q_inputs - _, _, _, p_mu, p_sigma = p_inputs - diff_log_scale = pt.log(q_sigma) - pt.log(p_sigma) - return ( - 0.5 * (q_mu / p_sigma - p_mu / p_sigma) ** 2 - + 0.5 * pt.expm1(2.0 * diff_log_scale) - - diff_log_scale - ) - - class TruncatedNormalRV(RandomVariable): name = "truncated_normal" ndim_supp = 0 diff --git a/tests/logprob/test_kl_divergence.py b/tests/distributions/stats/test_kl_divergence.py similarity index 100% rename from tests/logprob/test_kl_divergence.py rename to tests/distributions/stats/test_kl_divergence.py