Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handmade distributions have the option not to check for validity #20120

Merged
merged 9 commits into from Oct 5, 2020
4 changes: 2 additions & 2 deletions sympy/stats/crv.py
Expand Up @@ -333,7 +333,7 @@ def compute_characteristic_function(self, **kwargs):
"""
x, t = symbols('x, t', real=True, cls=Dummy)
pdf = self.pdf(x)
cf = integrate(exp(I*t*x)*pdf, (x, -oo, oo))
cf = integrate(exp(I*t*x)*pdf, (x, self.set))
return Lambda(t, cf)

def _characteristic_function(self, t):
Expand All @@ -355,7 +355,7 @@ def compute_moment_generating_function(self, **kwargs):
"""
x, t = symbols('x, t', real=True, cls=Dummy)
pdf = self.pdf(x)
mgf = integrate(exp(t * x) * pdf, (x, -oo, oo))
mgf = integrate(exp(t * x) * pdf, (x, self.set))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

return Lambda(t, mgf)

def _moment_generating_function(self, t):
Expand Down
18 changes: 13 additions & 5 deletions sympy/stats/crv_types.py
Expand Up @@ -127,10 +127,11 @@
def _(x):
return any([is_random(i) for i in x])

def rv(symbol, cls, args):
def rv(symbol, cls, args, **kwargs):
args = list(map(sympify, args))
dist = cls(*args)
dist.check(*args)
if kwargs.pop('check', True):
Maelstrom6 marked this conversation as resolved.
Show resolved Hide resolved
dist.check(*args)
pspace = SingleContinuousPSpace(symbol, dist)
if any(is_random(arg) for arg in args):
from sympy.stats.compound_rv import CompoundPSpace, CompoundDistribution
Expand All @@ -152,10 +153,10 @@ def set(self):
def check(pdf, set):
x = Dummy('x')
val = integrate(pdf(x), (x, set))
_value_check(val == S.One, "The pdf on the given set is incorrect.")
_value_check(Eq(val, 1) != S.false, "The pdf on the given set is incorrect.")


def ContinuousRV(symbol, density, set=Interval(-oo, oo)):
def ContinuousRV(symbol, density, set=Interval(-oo, oo), **kwargs):
"""
Create a Continuous Random Variable given the following:

Expand All @@ -168,6 +169,11 @@ def ContinuousRV(symbol, density, set=Interval(-oo, oo)):
Represents probability density function.
set : set/Interval
Represents the region where the pdf is valid, by default is real line.
check : bool
If True, it will check whether the given density
integrates to 1 over the given set. If False, it
will not perform this check. Default is False.


Returns
=======
Expand Down Expand Up @@ -196,7 +202,9 @@ def ContinuousRV(symbol, density, set=Interval(-oo, oo)):
"""
pdf = Piecewise((density, set.as_relational(symbol)), (0, True))
pdf = Lambda(symbol, pdf)
return rv(symbol.name, ContinuousDistributionHandmade, (pdf, set))
# have a default of False while `rv` should have a default of True
kwargs['check'] = kwargs.pop('check', False)
return rv(symbol.name, ContinuousDistributionHandmade, (pdf, set), **kwargs)

########################################
# Continuous Probability Distributions #
Expand Down
19 changes: 13 additions & 6 deletions sympy/stats/drv_types.py
Expand Up @@ -17,7 +17,7 @@

from sympy import (Basic, factorial, exp, S, sympify, I, zeta, polylog, log, beta,
hyper, binomial, Piecewise, floor, besseli, sqrt, Sum, Dummy,
Lambda)
Lambda, Eq)
from sympy.stats.drv import SingleDiscreteDistribution, SingleDiscretePSpace
from sympy.stats.rv import _value_check, is_random

Expand All @@ -33,10 +33,11 @@
]


