Skip to content

Commit

Permalink
add mixture dist.
Browse files Browse the repository at this point in the history
  • Loading branch information
Smit-create committed Aug 4, 2020
1 parent 9a57bef commit d2fdd4c
Show file tree
Hide file tree
Showing 5 changed files with 299 additions and 0 deletions.
14 changes: 14 additions & 0 deletions sympy/core/tests/test_args.py
Expand Up @@ -1634,6 +1634,20 @@ def test_sympy__stats__joint_rv_types__NegativeMultinomialDistribution():
from sympy.stats.joint_rv_types import NegativeMultinomialDistribution
assert _test_args(NegativeMultinomialDistribution(5, [0.5, 0.1, 0.3]))

def test_sympy__stats__mixture_rv__MixturePSpace():
from sympy.stats.mixture_rv import MixturePSpace, MixtureDistribution
from sympy.stats import Normal
from sympy import S
M = MixtureDistribution([S.Half, S.Half], [Normal('N', 1, 2), Normal('M', 2, 3)])
assert _test_args(MixturePSpace('x', M))

def test_sympy__stats__mixture_rv__MixtureDistribution():
from sympy.stats.mixture_rv import MixtureDistribution
from sympy.stats import Normal
from sympy import S
N , M = Normal('N', 1, 2), Normal('M', 2, 3)
assert _test_args(MixtureDistribution([S.Half, S.Half], [N, M]))

def test_sympy__stats__rv__RandomIndexedSymbol():
from sympy.stats.rv import RandomIndexedSymbol, pspace
from sympy.stats.stochastic_process_types import DiscreteMarkovChain
Expand Down
4 changes: 4 additions & 0 deletions sympy/stats/__init__.py
Expand Up @@ -144,6 +144,8 @@

'MatrixGamma', 'Wishart', 'MatrixNormal',

'Mixture',

'Probability', 'Expectation', 'Variance', 'Covariance', 'Moment',
'CentralMoment',

Expand Down Expand Up @@ -193,6 +195,8 @@

from .matrix_distributions import MatrixGamma, Wishart, MatrixNormal

from .mixture_rv import Mixture

from .symbolic_probability import (Probability, Expectation, Variance,
Covariance, Moment, CentralMoment)

Expand Down
166 changes: 166 additions & 0 deletions sympy/stats/mixture_rv.py
@@ -0,0 +1,166 @@
from sympy import Basic, ImmutableMatrix, S, Symbol, Lambda
from sympy.core.sympify import sympify
from sympy.stats.rv import _symbol_converter, RandomSymbol, NamedArgsMixin, PSpace
from sympy.stats.crv import SingleContinuousPSpace
from sympy.stats.frv import SingleFinitePSpace
from sympy.stats.drv import SingleDiscretePSpace
from sympy.stats.crv_types import ContinuousDistributionHandmade
from sympy.stats.drv_types import DiscreteDistributionHandmade
from sympy.stats.frv_types import FiniteDistributionHandmade


class MixturePSpace(PSpace):
"""
A temporary Probability Space for the Mixture Distribution
"""

def __new__(cls, s, distribution):
s = _symbol_converter(s)
if not isinstance(distribution, MixtureDistribution):
raise ValueError("%s should be an isinstance of "
"MixtureDistribution"%(distribution))
return Basic.__new__(cls, s, distribution)

@property
def value(self):
return RandomSymbol(self.symbol, self)

@property
def symbol(self):
return self.args[0]

@property
def distribution(self):
return self.args[1]

@property
def pdf(self):
return self.distribution.pdf(self.symbol)

@property
def set(self):
return self.distribution.set

@property
def domain(self):
return self._get_newpspace().domain

def _get_newpspace(self):
new_pspace = self._transform_pspace()
if new_pspace is not None:
return new_pspace
message = ("Mixture Distribution for %s is not implemeted yet" % str(self.distribution))
raise NotImplementedError(message)

