Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions pymc/distributions/censored.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
from pytensor.tensor import TensorVariable
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.utils import normalize_size_param
from pytensor.tensor.variable import TensorConstant

from pymc.distributions.dist_math import check_parameters
from pymc.distributions.distribution import (
Distribution,
SymbolicRandomVariable,
Expand All @@ -29,6 +31,7 @@
implicit_size_from_params,
rv_size_is_none,
)
from pymc.logprob.abstract import _logcdf
from pymc.util import check_dist_not_registered


Expand Down Expand Up @@ -156,3 +159,30 @@ def support_point_censored(op, rv, dist, lower, upper):
)
support_point = pt.full_like(dist, support_point)
return support_point


@_logcdf.register(CensoredRV)
def censored_logcdf(op, value, *inputs, **kwargs):
base_rv, lower, upper = inputs

base_rv_op = base_rv.owner.op
base_rv_inputs = base_rv.owner.inputs
logcdf_val = _logcdf(base_rv_op, value, *base_rv_inputs, **kwargs)

is_lower_bounded = not (isinstance(lower, TensorConstant) and np.all(np.isneginf(lower.value)))
is_upper_bounded = not (isinstance(upper, TensorConstant) and np.all(np.isinf(upper.value)))

if is_lower_bounded:
logcdf_val = pt.switch(pt.lt(value, lower), -np.inf, logcdf_val)

if is_upper_bounded:
logcdf_val = pt.switch(pt.ge(value, upper), 0.0, logcdf_val)

if is_lower_bounded and is_upper_bounded:
logcdf_val = check_parameters(
logcdf_val,
pt.le(lower, upper),
msg="lower_bound <= upper_bound",
)

return logcdf_val
93 changes: 92 additions & 1 deletion tests/distributions/test_censored.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@

import numpy as np
import pytest
import scipy as sp

import pymc as pm

from pymc import logp
from pymc import logcdf, logp
from pymc.distributions.shape_utils import change_dist_size


Expand Down Expand Up @@ -126,3 +127,93 @@ def test_censored_categorical(self):
logp(censored_cat, [-1, 0, 1, 2, 3, 4, 5]).exp().eval(),
[0, 0, 0.3, 0.2, 0.5, 0, 0],
)

def test_censored_logcdf_continuous(self):
norm = pm.Normal.dist(0, 1)
eval_points = np.array([-np.inf, -2, -1, 0, 1, 2, np.inf])
expected_logcdf_uncensored = sp.stats.norm.logcdf(eval_points)

match_str = "divide by zero encountered in log|invalid value encountered in subtract"

# No censoring
censored_norm = pm.Censored.dist(norm, lower=None, upper=None)
with pytest.warns(RuntimeWarning, match=match_str):
censored_eval = logcdf(censored_norm, eval_points).eval()
np.testing.assert_allclose(censored_eval, expected_logcdf_uncensored)

# Left censoring
censored_norm = pm.Censored.dist(norm, lower=-1, upper=None)
expected_left = np.where(eval_points < -1, -np.inf, expected_logcdf_uncensored)
with pytest.warns(RuntimeWarning, match=match_str):
censored_eval = logcdf(censored_norm, eval_points).eval()
np.testing.assert_allclose(
censored_eval,
expected_left,
rtol=1e-6,
)

# Right censoring
censored_norm = pm.Censored.dist(norm, lower=None, upper=1)
expected_right = np.where(eval_points >= 1, 0.0, expected_logcdf_uncensored)
with pytest.warns(RuntimeWarning, match=match_str):
censored_eval = logcdf(censored_norm, eval_points).eval()
np.testing.assert_allclose(
censored_eval,
expected_right,
rtol=1e-6,
)

# Interval censoring
censored_norm = pm.Censored.dist(norm, lower=-1, upper=1)
expected_interval = np.where(eval_points < -1, -np.inf, expected_logcdf_uncensored)
expected_interval = np.where(eval_points >= 1, 0.0, expected_interval)
with pytest.warns(RuntimeWarning, match=match_str):
censored_eval = logcdf(censored_norm, eval_points).eval()
np.testing.assert_allclose(
censored_eval,
expected_interval,
rtol=1e-6,
)

def test_censored_logcdf_discrete(self):
probs = [0.1, 0.2, 0.2, 0.3, 0.2]
cat = pm.Categorical.dist(probs)
eval_points = np.array([-1, 0, 1, 2, 3, 4, 5])

cdf = np.cumsum(probs)
log_cdf_base = np.log(cdf)
expected_logcdf_uncensored = np.full_like(eval_points, -np.inf, dtype=float)
expected_logcdf_uncensored[1:6] = log_cdf_base
expected_logcdf_uncensored[6] = 0.0

# No censoring
censored_cat = pm.Censored.dist(cat, lower=None, upper=None)
np.testing.assert_allclose(
logcdf(censored_cat, eval_points).eval(),
expected_logcdf_uncensored,
)

# Left censoring
censored_cat = pm.Censored.dist(cat, lower=1, upper=None)
expected_left = np.where(eval_points < 1, -np.inf, expected_logcdf_uncensored)
np.testing.assert_allclose(
logcdf(censored_cat, eval_points).eval(),
expected_left,
)

# Right censoring
censored_cat = pm.Censored.dist(cat, lower=None, upper=3)
expected_right = np.where(eval_points >= 3, 0.0, expected_logcdf_uncensored)
np.testing.assert_allclose(
logcdf(censored_cat, eval_points).eval(),
expected_right,
)

# Interval censoring
censored_cat = pm.Censored.dist(cat, lower=1, upper=3)
expected_interval = np.where(eval_points < 1, -np.inf, expected_logcdf_uncensored)
expected_interval = np.where(eval_points >= 3, 0.0, expected_interval)
np.testing.assert_allclose(
logcdf(censored_cat, eval_points).eval(),
expected_interval,
)
Loading