def rv(symbol, cls, *args):
def rv(symbol, cls, *args, **kwargs):
args = list(map(sympify, args))
dist = cls(*args)
dist.check(*args)
if kwargs.pop('check', True):
Maelstrom6 marked this conversation as resolved.
Show resolved Hide resolved
dist.check(*args)
pspace = SingleDiscretePSpace(symbol, dist)
if any(is_random(arg) for arg in args):
from sympy.stats.compound_rv import CompoundPSpace, CompoundDistribution
Expand All @@ -58,9 +59,9 @@ def set(self):
def check(pdf, set):
x = Dummy('x')
val = Sum(pdf(x), (x, set._inf, set._sup)).doit()
_value_check(val == S.One, "The pdf is incorrect on the given set.")
_value_check(Eq(val, 1) != S.false, "The pdf is incorrect on the given set.")

def DiscreteRV(symbol, density, set=S.Integers):
def DiscreteRV(symbol, density, set=S.Integers, **kwargs):
"""
Create a Discrete Random Variable given the following:

Expand All @@ -73,6 +74,10 @@ def DiscreteRV(symbol, density, set=S.Integers):
Represents probability density function.
set : set
Represents the region where the pdf is valid, by default is real line.
check : bool
If True, it will check whether the given density
integrates to 1 over the given set. If False, it
will not perform this check. Default is False.

Examples
========
Expand All @@ -97,7 +102,9 @@ def DiscreteRV(symbol, density, set=S.Integers):
set = sympify(set)
pdf = Piecewise((density, set.as_relational(symbol)), (0, True))
pdf = Lambda(symbol, pdf)
return rv(symbol.name, DiscreteDistributionHandmade, pdf, set)
# have a default of False while `rv` should have a default of True
kwargs['check'] = kwargs.pop('check', False)
return rv(symbol.name, DiscreteDistributionHandmade, pdf, set, **kwargs)


#-------------------------------------------------------------------------------
Expand Down
20 changes: 15 additions & 5 deletions sympy/stats/frv_types.py
Expand Up @@ -36,10 +36,11 @@
'Rademacher'
]

def rv(name, cls, *args):
def rv(name, cls, *args, **kwargs):
args = list(map(sympify, args))
dist = cls(*args)
dist.check(*args)
if kwargs.pop('check', True):
Maelstrom6 marked this conversation as resolved.
Show resolved Hide resolved
dist.check(*args)
pspace = SingleFinitePSpace(name, dist)
if any(is_random(arg) for arg in args):
from sympy.stats.compound_rv import CompoundPSpace, CompoundDistribution
Expand All @@ -66,17 +67,24 @@ def check(density):
for p in density.values():
_value_check((p >= 0, p <= 1),
"Probability at a point must be between 0 and 1.")
_value_check(Eq(sum(density.values()), 1), "Total Probability must be 1.")
val = sum(density.values())
_value_check(Eq(val, 1) != S.false, "Total Probability must be 1.")

def FiniteRV(name, density):
def FiniteRV(name, density, **kwargs):
r"""
Create a Finite Random Variable given a dict representing the density.

Parameters
==========

name : Symbol
Represents name of the random variable.
density: A dict
Dictionary conatining the pdf of finite distribution
check : bool
If True, it will check whether the given density
integrates to 1 over the given set. If False, it
will not perform this check. Default is False.

Examples
========
Expand All @@ -97,7 +105,9 @@ def FiniteRV(name, density):
RandomSymbol