def _transform_pspace(self):
"""
This function returns the new pspace of the distribution using handmade
Distributions and their corresponding pspace.
"""
pdf = Lambda(self.symbol, self.pdf)
_set = self.distribution.set
if self.distribution.is_Continuous:
return SingleContinuousPSpace(self.symbol, ContinuousDistributionHandmade(pdf, _set))
elif self.distribution.is_Discrete:
return SingleDiscretePSpace(self.symbol, DiscreteDistributionHandmade(pdf, _set))
elif self.distribution.is_Finite:
dens = dict((k, pdf(k)) for k in _set)
return SingleFinitePSpace(self.symbol, FiniteDistributionHandmade(dens))

def compute_density(self, expr, **kwargs):
new_pspace = self._get_newpspace()
expr = expr.subs({self.value: new_pspace.value})
return new_pspace.compute_density(expr, **kwargs)

def compute_cdf(self, expr, **kwargs):
new_pspace = self._get_newpspace()
expr = expr.subs({self.value: new_pspace.value})
return new_pspace.compute_cdf(expr, **kwargs)

def compute_expectation(self, expr, rvs=None, evaluate=False, **kwargs):
new_pspace = self._get_newpspace()
expr = expr.subs({self.value: new_pspace.value})
if isinstance(new_pspace, SingleFinitePSpace):
return new_pspace.compute_expectation(expr, rvs, **kwargs)
return new_pspace.compute_expectation(expr, rvs, evaluate, **kwargs)

def probability(self, condition, **kwargs):
new_pspace = self._get_newpspace()
condition = condition.subs({self.value: new_pspace.value})
return new_pspace.probability(condition)

def conditional_space(self, condition, **kwargs):
new_pspace = self._get_newpspace()
condition = condition.subs({self.value: new_pspace.value})
return new_pspace.conditional_space(condition)

def sample(self, size=(), library='scipy'):
new_pspace = self._get_newpspace()
samp = new_pspace.sample(size, library)
return {self.value: samp[new_pspace.value]}

def rv(symbol, cls, args):
args = list(map(sympify, args))
dist = cls(*args)
dist.check(*args)
pspace = MixturePSpace(symbol, dist)
return pspace.value


class MixtureDistribution(Basic, NamedArgsMixin):
"""Represents the Mixture distribution"""
_argnames = ('wts', 'rvs')

def __new__(cls, wts, rvs):
wts = ImmutableMatrix(wts)
rvs = ImmutableMatrix(rvs)
return Basic.__new__(cls, wts, rvs)

@property
def set(self):
return (self.rvs[0, 0]).pspace.distribution.set

@property
def is_Continuous(self):
return (self.rvs[0, 0]).pspace.is_Continuous

@property
def is_Discrete(self):
return (self.rvs[0, 0]).pspace.is_Discrete

@property
def is_Finite(self):
return (self.rvs[0, 0]).pspace.is_Finite

@staticmethod
def check(wts, rvs):
rvs, wts = list(rvs), list(wts)
set_ = rvs[0].pspace.domain.set
for rv in rvs:
if not isinstance(rv, RandomSymbol):
raise TypeError("Each of element should be a random variable")
if rv.pspace.domain.set != set_:
raise ValueError("Each random variable should be defined on same set")
for wt in wts:
if not wt.is_positive:
raise ValueError("Weight of each random variable should be positive")
if len(rvs) != len(wts):
raise ValueError("Weights and RVs should be of same length")
if sum(wts) != S.One:
raise ValueError("Sum of the weights should be 1")

def pdf(self, x):
y = Symbol('y')
rvs, wts = list(self.rvs), list(self.wts)
pdf_ = S.Zero
if self.is_Finite:
for rv in range(len(rvs)):
pdf_ = pdf_ + wts[rv]*rvs[rv].pspace.distribution.pmf(y)
else:
for rv in range(len(rvs)):
pdf_ = pdf_ + wts[rv]*rvs[rv].pspace.distribution.pdf(y)
return Lambda(y, pdf_)(x)

