Skip to content

Commit

Permalink
rework KL
Browse files Browse the repository at this point in the history
  • Loading branch information
ferrine committed Dec 15, 2023
1 parent 50dc421 commit 594447a
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 18 deletions.
6 changes: 6 additions & 0 deletions pymc/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@
)
from pymc.distributions.truncated import Truncated

# no dispatched stats are being initialized
# isort: off
import pymc.distributions._stats

# isort: on

__all__ = [
"Uniform",
"Flat",
Expand Down
14 changes: 14 additions & 0 deletions pymc/distributions/_stats/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2023 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pymc.distributions._stats import kl_divergence
36 changes: 36 additions & 0 deletions pymc/distributions/_stats/kl_divergence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2023 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List

import pytensor.tensor as pt

from pymc.distributions.continuous import Normal
from pymc.logprob.abstract import _kl_div


@_kl_div.register(Normal, Normal)
def _normal_normal_kl(
q_dist: Normal,
p_dist: Normal,
q_inputs: List[pt.TensorVariable],
p_inputs: List[pt.TensorVariable],
):
_, _, _, q_mu, q_sigma = q_inputs
_, _, _, p_mu, p_sigma = p_inputs
diff_log_scale = pt.log(q_sigma) - pt.log(p_sigma)
return (

Check warning on line 32 in pymc/distributions/_stats/kl_divergence.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/_stats/kl_divergence.py#L29-L32

Added lines #L29 - L32 were not covered by tests
0.5 * (q_mu / p_sigma - p_mu / p_sigma) ** 2
+ 0.5 * pt.expm1(2.0 * diff_log_scale)
- diff_log_scale
)
19 changes: 1 addition & 18 deletions pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.variable import TensorConstant

from pymc.logprob.abstract import _kl_div, _logprob_helper
from pymc.logprob.abstract import _logprob_helper
from pymc.logprob.basic import icdf

try:
Expand Down Expand Up @@ -551,23 +551,6 @@ def icdf(value, mu, sigma):
)


@_kl_div.register(Normal, Normal)
def _normal_normal_kl(
q_dist: Normal,
p_dist: Normal,
q_inputs: List[pt.TensorVariable],
p_inputs: List[pt.TensorVariable],
):
_, _, _, q_mu, q_sigma = q_inputs
_, _, _, p_mu, p_sigma = p_inputs
diff_log_scale = pt.log(q_sigma) - pt.log(p_sigma)
return (
0.5 * (q_mu / p_sigma - p_mu / p_sigma) ** 2
+ 0.5 * pt.expm1(2.0 * diff_log_scale)
- diff_log_scale
)


class TruncatedNormalRV(RandomVariable):
name = "truncated_normal"
ndim_supp = 0
Expand Down
File renamed without changes.

0 comments on commit 594447a

Please sign in to comment.