diff --git a/doc/pdfs.rst b/doc/pdfs.rst index 0ad8d28..c4005c9 100644 --- a/doc/pdfs.rst +++ b/doc/pdfs.rst @@ -5,6 +5,13 @@ Probability Density Functions .. automodule:: pybayes.pdfs :no-members: +Random Variables and their Components +===================================== + +.. autoclass:: RV + +.. autoclass:: RVComp + Probability Density Function prototype ====================================== diff --git a/pybayes/__init__.py b/pybayes/__init__.py index 7004aee..5632977 100644 --- a/pybayes/__init__.py +++ b/pybayes/__init__.py @@ -7,5 +7,5 @@ Bayesian statistics... TODO """ -from pdfs import CPdf, Pdf, UniPdf, GaussPdf, ProdPdf, MLinGaussCPdf, ProdCPdf +from pdfs import RVComp, RV, CPdf, Pdf, UniPdf, GaussPdf, ProdPdf, MLinGaussCPdf, ProdCPdf from kalman import Kalman diff --git a/pybayes/pdfs.pxd b/pybayes/pdfs.pxd index 6c94011..72dcbc0 100644 --- a/pybayes/pdfs.pxd +++ b/pybayes/pdfs.pxd @@ -8,6 +8,19 @@ cimport cython from numpywrap cimport * +cdef class RVComp(object): + cdef readonly int dimension + cdef readonly str name + + +cdef class RV(object): + cdef readonly int dimension + cdef readonly str name + cdef readonly list components + + cpdef bint contains(self, RVComp component) + + cdef class CPdf(object): cpdef int shape(self) except -1 cpdef int cond_shape(self) except -1 diff --git a/pybayes/pdfs.py b/pybayes/pdfs.py index 047cbc4..9a089d4 100644 --- a/pybayes/pdfs.py +++ b/pybayes/pdfs.py @@ -16,6 +16,57 @@ from numpywrap import * +class RVComp(object): + """Atomic component of a random value""" + + def __init__(self, name, dimension): + """Initialise new component of a random variable. It will be named + `name` with dimension `dimension`""" + + self.name = str(name) + if not isinstance(dimension, int): + raise TypeError("dimension must be integer (int)") + if dimension < 1: + raise ValueError("dimension must be non-zero positive") + self.dimension = dimension + + +class RV(object): + """Representation of a random variable made of one or more components + (see RVComp)""" + + def __init__(self, *components): + """Initialise new random variable made of one or more components. You may + also pass one or more existing RVs, whose components will be reused""" + self.dimension = 0 + self.name = '[' + self.components = [] + for component in components: + if isinstance(component, RVComp): + self._add_component(component) + elif isinstance(component, RV): + for subcomp in component.components: + self._add_component(subcomp) + else: + raise TypeError('component ' + component + ' is neither an instance ' + + 'of RVComp or RV') + self.name = self.name[:-2] + ']' + + def _add_component(self, component): + # TODO: check if component is already contained? (does it matter somewhere?) + self.components.append(component) + self.dimension += component.dimension + self.name += component.name + ", " + + def contains(self, component): + """Return True is this random value contains the exact same instance of + the `component`""" + for comp in self.components: + if id(comp) == id(component): + return True + return False + + class CPdf(object): """Base class for all Conditional Probability Density Functions diff --git a/pybayes/tests/test_pdfs.py b/pybayes/tests/test_pdfs.py index 7f8f309..239deac 100644 --- a/pybayes/tests/test_pdfs.py +++ b/pybayes/tests/test_pdfs.py @@ -12,6 +12,36 @@ from support import PbTestCase +class TestRVComp(PbTestCase): + """Test random variable component""" + + def test_init(self): + rvcomp = pb.RVComp("pretty name", 123) + self.assertEquals(rvcomp.name, "pretty name") + self.assertEquals(rvcomp.dimension, 123) + + def test_invalid_init(self): + self.assertRaises(TypeError, pb.RVComp, "def", 0.45) + self.assertRaises(TypeError, pb.RVComp, "def", "not a number") + self.assertRaises(ValueError, pb.RVComp, "abc", -1) + + +class TestRV(PbTestCase): + """Test random variable representation""" + + def test_init(self): + comp_a = pb.RVComp("a", 1) + comp_b = pb.RVComp("b", 2) + rv_1 = pb.RV(comp_a, comp_b) + self.assertEquals(rv_1.name, "[a, b]") + self.assertEquals(rv_1.dimension, 3) + self.assertTrue(rv_1.contains(comp_a)) + self.assertTrue(rv_1.contains(comp_b)) + rv_2 = pb.RV(rv_1) + self.assertEquals(rv_2.name, "[a, b]") + self.assertEquals(rv_2.dimension, 3) + + class TestCpdf(PbTestCase): """Test abstract class CPdf"""