From d2fdd4cfed01cd9b8fb61f163076a8614ba38d78 Mon Sep 17 00:00:00 2001 From: Smit Lunagariya Date: Tue, 4 Aug 2020 10:49:18 +0530 Subject: [PATCH] add mixture dist. --- sympy/core/tests/test_args.py | 14 +++ sympy/stats/__init__.py | 4 + sympy/stats/mixture_rv.py | 166 +++++++++++++++++++++++++++ sympy/stats/rv.py | 3 + sympy/stats/tests/test_mixture_rv.py | 112 ++++++++++++++++++ 5 files changed, 299 insertions(+) create mode 100644 sympy/stats/mixture_rv.py create mode 100644 sympy/stats/tests/test_mixture_rv.py diff --git a/sympy/core/tests/test_args.py b/sympy/core/tests/test_args.py index 19640af56275..06c38244b1cb 100644 --- a/sympy/core/tests/test_args.py +++ b/sympy/core/tests/test_args.py @@ -1634,6 +1634,20 @@ def test_sympy__stats__joint_rv_types__NegativeMultinomialDistribution(): from sympy.stats.joint_rv_types import NegativeMultinomialDistribution assert _test_args(NegativeMultinomialDistribution(5, [0.5, 0.1, 0.3])) +def test_sympy__stats__mixture_rv__MixturePSpace(): + from sympy.stats.mixture_rv import MixturePSpace, MixtureDistribution + from sympy.stats import Normal + from sympy import S + M = MixtureDistribution([S.Half, S.Half], [Normal('N', 1, 2), Normal('M', 2, 3)]) + assert _test_args(MixturePSpace('x', M)) + +def test_sympy__stats__mixture_rv__MixtureDistribution(): + from sympy.stats.mixture_rv import MixtureDistribution + from sympy.stats import Normal + from sympy import S + N , M = Normal('N', 1, 2), Normal('M', 2, 3) + assert _test_args(MixtureDistribution([S.Half, S.Half], [N, M])) + def test_sympy__stats__rv__RandomIndexedSymbol(): from sympy.stats.rv import RandomIndexedSymbol, pspace from sympy.stats.stochastic_process_types import DiscreteMarkovChain diff --git a/sympy/stats/__init__.py b/sympy/stats/__init__.py index be9d57f38070..e46196b40b95 100644 --- a/sympy/stats/__init__.py +++ b/sympy/stats/__init__.py @@ -144,6 +144,8 @@ 'MatrixGamma', 'Wishart', 'MatrixNormal', + 'Mixture', + 'Probability', 'Expectation', 'Variance', 'Covariance', 'Moment', 'CentralMoment', @@ -193,6 +195,8 @@ from .matrix_distributions import MatrixGamma, Wishart, MatrixNormal +from .mixture_rv import Mixture + from .symbolic_probability import (Probability, Expectation, Variance, Covariance, Moment, CentralMoment) diff --git a/sympy/stats/mixture_rv.py b/sympy/stats/mixture_rv.py new file mode 100644 index 000000000000..e4740521c8e5 --- /dev/null +++ b/sympy/stats/mixture_rv.py @@ -0,0 +1,166 @@ +from sympy import Basic, ImmutableMatrix, S, Symbol, Lambda +from sympy.core.sympify import sympify +from sympy.stats.rv import _symbol_converter, RandomSymbol, NamedArgsMixin, PSpace +from sympy.stats.crv import SingleContinuousPSpace +from sympy.stats.frv import SingleFinitePSpace +from sympy.stats.drv import SingleDiscretePSpace +from sympy.stats.crv_types import ContinuousDistributionHandmade +from sympy.stats.drv_types import DiscreteDistributionHandmade +from sympy.stats.frv_types import FiniteDistributionHandmade + + +class MixturePSpace(PSpace): + """ + A temporary Probability Space for the Mixture Distribution + """ + + def __new__(cls, s, distribution): + s = _symbol_converter(s) + if not isinstance(distribution, MixtureDistribution): + raise ValueError("%s should be an isinstance of " + "MixtureDistribution"%(distribution)) + return Basic.__new__(cls, s, distribution) + + @property + def value(self): + return RandomSymbol(self.symbol, self) + + @property + def symbol(self): + return self.args[0] + + @property + def distribution(self): + return self.args[1] + + @property + def pdf(self): + return self.distribution.pdf(self.symbol) + + @property + def set(self): + return self.distribution.set + + @property + def domain(self): + return self._get_newpspace().domain + + def _get_newpspace(self): + new_pspace = self._transform_pspace() + if new_pspace is not None: + return new_pspace + message = ("Mixture Distribution for %s is not implemeted yet" % str(self.distribution)) + raise NotImplementedError(message) + + def _transform_pspace(self): + """ + This function returns the new pspace of the distribution using handmade + Distributions and their corresponding pspace. + """ + pdf = Lambda(self.symbol, self.pdf) + _set = self.distribution.set + if self.distribution.is_Continuous: + return SingleContinuousPSpace(self.symbol, ContinuousDistributionHandmade(pdf, _set)) + elif self.distribution.is_Discrete: + return SingleDiscretePSpace(self.symbol, DiscreteDistributionHandmade(pdf, _set)) + elif self.distribution.is_Finite: + dens = dict((k, pdf(k)) for k in _set) + return SingleFinitePSpace(self.symbol, FiniteDistributionHandmade(dens)) + + def compute_density(self, expr, **kwargs): + new_pspace = self._get_newpspace() + expr = expr.subs({self.value: new_pspace.value}) + return new_pspace.compute_density(expr, **kwargs) + + def compute_cdf(self, expr, **kwargs): + new_pspace = self._get_newpspace() + expr = expr.subs({self.value: new_pspace.value}) + return new_pspace.compute_cdf(expr, **kwargs) + + def compute_expectation(self, expr, rvs=None, evaluate=False, **kwargs): + new_pspace = self._get_newpspace() + expr = expr.subs({self.value: new_pspace.value}) + if isinstance(new_pspace, SingleFinitePSpace): + return new_pspace.compute_expectation(expr, rvs, **kwargs) + return new_pspace.compute_expectation(expr, rvs, evaluate, **kwargs) + + def probability(self, condition, **kwargs): + new_pspace = self._get_newpspace() + condition = condition.subs({self.value: new_pspace.value}) + return new_pspace.probability(condition) + + def conditional_space(self, condition, **kwargs): + new_pspace = self._get_newpspace() + condition = condition.subs({self.value: new_pspace.value}) + return new_pspace.conditional_space(condition) + + def sample(self, size=(), library='scipy'): + new_pspace = self._get_newpspace() + samp = new_pspace.sample(size, library) + return {self.value: samp[new_pspace.value]} + +def rv(symbol, cls, args): + args = list(map(sympify, args)) + dist = cls(*args) + dist.check(*args) + pspace = MixturePSpace(symbol, dist) + return pspace.value + + +class MixtureDistribution(Basic, NamedArgsMixin): + """Represents the Mixture distribution""" + _argnames = ('wts', 'rvs') + + def __new__(cls, wts, rvs): + wts = ImmutableMatrix(wts) + rvs = ImmutableMatrix(rvs) + return Basic.__new__(cls, wts, rvs) + + @property + def set(self): + return (self.rvs[0, 0]).pspace.distribution.set + + @property + def is_Continuous(self): + return (self.rvs[0, 0]).pspace.is_Continuous + + @property + def is_Discrete(self): + return (self.rvs[0, 0]).pspace.is_Discrete + + @property + def is_Finite(self): + return (self.rvs[0, 0]).pspace.is_Finite + + @staticmethod + def check(wts, rvs): + rvs, wts = list(rvs), list(wts) + set_ = rvs[0].pspace.domain.set + for rv in rvs: + if not isinstance(rv, RandomSymbol): + raise TypeError("Each of element should be a random variable") + if rv.pspace.domain.set != set_: + raise ValueError("Each random variable should be defined on same set") + for wt in wts: + if not wt.is_positive: + raise ValueError("Weight of each random variable should be positive") + if len(rvs) != len(wts): + raise ValueError("Weights and RVs should be of same length") + if sum(wts) != S.One: + raise ValueError("Sum of the weights should be 1") + + def pdf(self, x): + y = Symbol('y') + rvs, wts = list(self.rvs), list(self.wts) + pdf_ = S.Zero + if self.is_Finite: + for rv in range(len(rvs)): + pdf_ = pdf_ + wts[rv]*rvs[rv].pspace.distribution.pmf(y) + else: + for rv in range(len(rvs)): + pdf_ = pdf_ + wts[rv]*rvs[rv].pspace.distribution.pdf(y) + return Lambda(y, pdf_)(x) + +def Mixture(name, wts, rvs): + """Creates a random variable with mixture distribution""" + return rv(name, MixtureDistribution, (wts, rvs)) diff --git a/sympy/stats/rv.py b/sympy/stats/rv.py index c1bc92676621..a3f8fe6b65a7 100644 --- a/sympy/stats/rv.py +++ b/sympy/stats/rv.py @@ -605,9 +605,12 @@ def pspace(expr): if all(rv.pspace == rvs[0].pspace for rv in rvs): return rvs[0].pspace from sympy.stats.compound_rv import CompoundPSpace + from sympy.stats.mixture_rv import MixturePSpace for rv in rvs: if isinstance(rv.pspace, CompoundPSpace): return rv.pspace + if isinstance(rv.pspace, MixturePSpace): + return rv.pspace # Otherwise make a product space return IndependentProductPSpace(*[rv.pspace for rv in rvs]) diff --git a/sympy/stats/tests/test_mixture_rv.py b/sympy/stats/tests/test_mixture_rv.py new file mode 100644 index 000000000000..7a155cea454c --- /dev/null +++ b/sympy/stats/tests/test_mixture_rv.py @@ -0,0 +1,112 @@ +from sympy import (S, symbols, sqrt, Abs, exp, pi, oo, erf, erfc, Integral, + Interval, Dummy, factorial, binomial, log, beta, lerchphi, + Piecewise, lowergamma, gamma) +from sympy.stats import (P, E, density, sample, cdf, Normal, Laplace, Mixture, + NegativeBinomial, Poisson, Geometric, YuleSimon, Logarithmic, + Binomial, Hypergeometric, BetaBinomial) +from sympy.stats.mixture_rv import MixtureDistribution, MixturePSpace +from sympy.testing.pytest import raises, skip, ignore_warnings +from sympy.external import import_module + +y, z = symbols('y z') +def test_continuous_mixture(): + N = Normal("N", 0, 1) + M = Normal('M', 1, 2) + Z = Laplace('L', 3, 1) + D = Mixture('D', [S(2)/10, S(5)/10, S(3)/10], [N, M, Z]) + assert D.pspace.distribution.is_Continuous + assert isinstance(D.pspace.distribution, MixtureDistribution) + assert isinstance(D.pspace, MixturePSpace) + assert D.pspace.set == Interval(-oo, oo) + assert density(D)(z) == 3*exp(-Abs(z - 3))/20 + sqrt(2)*exp(-(z - 1)**2/8 + )/(8*sqrt(pi)) + sqrt(2)*exp(-z**2/2)/(10*sqrt(pi)) + assert E(D).simplify() == S(7)/5 + k = Dummy('k') + cdf_expr = erf(sqrt(2)*(z - 1)/4)/4 - erfc(sqrt(2)*z/2)/10 + 3*Integral( + exp(-Abs(k - 3)), (k, -oo, z))/20 + S(9)/20 + assert cdf(D)(z).simplify().dummy_eq(cdf_expr) + prob = Integral(3*exp(-Abs(k - 3))/20 + sqrt(2)*exp(-(k - 1)**2/8)/( + 8*sqrt(pi)) + sqrt(2)*exp(-k**2/2)/(10*sqrt(pi)), (k, 0, oo)) + + with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed + assert P(D > 0, evaluate=False).rewrite(Integral).dummy_eq(prob) + scipy = import_module('scipy') + if not scipy: + skip('Scipy is not installed. Abort tests') + else: + with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed + assert next(sample(D)) in D.pspace.set + samp = next(sample(D, size=5)) + for sam in samp: + assert sam in D.pspace.set + +def test_discrete_mixture(): + X = NegativeBinomial('X', 5, S(1)/3) + Y = Poisson('Y', 2) + D = Mixture('D', [S(2)/5, S(3)/5], [X, Y]) + assert D.pspace.distribution.is_Discrete + assert D.pspace.set == S.Naturals0 + assert density(D)(z).simplify() == 3*2**z*exp(-2)/(5*factorial(z)) + S(64)*3**(-z)*binomial(z + 4, + z)/1215 + assert E(D) == S(11)/5 + assert cdf(D)(z).simplify() == Piecewise((1 - 3*lowergamma(z + 1, 2)/(5*gamma(z + 1) + ) - S(4)*3**(-z)*z**4/3645 - S(64)*3**(-z)*z**3/3645 - S(392)*3**(-z)*z**2/3645 -\ + S(1112)*3**(-z)*z/3645 - S(422)*3**(-z)/1215, z >= 0), (0, True)) + assert P(D > 2).simplify() == S(2813)/3645 - 3*exp(-2) + X = Geometric('X', S(2)/5) + Y = Logarithmic('Y', S(2)/3) + Z = YuleSimon('Z', 3) + B = Binomial('B', 2, S(1)/10) + wts = list(density(B).dict.values()) + D = Mixture('D', wts, [X, Y, Z]) + assert density(D)(z).simplify() == S(9)*10**z*15**(-z)/(50*z*log(3) + ) + S(3)*beta(z, 4)/100 + S(27)*15**(-z)*3**(2*z)/50 + assert cdf(D)(z).simplify() == Piecewise(((-12*2**z*3**(-z)*lerchphi(S(2)/3, 1, z + 1 + ) + (-81*3**z*5**(-z) - z*beta(z, 4) + 82)*log(3) + log(387420489))/(100*log(3)), + z >= 1), (0, True)) + assert E(D).simplify() == S(9)/(25*log(3)) + S(51)/25 + scipy = import_module('scipy') + if not scipy: + skip('Scipy is not installed. Abort tests') + else: + with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed + assert next(sample(D)) in D.pspace.set + samp = next(sample(D, size=5)) + for sam in samp: + assert sam in D.pspace.set + +def test_finite_mixture(): + X = Hypergeometric('X', 10, 5, 5) + Y = BetaBinomial('Y', 5, 1, 2) + Z = Binomial('Z', 5, S(2)/3) + B = Binomial('B', 2, S(1)/3) + wts = list(density(B).dict.values()) + D = Mixture('D', wts, [X, Y, Z]) + assert D.pspace.distribution.is_Finite + assert D.pspace.domain.set == {0, 1, 2, 3, 4, 5} + assert sum(density(D).dict.values()).evalf() == 1 + assert E(D) == S(40)*beta(2, 6)/9 + S(40)*beta(6, 2)/9 + S(160)*beta(3, 5 + )/9 + S(160)*beta(5, 3)/9 + S(80)*beta(4, 4)/3 + S(40)/27 + assert P(D < 5) == S(40)*beta(5, 3)/9 + S(80)*beta(4, 4)/9 + S(80)*beta(3, 5 + )/9 + S(40)*beta(2, 6)/9 + S(10198)/15309 + scipy = import_module('scipy') + if not scipy: + skip('Scipy is not installed. Abort tests') + else: + with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed + assert next(sample(D)) in D.pspace.set + samp = next(sample(D, size=5)) + for sam in samp: + assert sam in D.pspace.set + +def test_mixture_raises(): + wts = [1, 2, 3] + X = Hypergeometric('X', 10, 5, 5) + Y = BetaBinomial('Y', 5, 1, 2) + Z = Binomial('Z', 5, S(2)/3) + raises(ValueError, lambda: Mixture('D', wts, [X, Y, Z])) + raises(TypeError, lambda: Mixture('D', [S(2)/5, S(3)/5], [X, y])) + Z = Binomial('Z', 2, S(2)/3) + raises(ValueError, lambda: Mixture('D', wts, [X, Y, Z])) + raises(ValueError, lambda: Mixture('D', [-S(2)/5, S(3)/5], [X, Y])) + raises(ValueError, lambda: Mixture('D', [S(2)/5, S(2)/5, S(1)/5], [X, Y]))