"""
return rv(name, FiniteDistributionHandmade, density)
# have a default of False while `rv` should have a default of True
kwargs['check'] = kwargs.pop('check', False)
return rv(name, FiniteDistributionHandmade, density, **kwargs)

class DiscreteUniformDistribution(SingleFiniteDistribution):

Expand Down
13 changes: 11 additions & 2 deletions sympy/stats/tests/test_continuous_rv.py
Expand Up @@ -337,7 +337,7 @@ def test_sample_continuous():
def test_ContinuousRV():
pdf = sqrt(2)*exp(-x**2/2)/(2*sqrt(pi)) # Normal distribution
# X and Y should be equivalent
X = ContinuousRV(x, pdf)
X = ContinuousRV(x, pdf, check=True)
Y = Normal('y', 0, 1)

assert variance(X) == variance(Y)
Expand All @@ -346,7 +346,16 @@ def test_ContinuousRV():
assert Z.pspace.domain.set == Interval(0, oo)
assert E(Z) == 1
assert P(Z > 5) == exp(-5)
raises(ValueError, lambda: ContinuousRV(z, exp(-z), set=Interval(0, 10)))
raises(ValueError, lambda: ContinuousRV(z, exp(-z), set=Interval(0, 10), check=True))

# the correct pdf for Gamma(k, theta) but the integral in `check`
# integrates to something equivalent to 1 and not to 1 exactly
_x, k, theta = symbols("x k theta", positive=True)
pdf = 1/(gamma(k)*theta**k)*_x**(k-1)*exp(-_x/theta)
X = ContinuousRV(_x, pdf, set=Interval(0, oo))
Y = Gamma('y', k, theta)
assert (E(X) - E(Y)).simplify() == 0
assert (variance(X) - variance(Y)).simplify() == 0


def test_arcsin():
Expand Down
10 changes: 8 additions & 2 deletions sympy/stats/tests/test_discrete_rv.py
Expand Up @@ -178,11 +178,17 @@ def test_DiscreteRV():
p = S(1)/2
x = Symbol('x', integer=True, positive=True)
pdf = p*(1 - p)**(x - 1) # pdf of Geometric Distribution
D = DiscreteRV(x, pdf, set=S.Naturals)
D = DiscreteRV(x, pdf, set=S.Naturals, check=True)
assert E(D) == E(Geometric('G', S(1)/2)) == 2
assert P(D > 3) == S(1)/8
assert D.pspace.domain.set == S.Naturals
raises(ValueError, lambda: DiscreteRV(x, x, FiniteSet(*range(4))))
raises(ValueError, lambda: DiscreteRV(x, x, FiniteSet(*range(4)), check=True))

# purposeful invalid pmf but it should not raise since check=False
# see test_drv_types.test_ContinuousRV for explanation
X = DiscreteRV(x, 1/x, S.Naturals)
assert P(X < 2) == 1
assert E(X) == oo

def test_precomputed_characteristic_functions():
import mpmath
Expand Down
14 changes: 10 additions & 4 deletions sympy/stats/tests/test_finite_rv.py
Expand Up @@ -383,7 +383,7 @@ def test_rademacher():


def test_FiniteRV():
F = FiniteRV('F', {1: S.Half, 2: Rational(1, 4), 3: Rational(1, 4)})
F = FiniteRV('F', {1: S.Half, 2: Rational(1, 4), 3: Rational(1, 4)}, check=True)
p = Symbol("p", positive=True)

assert dict(density(F).items()) == {S.One: S.Half, S(2): Rational(1, 4), S(3): Rational(1, 4)}
Expand All @@ -395,10 +395,16 @@ def test_FiniteRV():
*[Eq(F.symbol, i) for i in [1, 2, 3]])

assert F.pspace.domain.set == FiniteSet(1, 2, 3)
raises(ValueError, lambda: FiniteRV('F', {1: S.Half, 2: S.Half, 3: S.Half}))
raises(ValueError, lambda: FiniteRV('F', {1: S.Half, 2: Rational(-1, 2), 3: S.One}))
raises(ValueError, lambda: FiniteRV('F', {1: S.Half, 2: S.Half, 3: S.Half}, check=True))
raises(ValueError, lambda: FiniteRV('F', {1: S.Half, 2: Rational(-1, 2), 3: S.One}, check=True))
raises(ValueError, lambda: FiniteRV('F', {1: S.One, 2: Rational(3, 2), 3: S.Zero,\
4: Rational(-1, 2), 5: Rational(-3, 4), 6: Rational(-1, 4)}))
4: Rational(-1, 2), 5: Rational(-3, 4), 6: Rational(-1, 4)}, check=True))

# purposeful invalid pmf but it should not raise since check=False
# see test_drv_types.test_ContinuousRV for explanation
X = FiniteRV('X', {1: 1, 2: 2})
assert E(X) == 5
assert P(X <= 2) + P(X > 2) != 1

def test_density_call():
from sympy.abc import p
Expand Down