From 22f67908d9a385ddedc40943a0c3c462b9cd54ce Mon Sep 17 00:00:00 2001 From: Sam Ireland Date: Fri, 16 Mar 2018 20:42:14 +0000 Subject: [PATCH] Add concept of 'or', 'and' and 'outcomes' to events --- inferi/probability.py | 29 +++++++++++++- tests/integration/test_probability.py | 4 ++ tests/unit/test_events.py | 58 +++++++++++++++++++++++++++ 3 files changed, 90 insertions(+), 1 deletion(-) diff --git a/inferi/probability.py b/inferi/probability.py index 03b0ecc..0fc0313 100644 --- a/inferi/probability.py +++ b/inferi/probability.py @@ -32,6 +32,18 @@ def __contains__(self, member): if event._outcome == member: return True + def __or__(self, event): + if not isinstance(event, Event): + raise TypeError(f"{event} is not an Event") + return Event(*(self._simple_events | event._simple_events)) + + + def __and__(self, event): + if not isinstance(event, Event): + raise TypeError(f"{event} is not an Event") + return Event(*(self._simple_events & event._simple_events)) + + def simple_events(self): """The set of simple events in this event. @@ -70,6 +82,18 @@ def mutually_exclusive_with(self, event): return not self._simple_events & event._simple_events + def outcomes(self, p=False): + """The set of outcomes that the event's simple events can produce. + + :param bool p: if ``True``, the results will be returned as a dict with + probabilities associated. + + :rtype: ``set`` or ``dict``""" + + if p: return {e._outcome: e._probability for e in self._simple_events} + return set([e._outcome for e in self._simple_events]) + + class SimpleEvent(Event): """Base class: py:class:`.Event` @@ -146,8 +170,11 @@ def simple_events(self): def outcomes(self, p=False): """The set of outcomes that the sample space's simple events can produce. + + :param bool p: if ``True``, the results will be returned as a dict with + probabilities associated. - :rtype: ``set``""" + :rtype: ``set`` or ``dict``""" if p: return {e.outcome(): e.probability() for e in self._simple_events} return set([e.outcome() for e in self._simple_events]) diff --git a/tests/integration/test_probability.py b/tests/integration/test_probability.py index c754e95..2b45a8e 100644 --- a/tests/integration/test_probability.py +++ b/tests/integration/test_probability.py @@ -83,6 +83,10 @@ def test_events(self): self.assertTrue(event2.mutually_exclusive_with(sample_space.event(1))) self.assertFalse(event2.mutually_exclusive_with(sample_space.event(2))) self.assertFalse(event1.mutually_exclusive_with(event2)) + combined = event1 | event2 + self.assertEqual(combined.outcomes(), {2, 4, 5, 6}) + combined = event1 & event2 + self.assertEqual(combined.outcomes(), {2}) # Unfair die sample_space = inferi.SampleSpace(1, 2, 3, 4, 5, 6, p={4: 0.3}) diff --git a/tests/unit/test_events.py b/tests/unit/test_events.py index 946c763..2617c77 100644 --- a/tests/unit/test_events.py +++ b/tests/unit/test_events.py @@ -78,6 +78,48 @@ def test_events_in_event(self): +class EventOrTests(EventTest): + + def test_can_get_event_or(self): + event = Event(*self.simple_events[:3]) + mock_event = Mock(Event) + mock_event._simple_events = set(self.simple_events[3:6]) + new = event | mock_event + self.assertEqual(new._simple_events, set(self.simple_events[:6])) + + + def test_or_needs_event(self): + event = Event(*self.simple_events[:3]) + with self.assertRaises(TypeError): + event | "event" + + + +class EventAndTests(EventTest): + + def test_can_get_event_and(self): + event = Event(*self.simple_events[:4]) + mock_event = Mock(Event) + mock_event._simple_events = set(self.simple_events[3:6]) + new = event & mock_event + self.assertEqual(new._simple_events, set(self.simple_events[3:4])) + + + def test_can_get_empty_event_and(self): + event = Event(*self.simple_events[:3]) + mock_event = Mock(Event) + mock_event._simple_events = set(self.simple_events[3:6]) + new = event & mock_event + self.assertEqual(new._simple_events, set()) + + + def test_and_needs_event(self): + event = Event(*self.simple_events[:3]) + with self.assertRaises(TypeError): + event | "event" + + + class EventSimpleEvents(EventTest): def test_can_get_event_simple_events(self): @@ -103,6 +145,22 @@ def test_can_get_event_probability(self): +class EventOutcomestests(EventTest): + + def test_can_get_outcomes(self): + event = Event(*self.simple_events) + outcomes = event.outcomes() + self.assertEqual(outcomes, set(range(1, 11))) + + + def test_can_get_outcomes_with_odds(self): + event = Event(*self.simple_events) + outcomes = event.outcomes(p=True) + self.assertEqual(outcomes, {i: 5 for i in range(1, 11)}) + + + + class EventMutualExclusivityTests(EventTest): def test_not_mutually_exclusive_if_simple_events_in_common(self):