Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9a57bef
commit d2fdd4c
Showing
5 changed files
with
299 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
from sympy import Basic, ImmutableMatrix, S, Symbol, Lambda | ||
from sympy.core.sympify import sympify | ||
from sympy.stats.rv import _symbol_converter, RandomSymbol, NamedArgsMixin, PSpace | ||
from sympy.stats.crv import SingleContinuousPSpace | ||
from sympy.stats.frv import SingleFinitePSpace | ||
from sympy.stats.drv import SingleDiscretePSpace | ||
from sympy.stats.crv_types import ContinuousDistributionHandmade | ||
from sympy.stats.drv_types import DiscreteDistributionHandmade | ||
from sympy.stats.frv_types import FiniteDistributionHandmade | ||
|
||
|
||
class MixturePSpace(PSpace): | ||
""" | ||
A temporary Probability Space for the Mixture Distribution | ||
""" | ||
|
||
def __new__(cls, s, distribution): | ||
s = _symbol_converter(s) | ||
if not isinstance(distribution, MixtureDistribution): | ||
raise ValueError("%s should be an isinstance of " | ||
"MixtureDistribution"%(distribution)) | ||
return Basic.__new__(cls, s, distribution) | ||
|
||
@property | ||
def value(self): | ||
return RandomSymbol(self.symbol, self) | ||
|
||
@property | ||
def symbol(self): | ||
return self.args[0] | ||
|
||
@property | ||
def distribution(self): | ||
return self.args[1] | ||
|
||
@property | ||
def pdf(self): | ||
return self.distribution.pdf(self.symbol) | ||
|
||
@property | ||
def set(self): | ||
return self.distribution.set | ||
|
||
@property | ||
def domain(self): | ||
return self._get_newpspace().domain | ||
|
||
def _get_newpspace(self): | ||
new_pspace = self._transform_pspace() | ||
if new_pspace is not None: | ||
return new_pspace | ||
message = ("Mixture Distribution for %s is not implemeted yet" % str(self.distribution)) | ||
raise NotImplementedError(message) | ||
|
||
def _transform_pspace(self): | ||
""" | ||
This function returns the new pspace of the distribution using handmade | ||
Distributions and their corresponding pspace. | ||
""" | ||
pdf = Lambda(self.symbol, self.pdf) | ||
_set = self.distribution.set | ||
if self.distribution.is_Continuous: | ||
return SingleContinuousPSpace(self.symbol, ContinuousDistributionHandmade(pdf, _set)) | ||
elif self.distribution.is_Discrete: | ||
return SingleDiscretePSpace(self.symbol, DiscreteDistributionHandmade(pdf, _set)) | ||
elif self.distribution.is_Finite: | ||
dens = dict((k, pdf(k)) for k in _set) | ||
return SingleFinitePSpace(self.symbol, FiniteDistributionHandmade(dens)) | ||
|
||
def compute_density(self, expr, **kwargs): | ||
new_pspace = self._get_newpspace() | ||
expr = expr.subs({self.value: new_pspace.value}) | ||
return new_pspace.compute_density(expr, **kwargs) | ||
|
||
def compute_cdf(self, expr, **kwargs): | ||
new_pspace = self._get_newpspace() | ||
expr = expr.subs({self.value: new_pspace.value}) | ||
return new_pspace.compute_cdf(expr, **kwargs) | ||
|
||
def compute_expectation(self, expr, rvs=None, evaluate=False, **kwargs): | ||
new_pspace = self._get_newpspace() | ||
expr = expr.subs({self.value: new_pspace.value}) | ||
if isinstance(new_pspace, SingleFinitePSpace): | ||
return new_pspace.compute_expectation(expr, rvs, **kwargs) | ||
return new_pspace.compute_expectation(expr, rvs, evaluate, **kwargs) | ||
|
||
def probability(self, condition, **kwargs): | ||
new_pspace = self._get_newpspace() | ||
condition = condition.subs({self.value: new_pspace.value}) | ||
return new_pspace.probability(condition) | ||
|
||
def conditional_space(self, condition, **kwargs): | ||
new_pspace = self._get_newpspace() | ||
condition = condition.subs({self.value: new_pspace.value}) | ||
return new_pspace.conditional_space(condition) | ||
|
||
def sample(self, size=(), library='scipy'): | ||
new_pspace = self._get_newpspace() | ||
samp = new_pspace.sample(size, library) | ||
return {self.value: samp[new_pspace.value]} | ||
|
||
def rv(symbol, cls, args): | ||
args = list(map(sympify, args)) | ||
dist = cls(*args) | ||
dist.check(*args) | ||
pspace = MixturePSpace(symbol, dist) | ||
return pspace.value | ||
|
||
|
||
class MixtureDistribution(Basic, NamedArgsMixin): | ||
"""Represents the Mixture distribution""" | ||
_argnames = ('wts', 'rvs') | ||
|
||
def __new__(cls, wts, rvs): | ||
wts = ImmutableMatrix(wts) | ||
rvs = ImmutableMatrix(rvs) | ||
return Basic.__new__(cls, wts, rvs) | ||
|
||
@property | ||
def set(self): | ||
return (self.rvs[0, 0]).pspace.distribution.set | ||
|
||
@property | ||
def is_Continuous(self): | ||
return (self.rvs[0, 0]).pspace.is_Continuous | ||
|
||
@property | ||
def is_Discrete(self): | ||
return (self.rvs[0, 0]).pspace.is_Discrete | ||
|
||
@property | ||
def is_Finite(self): | ||
return (self.rvs[0, 0]).pspace.is_Finite | ||
|
||
@staticmethod | ||
def check(wts, rvs): | ||
rvs, wts = list(rvs), list(wts) | ||
set_ = rvs[0].pspace.domain.set | ||
for rv in rvs: | ||
if not isinstance(rv, RandomSymbol): | ||
raise TypeError("Each of element should be a random variable") | ||
if rv.pspace.domain.set != set_: | ||
raise ValueError("Each random variable should be defined on same set") | ||
for wt in wts: | ||
if not wt.is_positive: | ||
raise ValueError("Weight of each random variable should be positive") | ||
if len(rvs) != len(wts): | ||
raise ValueError("Weights and RVs should be of same length") | ||
if sum(wts) != S.One: | ||
raise ValueError("Sum of the weights should be 1") | ||
|
||
def pdf(self, x): | ||
y = Symbol('y') | ||
rvs, wts = list(self.rvs), list(self.wts) | ||
pdf_ = S.Zero | ||
if self.is_Finite: | ||
for rv in range(len(rvs)): | ||
pdf_ = pdf_ + wts[rv]*rvs[rv].pspace.distribution.pmf(y) | ||
else: | ||
for rv in range(len(rvs)): | ||
pdf_ = pdf_ + wts[rv]*rvs[rv].pspace.distribution.pdf(y) | ||
return Lambda(y, pdf_)(x) | ||
|
||
def Mixture(name, wts, rvs): | ||
"""Creates a random variable with mixture distribution""" | ||
return rv(name, MixtureDistribution, (wts, rvs)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
from sympy import (S, symbols, sqrt, Abs, exp, pi, oo, erf, erfc, Integral, | ||
Interval, Dummy, factorial, binomial, log, beta, lerchphi, | ||
Piecewise, lowergamma, gamma) | ||
from sympy.stats import (P, E, density, sample, cdf, Normal, Laplace, Mixture, | ||
NegativeBinomial, Poisson, Geometric, YuleSimon, Logarithmic, | ||
Binomial, Hypergeometric, BetaBinomial) | ||
from sympy.stats.mixture_rv import MixtureDistribution, MixturePSpace | ||
from sympy.testing.pytest import raises, skip, ignore_warnings | ||
from sympy.external import import_module | ||
|
||
y, z = symbols('y z') | ||
def test_continuous_mixture(): | ||
N = Normal("N", 0, 1) | ||
M = Normal('M', 1, 2) | ||
Z = Laplace('L', 3, 1) | ||
D = Mixture('D', [S(2)/10, S(5)/10, S(3)/10], [N, M, Z]) | ||
assert D.pspace.distribution.is_Continuous | ||
assert isinstance(D.pspace.distribution, MixtureDistribution) | ||
assert isinstance(D.pspace, MixturePSpace) | ||
assert D.pspace.set == Interval(-oo, oo) | ||
assert density(D)(z) == 3*exp(-Abs(z - 3))/20 + sqrt(2)*exp(-(z - 1)**2/8 | ||
)/(8*sqrt(pi)) + sqrt(2)*exp(-z**2/2)/(10*sqrt(pi)) | ||
assert E(D).simplify() == S(7)/5 | ||
k = Dummy('k') | ||
cdf_expr = erf(sqrt(2)*(z - 1)/4)/4 - erfc(sqrt(2)*z/2)/10 + 3*Integral( | ||
exp(-Abs(k - 3)), (k, -oo, z))/20 + S(9)/20 | ||
assert cdf(D)(z).simplify().dummy_eq(cdf_expr) | ||
prob = Integral(3*exp(-Abs(k - 3))/20 + sqrt(2)*exp(-(k - 1)**2/8)/( | ||
8*sqrt(pi)) + sqrt(2)*exp(-k**2/2)/(10*sqrt(pi)), (k, 0, oo)) | ||
|
||
with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed | ||
assert P(D > 0, evaluate=False).rewrite(Integral).dummy_eq(prob) | ||
scipy = import_module('scipy') | ||
if not scipy: | ||
skip('Scipy is not installed. Abort tests') | ||
else: | ||
with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed | ||
assert next(sample(D)) in D.pspace.set | ||
samp = next(sample(D, size=5)) | ||
for sam in samp: | ||
assert sam in D.pspace.set | ||
|
||
def test_discrete_mixture(): | ||
X = NegativeBinomial('X', 5, S(1)/3) | ||
Y = Poisson('Y', 2) | ||
D = Mixture('D', [S(2)/5, S(3)/5], [X, Y]) | ||
assert D.pspace.distribution.is_Discrete | ||
assert D.pspace.set == S.Naturals0 | ||
assert density(D)(z).simplify() == 3*2**z*exp(-2)/(5*factorial(z)) + S(64)*3**(-z)*binomial(z + 4, | ||
z)/1215 | ||
assert E(D) == S(11)/5 | ||
assert cdf(D)(z).simplify() == Piecewise((1 - 3*lowergamma(z + 1, 2)/(5*gamma(z + 1) | ||
) - S(4)*3**(-z)*z**4/3645 - S(64)*3**(-z)*z**3/3645 - S(392)*3**(-z)*z**2/3645 -\ | ||
S(1112)*3**(-z)*z/3645 - S(422)*3**(-z)/1215, z >= 0), (0, True)) | ||
assert P(D > 2).simplify() == S(2813)/3645 - 3*exp(-2) | ||
X = Geometric('X', S(2)/5) | ||
Y = Logarithmic('Y', S(2)/3) | ||
Z = YuleSimon('Z', 3) | ||
B = Binomial('B', 2, S(1)/10) | ||
wts = list(density(B).dict.values()) | ||
D = Mixture('D', wts, [X, Y, Z]) | ||
assert density(D)(z).simplify() == S(9)*10**z*15**(-z)/(50*z*log(3) | ||
) + S(3)*beta(z, 4)/100 + S(27)*15**(-z)*3**(2*z)/50 | ||
assert cdf(D)(z).simplify() == Piecewise(((-12*2**z*3**(-z)*lerchphi(S(2)/3, 1, z + 1 | ||
) + (-81*3**z*5**(-z) - z*beta(z, 4) + 82)*log(3) + log(387420489))/(100*log(3)), | ||
z >= 1), (0, True)) | ||
assert E(D).simplify() == S(9)/(25*log(3)) + S(51)/25 | ||
scipy = import_module('scipy') | ||
if not scipy: | ||
skip('Scipy is not installed. Abort tests') | ||
else: | ||
with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed | ||
assert next(sample(D)) in D.pspace.set | ||
samp = next(sample(D, size=5)) | ||
for sam in samp: | ||
assert sam in D.pspace.set | ||
|
||
def test_finite_mixture(): | ||
X = Hypergeometric('X', 10, 5, 5) | ||
Y = BetaBinomial('Y', 5, 1, 2) | ||
Z = Binomial('Z', 5, S(2)/3) | ||
B = Binomial('B', 2, S(1)/3) | ||
wts = list(density(B).dict.values()) | ||
D = Mixture('D', wts, [X, Y, Z]) | ||
assert D.pspace.distribution.is_Finite | ||
assert D.pspace.domain.set == {0, 1, 2, 3, 4, 5} | ||
assert sum(density(D).dict.values()).evalf() == 1 | ||
assert E(D) == S(40)*beta(2, 6)/9 + S(40)*beta(6, 2)/9 + S(160)*beta(3, 5 | ||
)/9 + S(160)*beta(5, 3)/9 + S(80)*beta(4, 4)/3 + S(40)/27 | ||
assert P(D < 5) == S(40)*beta(5, 3)/9 + S(80)*beta(4, 4)/9 + S(80)*beta(3, 5 | ||
)/9 + S(40)*beta(2, 6)/9 + S(10198)/15309 | ||
scipy = import_module('scipy') | ||
if not scipy: | ||
skip('Scipy is not installed. Abort tests') | ||
else: | ||
with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed | ||
assert next(sample(D)) in D.pspace.set | ||
samp = next(sample(D, size=5)) | ||
for sam in samp: | ||
assert sam in D.pspace.set | ||
|
||
def test_mixture_raises(): | ||
wts = [1, 2, 3] | ||
X = Hypergeometric('X', 10, 5, 5) | ||
Y = BetaBinomial('Y', 5, 1, 2) | ||
Z = Binomial('Z', 5, S(2)/3) | ||
raises(ValueError, lambda: Mixture('D', wts, [X, Y, Z])) | ||
raises(TypeError, lambda: Mixture('D', [S(2)/5, S(3)/5], [X, y])) | ||
Z = Binomial('Z', 2, S(2)/3) | ||
raises(ValueError, lambda: Mixture('D', wts, [X, Y, Z])) | ||
raises(ValueError, lambda: Mixture('D', [-S(2)/5, S(3)/5], [X, Y])) | ||
raises(ValueError, lambda: Mixture('D', [S(2)/5, S(2)/5, S(1)/5], [X, Y])) |