def Mixture(name, wts, rvs):
"""Creates a random variable with mixture distribution"""
return rv(name, MixtureDistribution, (wts, rvs))
3 changes: 3 additions & 0 deletions sympy/stats/rv.py
Expand Up @@ -605,9 +605,12 @@ def pspace(expr):
if all(rv.pspace == rvs[0].pspace for rv in rvs):
return rvs[0].pspace
from sympy.stats.compound_rv import CompoundPSpace
from sympy.stats.mixture_rv import MixturePSpace
for rv in rvs:
if isinstance(rv.pspace, CompoundPSpace):
return rv.pspace
if isinstance(rv.pspace, MixturePSpace):
return rv.pspace
# Otherwise make a product space
return IndependentProductPSpace(*[rv.pspace for rv in rvs])

Expand Down
112 changes: 112 additions & 0 deletions sympy/stats/tests/test_mixture_rv.py
@@ -0,0 +1,112 @@
from sympy import (S, symbols, sqrt, Abs, exp, pi, oo, erf, erfc, Integral,
Interval, Dummy, factorial, binomial, log, beta, lerchphi,
Piecewise, lowergamma, gamma)
from sympy.stats import (P, E, density, sample, cdf, Normal, Laplace, Mixture,
NegativeBinomial, Poisson, Geometric, YuleSimon, Logarithmic,
Binomial, Hypergeometric, BetaBinomial)
from sympy.stats.mixture_rv import MixtureDistribution, MixturePSpace
from sympy.testing.pytest import raises, skip, ignore_warnings
from sympy.external import import_module

y, z = symbols('y z')
def test_continuous_mixture():
N = Normal("N", 0, 1)
M = Normal('M', 1, 2)
Z = Laplace('L', 3, 1)
D = Mixture('D', [S(2)/10, S(5)/10, S(3)/10], [N, M, Z])
assert D.pspace.distribution.is_Continuous
assert isinstance(D.pspace.distribution, MixtureDistribution)
assert isinstance(D.pspace, MixturePSpace)
assert D.pspace.set == Interval(-oo, oo)
assert density(D)(z) == 3*exp(-Abs(z - 3))/20 + sqrt(2)*exp(-(z - 1)**2/8
)/(8*sqrt(pi)) + sqrt(2)*exp(-z**2/2)/(10*sqrt(pi))
assert E(D).simplify() == S(7)/5
k = Dummy('k')
cdf_expr = erf(sqrt(2)*(z - 1)/4)/4 - erfc(sqrt(2)*z/2)/10 + 3*Integral(
exp(-Abs(k - 3)), (k, -oo, z))/20 + S(9)/20
assert cdf(D)(z).simplify().dummy_eq(cdf_expr)
prob = Integral(3*exp(-Abs(k - 3))/20 + sqrt(2)*exp(-(k - 1)**2/8)/(
8*sqrt(pi)) + sqrt(2)*exp(-k**2/2)/(10*sqrt(pi)), (k, 0, oo))

with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed
assert P(D > 0, evaluate=False).rewrite(Integral).dummy_eq(prob)
scipy = import_module('scipy')
if not scipy:
skip('Scipy is not installed. Abort tests')
else:
with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed
assert next(sample(D)) in D.pspace.set
samp = next(sample(D, size=5))
for sam in samp:
assert sam in D.pspace.set

def test_discrete_mixture():
X = NegativeBinomial('X', 5, S(1)/3)
Y = Poisson('Y', 2)
D = Mixture('D', [S(2)/5, S(3)/5], [X, Y])
assert D.pspace.distribution.is_Discrete
assert D.pspace.set == S.Naturals0
assert density(D)(z).simplify() == 3*2**z*exp(-2)/(5*factorial(z)) + S(64)*3**(-z)*binomial(z + 4,
z)/1215
assert E(D) == S(11)/5
assert cdf(D)(z).simplify() == Piecewise((1 - 3*lowergamma(z + 1, 2)/(5*gamma(z + 1)
) - S(4)*3**(-z)*z**4/3645 - S(64)*3**(-z)*z**3/3645 - S(392)*3**(-z)*z**2/3645 -\
S(1112)*3**(-z)*z/3645 - S(422)*3**(-z)/1215, z >= 0), (0, True))
assert P(D > 2).simplify() == S(2813)/3645 - 3*exp(-2)
X = Geometric('X', S(2)/5)
Y = Logarithmic('Y', S(2)/3)
Z = YuleSimon('Z', 3)
B = Binomial('B', 2, S(1)/10)
wts = list(density(B).dict.values())
D = Mixture('D', wts, [X, Y, Z])
assert density(D)(z).simplify() == S(9)*10**z*15**(-z)/(50*z*log(3)
) + S(3)*beta(z, 4)/100 + S(27)*15**(-z)*3**(2*z)/50
assert cdf(D)(z).simplify() == Piecewise(((-12*2**z*3**(-z)*lerchphi(S(2)/3, 1, z + 1
) + (-81*3**z*5**(-z) - z*beta(z, 4) + 82)*log(3) + log(387420489))/(100*log(3)),
z >= 1), (0, True))
assert E(D).simplify() == S(9)/(25*log(3)) + S(51)/25
scipy = import_module('scipy')
if not scipy:
skip('Scipy is not installed. Abort tests')
else:
with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed
assert next(sample(D)) in D.pspace.set
samp = next(sample(D, size=5))
for sam in samp:
assert sam in D.pspace.set

def test_finite_mixture():
X = Hypergeometric('X', 10, 5, 5)
Y = BetaBinomial('Y', 5, 1, 2)
Z = Binomial('Z', 5, S(2)/3)
B = Binomial('B', 2, S(1)/3)
wts = list(density(B).dict.values())
D = Mixture('D', wts, [X, Y, Z])
assert D.pspace.distribution.is_Finite
assert D.pspace.domain.set == {0, 1, 2, 3, 4, 5}
assert sum(density(D).dict.values()).evalf() == 1
assert E(D) == S(40)*beta(2, 6)/9 + S(40)*beta(6, 2)/9 + S(160)*beta(3, 5
)/9 + S(160)*beta(5, 3)/9 + S(80)*beta(4, 4)/3 + S(40)/27
assert P(D < 5) == S(40)*beta(5, 3)/9 + S(80)*beta(4, 4)/9 + S(80)*beta(3, 5
)/9 + S(40)*beta(2, 6)/9 + S(10198)/15309
scipy = import_module('scipy')
if not scipy:
skip('Scipy is not installed. Abort tests')
else:
with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed
assert next(sample(D)) in D.pspace.set
samp = next(sample(D, size=5))
for sam in samp:
assert sam in D.pspace.set

def test_mixture_raises():
wts = [1, 2, 3]
X = Hypergeometric('X', 10, 5, 5)
Y = BetaBinomial('Y', 5, 1, 2)
Z = Binomial('Z', 5, S(2)/3)
raises(ValueError, lambda: Mixture('D', wts, [X, Y, Z]))
raises(TypeError, lambda: Mixture('D', [S(2)/5, S(3)/5], [X, y]))
Z = Binomial('Z', 2, S(2)/3)
raises(ValueError, lambda: Mixture('D', wts, [X, Y, Z]))
raises(ValueError, lambda: Mixture('D', [-S(2)/5, S(3)/5], [X, Y]))
raises(ValueError, lambda: Mixture('D', [S(2)/5, S(2)/5, S(1)/5], [X, Y]))

0 comments on commit d2fdd4c

Please sign in to comment.