Skip to content

Commit

Permalink
pdfs.GammaPdf: implement sample() incl. tests
Browse files Browse the repository at this point in the history
  • Loading branch information
strohel committed Aug 20, 2012
1 parent fb82839 commit 9360a1c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
3 changes: 3 additions & 0 deletions pybayes/pdfs.py
Expand Up @@ -678,6 +678,9 @@ def eval_log(self, x, cond = None):
return -math.lgamma(self.k) - self.k*math.log(self.theta) + (self.k - 1)*float('-inf') return -math.lgamma(self.k) - self.k*math.log(self.theta) + (self.k - 1)*float('-inf')
return -math.lgamma(self.k) - self.k*math.log(self.theta) + (self.k - 1)*math.log(x[0]) - x[0]/self.theta return -math.lgamma(self.k) - self.k*math.log(self.theta) + (self.k - 1)*math.log(x[0]) - x[0]/self.theta


def sample(self, cond = None):
return random.gamma(self.k, self.theta, size=(1,))



class AbstractEmpPdf(Pdf): class AbstractEmpPdf(Pdf):
r"""An abstraction of empirical probability density functions that provides common methods such r"""An abstraction of empirical probability density functions that provides common methods such
Expand Down
13 changes: 13 additions & 0 deletions pybayes/tests/test_pdfs.py
Expand Up @@ -518,6 +518,19 @@ def test_eval_log(self):
self.assertApproxEqual(exp(self.gamma1.eval_log(x)), exp_results[i][0]) self.assertApproxEqual(exp(self.gamma1.eval_log(x)), exp_results[i][0])
self.assertApproxEqual(exp(self.gamma2.eval_log(x)), exp_results[i][1]) self.assertApproxEqual(exp(self.gamma2.eval_log(x)), exp_results[i][1])


@stochastic
def test_sample(self):
"""Test GaussPdf.sample() mean and variance."""
N = 500 # number of samples
emp1 = pb.EmpPdf(self.gamma1.samples(N)) # Emipirical pdf computes sample mean and variance for us
emp2 = pb.EmpPdf(self.gamma2.samples(N)) # Emipirical pdf computes sample mean and variance for us

self.assertTrue(np.all(abs(emp1.mean() - self.gamma1.mean()) <= 0.4))
self.assertTrue(np.all(abs(emp2.mean() - self.gamma2.mean()) <= 0.3))

self.assertTrue(np.all(abs(emp1.variance() - self.gamma1.variance()) <= 2.2))
self.assertTrue(np.all(abs(emp2.variance() - self.gamma2.variance()) <= 1.3))



class TestEmpPdf(PbTestCase): class TestEmpPdf(PbTestCase):
"""Test empirical pdf""" """Test empirical pdf"""
Expand Down

0 comments on commit 9360a1c

Please sign in to comment.