Expand Up
@@ -23,7 +23,7 @@
import pytest
import scipy .stats as st
from pytensor import scan
from pytensor import scan , shared
from pytensor .tensor import TensorVariable
import pymc as pm
Expand All
@@ -42,14 +42,16 @@
CustomDist ,
CustomDistRV ,
CustomSymbolicDistRV ,
PartialObservedRV ,
SymbolicRandomVariable ,
_moment ,
create_partial_observed_rv ,
moment ,
)
from pymc .distributions .shape_utils import change_dist_size , rv_size_is_none , to_tuple
from pymc .distributions .transforms import log
from pymc .exceptions import BlockModelAccessError
from pymc .logprob .basic import logcdf , logp
from pymc .logprob .basic import conditional_logp , logcdf , logp
from pymc .model import Deterministic , Model
from pymc .pytensorf import collect_default_updates
from pymc .sampling import draw , sample
Expand Down
Expand Up
@@ -700,3 +702,225 @@ def test_dtype(self, floatX):
assert pm .DiracDelta .dist (2 ** 16 ).dtype == "int32"
assert pm .DiracDelta .dist (2 ** 32 ).dtype == "int64"
assert pm .DiracDelta .dist (2.0 ).dtype == floatX
class TestPartialObservedRV :
@pytest .mark .parametrize ("symbolic_rv" , (False , True ))
def test_univariate (self , symbolic_rv ):
data = np .array ([0.25 , 0.5 , 0.25 ])
mask = np .array ([False , False , True ])
rv = pm .Normal .dist ([1 , 2 , 3 ])
if symbolic_rv :
# We use a Censored Normal so that PartialObservedRV is needed,
# but don't use the bounds for testing the logp
rv = pm .Censored .dist (rv , lower = - 100 , upper = 100 )
(obs_rv , obs_mask ), (unobs_rv , unobs_mask ), joined_rv = create_partial_observed_rv (rv , mask )
# Test types
if symbolic_rv :
assert isinstance (obs_rv .owner .op , PartialObservedRV )
assert isinstance (unobs_rv .owner .op , PartialObservedRV )
else :
assert isinstance (obs_rv .owner .op , Normal )
assert isinstance (unobs_rv .owner .op , Normal )
# Tesh shapes
assert tuple (obs_rv .shape .eval ()) == (2 ,)
assert tuple (unobs_rv .shape .eval ()) == (1 ,)
assert tuple (joined_rv .shape .eval ()) == (3 ,)
# Test logp
logp = conditional_logp (
{obs_rv : pt .as_tensor (data [~ mask ]), unobs_rv : pt .as_tensor (data [mask ])}
)
obs_logp , unobs_logp = pytensor .function ([], list (logp .values ()))()
np .testing .assert_allclose (obs_logp , st .norm ([1 , 2 ]).logpdf ([0.25 , 0.5 ]))
np .testing .assert_allclose (unobs_logp , st .norm ([3 ]).logpdf ([0.25 ]))
@pytest .mark .parametrize ("obs_component_selected" , (True , False ))
def test_multivariate_constant_mask_separable (self , obs_component_selected ):
if obs_component_selected :
mask = np .zeros ((1 , 4 ), dtype = bool )
else :
mask = np .ones ((1 , 4 ), dtype = bool )
obs_data = np .array ([[0.1 , 0.4 , 0.1 , 0.4 ]])
unobs_data = np .array ([[0.4 , 0.1 , 0.4 , 0.1 ]])
rv = pm .Dirichlet .dist ([1 , 2 , 3 , 4 ], shape = (1 , 4 ))
(obs_rv , obs_mask ), (unobs_rv , unobs_mask ), joined_rv = create_partial_observed_rv (rv , mask )
# Test types
assert isinstance (obs_rv .owner .op , pm .Dirichlet )
assert isinstance (unobs_rv .owner .op , pm .Dirichlet )
# Test shapes
if obs_component_selected :
expected_obs_shape = (1 , 4 )
expected_unobs_shape = (0 , 4 )
else :
expected_obs_shape = (0 , 4 )
expected_unobs_shape = (1 , 4 )
assert tuple (obs_rv .shape .eval ()) == expected_obs_shape
assert tuple (unobs_rv .shape .eval ()) == expected_unobs_shape
assert tuple (joined_rv .shape .eval ()) == (1 , 4 )
# Test logp
logp = conditional_logp (
{
obs_rv : pt .as_tensor (obs_data )[obs_mask ],
unobs_rv : pt .as_tensor (unobs_data )[unobs_mask ],
}
)
obs_logp , unobs_logp = pytensor .function ([], list (logp .values ()))()
if obs_component_selected :
expected_obs_logp = pm .logp (rv , obs_data ).eval ()
expected_unobs_logp = []
else :
expected_obs_logp = []
expected_unobs_logp = pm .logp (rv , unobs_data ).eval ()
np .testing .assert_allclose (obs_logp , expected_obs_logp )
np .testing .assert_allclose (unobs_logp , expected_unobs_logp )
def test_multivariate_constant_mask_unseparable (self ):
mask = pt .constant (np .array ([[True , True , False , False ]]))
obs_data = np .array ([[0.1 , 0.4 , 0.1 , 0.4 ]])
unobs_data = np .array ([[0.4 , 0.1 , 0.4 , 0.1 ]])
rv = pm .Dirichlet .dist ([1 , 2 , 3 , 4 ], shape = (1 , 4 ))
(obs_rv , obs_mask ), (unobs_rv , unobs_mask ), joined_rv = create_partial_observed_rv (rv , mask )
# Test types
assert isinstance (obs_rv .owner .op , PartialObservedRV )
assert isinstance (unobs_rv .owner .op , PartialObservedRV )
# Test shapes
assert tuple (obs_rv .shape .eval ()) == (2 ,)
assert tuple (unobs_rv .shape .eval ()) == (2 ,)
assert tuple (joined_rv .shape .eval ()) == (1 , 4 )
# Test logp
logp = conditional_logp (
{
obs_rv : pt .as_tensor (obs_data )[obs_mask ],
unobs_rv : pt .as_tensor (unobs_data )[unobs_mask ],
}
)
obs_logp , unobs_logp = pytensor .function ([], list (logp .values ()))()
# For non-separable cases the logp always shows up in the observed variable
expected_logp = pm .logp (rv , [[0.1 , 0.4 , 0.4 , 0.1 ]]).eval ()
np .testing .assert_almost_equal (obs_logp , expected_logp )
np .testing .assert_array_equal (unobs_logp , [])
def test_multivariate_shared_mask_separable (self ):
mask = shared (np .array ([True ]))
obs_data = np .array ([[0.1 , 0.4 , 0.1 , 0.4 ]])
unobs_data = np .array ([[0.4 , 0.1 , 0.4 , 0.1 ]])
rv = pm .Dirichlet .dist ([1 , 2 , 3 , 4 ], shape = (1 , 4 ))
(obs_rv , obs_mask ), (unobs_rv , unobs_mask ), joined_rv = create_partial_observed_rv (rv , mask )
# Test types
# Multivariate RVs with shared masks on the last component are always unseparable.
assert isinstance (obs_rv .owner .op , pm .Dirichlet )
assert isinstance (unobs_rv .owner .op , pm .Dirichlet )
# Test shapes
assert tuple (obs_rv .shape .eval ()) == (0 , 4 )
assert tuple (unobs_rv .shape .eval ()) == (1 , 4 )
assert tuple (joined_rv .shape .eval ()) == (1 , 4 )
# Test logp
logp = conditional_logp (
{
obs_rv : pt .as_tensor (obs_data )[obs_mask ],
unobs_rv : pt .as_tensor (unobs_data )[unobs_mask ],
}
)
logp_fn = pytensor .function ([], list (logp .values ()))
obs_logp , unobs_logp = logp_fn ()
expected_logp = pm .logp (rv , unobs_data ).eval ()
np .testing .assert_almost_equal (obs_logp , [])
np .testing .assert_array_equal (unobs_logp , expected_logp )
# Test that we can update a shared mask
mask .set_value (np .array ([False ]))
assert tuple (obs_rv .shape .eval ()) == (1 , 4 )
assert tuple (unobs_rv .shape .eval ()) == (0 , 4 )
new_expected_logp = pm .logp (rv , obs_data ).eval ()
assert not np .isclose (expected_logp , new_expected_logp ) # Otherwise test is weak
obs_logp , unobs_logp = logp_fn ()
np .testing .assert_almost_equal (obs_logp , new_expected_logp )
np .testing .assert_array_equal (unobs_logp , [])
def test_multivariate_shared_mask_unseparable (self ):
# Even if the mask is initially not mixing support dims,
# it could later be changed in a way that does!
mask = shared (np .array ([[True , True , True , True ]]))
obs_data = np .array ([[0.1 , 0.4 , 0.1 , 0.4 ]])
unobs_data = np .array ([[0.4 , 0.1 , 0.4 , 0.1 ]])
rv = pm .Dirichlet .dist ([1 , 2 , 3 , 4 ], shape = (1 , 4 ))
(obs_rv , obs_mask ), (unobs_rv , unobs_mask ), joined_rv = create_partial_observed_rv (rv , mask )
# Test types
# Multivariate RVs with shared masks on the last component are always unseparable.
assert isinstance (obs_rv .owner .op , PartialObservedRV )
assert isinstance (unobs_rv .owner .op , PartialObservedRV )
# Test shapes
assert tuple (obs_rv .shape .eval ()) == (0 ,)
assert tuple (unobs_rv .shape .eval ()) == (4 ,)
assert tuple (joined_rv .shape .eval ()) == (1 , 4 )
# Test logp
logp = conditional_logp (
{
obs_rv : pt .as_tensor (obs_data )[obs_mask ],
unobs_rv : pt .as_tensor (unobs_data )[unobs_mask ],
}
)
logp_fn = pytensor .function ([], list (logp .values ()))
obs_logp , unobs_logp = logp_fn ()
# For non-separable cases the logp always shows up in the observed variable
# Even in this case where all entries come from an unobserved component
expected_logp = pm .logp (rv , unobs_data ).eval ()
np .testing .assert_almost_equal (obs_logp , expected_logp )
np .testing .assert_array_equal (unobs_logp , [])
# Test that we can update a shared mask
mask .set_value (np .array ([[False , False , True , True ]]))
assert tuple (obs_rv .shape .eval ()) == (2 ,)
assert tuple (unobs_rv .shape .eval ()) == (2 ,)
new_expected_logp = pm .logp (rv , [0.1 , 0.4 , 0.4 , 0.1 ]).eval ()
assert not np .isclose (expected_logp , new_expected_logp ) # Otherwise test is weak
obs_logp , unobs_logp = logp_fn ()
np .testing .assert_almost_equal (obs_logp , new_expected_logp )
np .testing .assert_array_equal (unobs_logp , [])
def test_moment (self ):
x = pm .GaussianRandomWalk .dist (init_dist = pm .Normal .dist (- 5 ), mu = 1 , steps = 9 )
ref_moment = moment (x ).eval ()
assert not np .allclose (ref_moment [::2 ], ref_moment [1 ::2 ]) # Otherwise test is weak
(obs_x , _ ), (unobs_x , _ ), _ = create_partial_observed_rv (
x , mask = np .array ([False , True ] * 5 )
)
np .testing .assert_allclose (moment (obs_x ).eval (), ref_moment [::2 ])
np .testing .assert_allclose (moment (unobs_x ).eval (), ref_moment [1 ::2 ])
def test_wrong_mask (self ):
rv = pm .Normal .dist (shape = (5 ,))
invalid_mask = np .array ([0 , 2 , 4 ])
with pytest .raises (ValueError , match = "mask must be an array or tensor of boolean dtype" ):
create_partial_observed_rv (rv , invalid_mask )
invalid_mask = np .zeros ((1 , 5 ), dtype = bool )
with pytest .raises (ValueError , match = "mask can't have more dims than rv" ):
create_partial_observed_rv (rv , invalid_mask )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤯