Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Symbolic conditions allowed in sympy.stats.frv #16908

Merged
merged 12 commits into from Jun 10, 2019
@@ -11,13 +11,15 @@

from itertools import product

from sympy import (Basic, Symbol, symbols, cacheit, sympify, Mul,
And, Or, Tuple, Piecewise, Eq, Lambda, exp, I, Dummy, nan)
from sympy import (Basic, Symbol, symbols, cacheit, sympify, Mul, Add,
And, Or, Tuple, Piecewise, Eq, Lambda, exp, I, Dummy, nan, Rational)
from sympy.sets.sets import FiniteSet
from sympy.core.relational import Relational
from sympy.stats.rv import (RandomDomain, ProductDomain, ConditionalDomain,
PSpace, IndependentProductPSpace, SinglePSpace, random_symbols,
sumsets, rv_subs, NamedArgsMixin)
from sympy.core.containers import Dict
from sympy.core.logic import Logic
import random

class FiniteDensity(dict):
@@ -143,14 +145,6 @@ def __new__(cls, domain, condition):
if condition is True:
return domain
cond = rv_subs(condition)
# Check that we aren't passed a condition like die1 == z
# where 'z' is a symbol that we don't know about
# We will never be able to test this equality through iteration
if not cond.free_symbols.issubset(domain.free_symbols):
raise ValueError('Condition "%s" contains foreign symbols \n%s.\n' % (
condition, tuple(cond.free_symbols - domain.free_symbols)) +
"Will be unable to iterate using this condition")

return Basic.__new__(cls, domain, cond)


@@ -166,7 +160,7 @@ def _test(self, elem):
return val
elif val.is_Equality:
return val.lhs == val.rhs
raise ValueError("Undeciable if %s" % str(val))
raise ValueError("Undecidable if %s" % str(val))

def __contains__(self, other):
return other in self.fulldomain and self._test(other)
@@ -309,8 +303,15 @@ def compute_moment_generating_function(self, expr):
def compute_expectation(self, expr, rvs=None, **kwargs):
rvs = rvs or self.values
expr = expr.xreplace(dict((rs, rs.symbol) for rs in rvs))
return sum([expr.xreplace(dict(elem)) * self.prob_of(elem)
for elem in self.domain])
probs = [self.prob_of(elem) for elem in self.domain]
if isinstance(expr, (Logic, Relational)):
parse_domain = [tuple(elem)[0][1] for elem in self.domain]
bools = [expr.xreplace(dict(elem)) for elem in self.domain]
else:
parse_domain = [expr.xreplace(dict(elem)) for elem in self.domain]
bools = [True for elem in self.domain]
return sum([Piecewise((prob * elem, blv), (0, True))
for prob, elem, blv in zip(probs, parse_domain, bools)])

def compute_quantile(self, expr):
cdf = self.compute_cdf(expr)
@@ -322,7 +323,16 @@ def compute_quantile(self, expr):

def probability(self, condition):
cond_symbols = frozenset(rs.symbol for rs in random_symbols(condition))
assert cond_symbols.issubset(self.symbols)
cond = rv_subs(condition)
if not cond_symbols.issubset(self.symbols):
raise ValueError("Cannot compare foriegn random symbols, %s"
%(str(cond_symbols - self.symbols)))
if isinstance(condition, Relational) and \
(not cond.free_symbols.issubset(self.domain.free_symbols)):
rv = condition.lhs if isinstance(condition.rhs, Symbol) else condition.rhs
return sum(Piecewise(
(self.prob_of(elem), condition.subs(rv, list(elem)[0][1])),
(0, True)) for elem in self.domain)
return sum(self.prob_of(elem) for elem in self.where(condition))

def conditional_space(self, condition):
@@ -137,7 +137,7 @@ def check(sides):
@cacheit
def dict(self):
as_int(self.sides) # Check that self.sides can be converted to an integer
return super(DieDistribution, self).dict
return dict((k, Rational(1, self.sides)) for k in self.set)

@property
def set(self):
@@ -319,3 +319,16 @@ def test_FinitePSpace():
X = Die('X', 6)
space = pspace(X)
assert space.density == DieDistribution(6)

def test_symbolic_conditions():
B = Bernoulli('B', S(1)/4)
D = Die('D', 4)
b, n = symbols('b, n')
Y = P(Eq(B, b))
Z = E(D > n)
assert Y == \
Piecewise((S(1)/4, Eq(b, 1)), (0, True)) + \
Piecewise((S(3)/4, Eq(b, 0)), (0, True))
assert Z == \
Piecewise((S(1)/4, n < 1), (0, True)) + Piecewise((S(1)/2, n < 2), (0, True)) + \
Piecewise((S(3)/4, n < 3), (0, True)) + Piecewise((S(1), n < 4), (0, True))
@@ -166,8 +166,6 @@ def test_dependence():
XX, YY = given(Tuple(X, Y), Eq(X + Y, 3))
assert dependent(XX, YY)


@XFAIL
def test_dependent_finite():
X, Y = Die('X'), Die('Y')
# Dependence testing requires symbolic conditions which currently break
ProTip! Use n and p to navigate between commits in a pull request.
You can’t perform that action at this time.