# Probability Monad

Probability Monad in the context of discrete probability distributions
is a way to model computations that yield a distribution over possible
results, rather than a single result.

It follows the three Monadic Laws:

**1. Left Identity**

```python
pure(x).bind(f) == f(x)
```

**2. Right Identity**

```python
m.bind(pure) == m
```

**3. Associativity**

```python
m.bind(f).bind(g) == m.bind(lambda x: f(x).bind(g))
```

In [1]:
import collections
import random
import math
from typing import TypeVar, Callable, Dict, Union, Tuple, Iterator

A = TypeVar('A')
B = TypeVar('B')

class ProbabilityMonad:
    """
    A Monad for discrete probability distributions.

    Represents a probability distribution as a mapping from possible outcomes
    to their probabilities.

    The probabilities for all outcomes must sum to 1.0 (or very close to it
    due to floating point arithmetic).
    """

    # A small epsilon for floating point comparisons
    _EPSILON = 1e-9

    def __init__(self, distribution: Dict[A, float]):
        """
        Initializes a ProbabilityMonad with a given distribution.

        Args:
            distribution (Dict[A, float]): A dictionary where keys are outcomes
                                            and values are their probabilities.
                                            Probabilities must be non-negative.
        Raises:
            ValueError: If probabilities sum to significantly different from 1.0
                        or if any probability is negative.
        """
        if not isinstance(distribution, dict):
            raise TypeError("Distribution must be a dictionary.")

        normalized_dist = {}
        for outcome, prob in distribution.items():
            if not isinstance(prob, (int, float)):
                raise TypeError(f"Probability for outcome '{outcome}' is not a number: {prob}")
            if prob < -self._EPSILON:
                # Allow for tiny negative due to fp errors
                raise ValueError(f"Probability for outcome '{outcome}' cannot be negative: {prob}")
            if prob > self._EPSILON:
                # Only include non-zero probabilities
                normalized_dist[outcome] = float(prob)

        total_prob = sum(normalized_dist.values())

        if not math.isclose(total_prob, 1.0, rel_tol=self._EPSILON, abs_tol=self._EPSILON):
            if total_prob < self._EPSILON:
                # Handle empty or all-zero distributions
                self._distribution = {}
            else:
                # Attempt to normalize if it's close but not exact
                # This can happen if the input isn't perfectly normalized
                print(f"Warning: Probabilities sum to {total_prob}. Attempting to normalize.")
                self._distribution = {k: v / total_prob for k, v in normalized_dist.items()}
                if not math.isclose(sum(self._distribution.values()), 1.0, rel_tol=self._EPSILON, abs_tol=self._EPSILON):
                     raise ValueError(f"Probabilities must sum to 1.0. Got {total_prob} and normalization failed.")
        else:
            self._distribution = normalized_dist

        # Filter out extremely small probabilities for cleaner representation and computation
        self._distribution = {
            k: v for k, v in self._distribution.items() if v > self._EPSILON
        }

        # Re-normalize if filtering changed sum slightly
        current_sum = sum(self._distribution.values())
        if not math.isclose(current_sum, 1.0, rel_tol=self._EPSILON, abs_tol=self._EPSILON):
            if current_sum > self._EPSILON: # Avoid division by zero
                self._distribution = {k: v / current_sum for k, v in self._distribution.items()}


    @classmethod
    def pure(cls, value: A) -> 'ProbabilityMonad[A]':
        """
        The 'pure' or 'unit' function.
        Creates a ProbabilityMonad representing a certain outcome.

        Args:
            value (A): The certain outcome.

        Returns:
            ProbabilityMonad[A]: A monad where `value` has probability 1.0.
        """
        return cls({value: 1.0})

    @classmethod
    def from_weighted_choices(cls, choices: Dict[A, Union[int, float]]) -> 'ProbabilityMonad[A]':
        """
        Creates a ProbabilityMonad from a dictionary of choices with raw weights.
        The weights will be normalized to probabilities.

        Args:
            choices (Dict[A, Union[int, float]]): A dictionary mapping outcomes
                                                  to their relative weights.

        Returns:
            ProbabilityMonad[A]: A monad representing the normalized distribution.
        """
        total_weight = sum(choices.values())
        if total_weight == 0:
            raise ValueError("Total weight must be greater than zero.")
        normalized_dist = {outcome: weight / total_weight for outcome, weight in choices.items()}
        return cls(normalized_dist)

    def bind(self, func: Callable[[A], 'ProbabilityMonad[B]']) -> 'ProbabilityMonad[B]':
        """
        The 'bind' operation (`>>=`).
        Applies a function that returns another ProbabilityMonad to each outcome
        of the current monad, effectively flattening the nested distributions.

        Args:
            func (Callable[[A], ProbabilityMonad[B]]): A function that takes an
                                                        outcome from the current
                                                        distribution and returns
                                                        a new ProbabilityMonad.

        Returns:
            ProbabilityMonad[B]: A new monad representing the combined, flattened
                                 distribution.
        """
        new_distribution = collections.defaultdict(float)
        for outer_outcome, outer_prob in self._distribution.items():
            if outer_prob == 0:
                continue

            inner_monad = func(outer_outcome)
            if not isinstance(inner_monad, ProbabilityMonad):
                raise TypeError(f"The function passed to bind must return a ProbabilityMonad, but got {type(inner_monad)}")

            for inner_outcome, inner_prob in inner_monad.distribution.items():
                new_distribution[inner_outcome] += outer_prob * inner_prob

        return ProbabilityMonad(dict(new_distribution)) # Convert defaultdict back to dict

    def map_f(self, func: Callable[[A], B]) -> 'ProbabilityMonad[B]':
        """
        The 'fmap' operation.
        Applies a function to each outcome of the current monad, without changing
        the underlying probabilities.

        Args:
            func (Callable[[A], B]): A function to apply to each outcome.

        Returns:
            ProbabilityMonad[B]: A new monad with transformed outcomes.
        """
        # fmap is just a special case of bind: m >>= (lambda x: pure(f x))
        return self.bind(lambda x: ProbabilityMonad.pure(func(x)))

    @property
    def distribution(self) -> Dict[A, float]:
        """Returns the underlying probability distribution."""
        return self._distribution.copy()

    def expected_value(self) -> float:
        """
        Calculates the expected value of the distribution.
        This only makes sense if the outcomes are numerical.
        """
        if not self._distribution:
            return 0.0

        if not all(isinstance(k, (int, float)) for k in self._distribution.keys()):
            raise TypeError("Cannot calculate expected value: not all outcomes are numerical.")

        return sum(outcome * prob for outcome, prob in self._distribution.items())

    def sample(self, n: int = 1) -> Union[A, Tuple[A, ...]]:
        """
        Draws `n` random samples from the distribution.

        Args:
            n (int): The number of samples to draw. Defaults to 1.

        Returns:
            Union[A, Tuple[A, ...]]: A single sample if n=1, otherwise a tuple of samples.
        """
        if not self._distribution:
            if n == 1:
                return None
            return tuple([None] * n) # Return None for empty distribution

        outcomes = list(self._distribution.keys())
        probabilities = list(self._distribution.values())

        # random.choices requires lists of population and weights
        if not outcomes: # Empty distribution (e.g., after filtering extremely small probs)
             if n == 1: return None
             return tuple([None]*n)


        samples = random.choices(outcomes, weights=probabilities, k=n)

        if n == 1:
            return samples[0]
        return tuple(samples)

    def __repr__(self) -> str:
        return f"ProbabilityMonad({self._distribution})"

    def __str__(self) -> str:
        if not self._distribution:
            return "ProbabilityMonad(Empty Distribution)"
        
        # Sort for consistent output
        sorted_dist = sorted(self._distribution.items(), key=lambda item: item[1], reverse=True)
        
        s = "Probability Distribution:\n"
        for outcome, prob in sorted_dist:
            s += f"  {repr(outcome)}: {prob:.6f}\n"
        return s
    
    # Optional: Operator overloading for syntactic sugar
    def __rshift__(self, func: Callable[[A], 'ProbabilityMonad[B]']) -> 'ProbabilityMonad[B]':
        """
        Overloads the '>>' operator for bind (monadic sequencing).
        This is a common convention for bind-like operations in Python.
        `m >> f` is equivalent to `m.bind(f)`
        """
        return self.bind(func)

    def __or__(self, func: Callable[[A], B]) -> 'ProbabilityMonad[B]':
        """
        Overloads the '|' operator for fmap (functorial mapping).
        `m | f` is equivalent to `m.map_f(f)`
        """
        return self.map_f(func)

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, ProbabilityMonad):
            return NotImplemented
        # Compare distributions, accounting for float precision
        if len(self._distribution) != len(other._distribution):
            return False
        
        for k, v in self._distribution.items():
            if k not in other._distribution or not math.isclose(v, other._distribution[k], rel_tol=self._EPSILON, abs_tol=self._EPSILON):
                return False
        return True

    def __iter__(self) -> Iterator[Tuple[A, float]]:
        """Allows iteration over (outcome, probability) pairs."""
        return iter(self._distribution.items())

In [2]:
pm = ProbabilityMonad({"twilight": 0.8, "pinkie": 0.2})
pm

ProbabilityMonad({'twilight': 0.8, 'pinkie': 0.2})

## Examples

### Example 1: Fair Coin Flip

In [3]:
coin = ProbabilityMonad.from_weighted_choices({"Heads": 1, "Tails": 1})

print("1. Fair Coin:")
print(coin)

print(f"Sample coin (1): {coin.sample()}")
print(f"Sample coin (5): {coin.sample(5)}\n")

1. Fair Coin:
Probability Distribution:
  'Heads': 0.500000
  'Tails': 0.500000

Sample coin (1): Heads
Sample coin (5): ('Heads', 'Tails', 'Heads', 'Tails', 'Heads')



### Example 2: Loaded Die

In [None]:
loaded_die = ProbabilityMonad.from_weighted_choices({
    1: 1, 2: 1, 3: 1, 4: 1, 5: 1,
    6: 2 # is loaded!!
})

print("2. Loaded Die:")
print(loaded_die)
print(f"Expected value of loaded die: {loaded_die.expected_value():.3f}")
print(f"Sample loaded die (3): {loaded_die.sample(3)}\n")

2. Loaded Die:
Probability Distribution:
  6: 0.285714
  1: 0.142857
  2: 0.142857
  3: 0.142857
  4: 0.142857
  5: 0.142857

Expected value of loaded die: 3.857
Sample loaded die (3): (6, 5, 6)



### Example 3: Certain Outcome

In [None]:
certain_event = ProbabilityMonad.pure("Success")
print("3. Certain Event:")
print(certain_event)
print(f"Sample certain event: {certain_event.sample()}\n")

3. Certain Event:
Probability Distribution:
  'Success': 1.000000

Sample certain event: Success



### Example 4: Mapping

In [15]:
numbers = ProbabilityMonad({ 1: 0.3, 2: 0.5, 3: 0.2 })
squared_numbers = numbers.map_f(lambda x: x * x)
print("4. Squared Numbers (map_f):")
print(numbers)
print(f"Mapped to squares:\n{squared_numbers}")
print(f"Expected value of squared numbers: {squared_numbers.expected_value():.3f}\n")

4. Squared Numbers (map_f):
Probability Distribution:
  2: 0.500000
  1: 0.300000
  3: 0.200000

Mapped to squares:
Probability Distribution:
  4: 0.500000
  1: 0.300000
  9: 0.200000

Expected value of squared numbers: 4.100



### Example 5: Mapping with Pipe (|)

In [None]:
squared_numbers_op = numbers | (lambda x: x * x)
print(f"5. Mapped to squares (using | operator):\n{squared_numbers_op}")
assert squared_numbers == squared_numbers_op, "map_f operator not working as expected!"

Mapped to squares (using | operator):
Probability Distribution:
  4: 0.500000
  1: 0.300000
  9: 0.200000



### Example 6: Mapping Coin Outcomes to Uppercase

In [None]:
upper_coin = coin.map_f(lambda s: s.upper())
print("6. Uppercase Coin (map_f):")
print(upper_coin)
print("\n")

5. Uppercase Coin (map_f):
Probability Distribution:
  'HEADS': 0.500000
  'TAILS': 0.500000





### Example 7: Conditionals

In [None]:
def roll_if_heads_else_loaded_die(coin_outcome: str) -> ProbabilityMonad:
    if coin_outcome == "Heads":
        return ProbabilityMonad.from_weighted_choices({1:1, 2:1, 3:1, 4:1, 5:1, 6:1}) # Fair die
    elif coin_outcome == "Tails":
        return loaded_die # Our previously defined loaded die
    else:
        # Should not happen with well-formed initial coin distribution
        return ProbabilityMonad.pure(0) 

combined_roll = coin.bind(roll_if_heads_else_loaded_die)
print("7. Coin then Die (bind):")
print(combined_roll)
print(f"Expected value of combined roll: {combined_roll.expected_value():.3f}")
print(f"Sample combined roll (5): {combined_roll.sample(5)}\n")

# Verify probabilities (e.g., P(6) should be 0.5 * (1/6) + 0.5 * (2/7) )
# P(6) = 1/12 + 1/7 = 7/84 + 12/84 = 19/84 = 0.22619
if 6 in combined_roll.distribution:
    print(f"P(6) in combined roll: {combined_roll.distribution[6]:.5f} (Expected: 0.22619)\n")


6. Coin then Die (bind):
Probability Distribution:
  6: 0.226190
  1: 0.154762
  2: 0.154762
  3: 0.154762
  4: 0.154762
  5: 0.154762

Expected value of combined roll: 3.679
Sample combined roll (5): (3, 2, 1, 3, 4)

P(6) in combined roll: 0.22619 (Expected: 0.22619)



### Example 8: Roll Two Dice, Sum Results

In [17]:
def roll_die() -> ProbabilityMonad:
    return ProbabilityMonad.from_weighted_choices({1:1, 2:1, 3:1, 4:1, 5:1, 6:1})

# First die roll
first_roll = roll_die()

# Now, bind the first roll to a function that performs the second roll
# and combines the results.
sum_of_two_dice = first_roll.bind(
    lambda r1: roll_die().bind( # For each result of r1, bind the second roll
        lambda r2: ProbabilityMonad.pure(r1 + r2) # Then pure the sum
    )
)
print("8. Sum of Two Dice (bind chain):")
print(sum_of_two_dice)
print(f"Expected value of sum of two dice: {sum_of_two_dice.expected_value():.3f}\n")


8. Sum of Two Dice (bind chain):
Probability Distribution:
  7: 0.166667
  6: 0.138889
  8: 0.138889
  5: 0.111111
  9: 0.111111
  4: 0.083333
  10: 0.083333
  3: 0.055556
  11: 0.055556
  2: 0.027778
  12: 0.027778

Expected value of sum of two dice: 7.000



### Example 9: Using the >> Operator for Bind

In [19]:
sum_of_two_dice_op = roll_die() >> (
    lambda r1: roll_die() >> (lambda r2: ProbabilityMonad.pure(r1 + r2))
)
print(f"9. Sum of Two Dice (using >> operator):\n{sum_of_two_dice_op}")
assert sum_of_two_dice == sum_of_two_dice_op, "bind operator not working as expected!"

9. Sum of Two Dice (using >> operator):
Probability Distribution:
  7: 0.166667
  6: 0.138889
  8: 0.138889
  5: 0.111111
  9: 0.111111
  4: 0.083333
  10: 0.083333
  3: 0.055556
  11: 0.055556
  2: 0.027778
  12: 0.027778



### Example 10: Chained Operations on Apples

Scenario: You have a number of apples. You give away some, then gain some.


In [20]:
# Initial apples: 10 with 50% prob, 12 with 50% prob
initial_apples = ProbabilityMonad({10: 0.5, 12: 0.5})

# First action: Give away 1 or 2 apples (each 50% prob)
def give_away(current_apples: int) -> ProbabilityMonad:
    return ProbabilityMonad({current_apples - 1: 0.5, current_apples - 2: 0.5})

# Second action: Gain 3 or 4 apples (each 50% prob)
def gain_some(current_apples: int) -> ProbabilityMonad:
    return ProbabilityMonad({current_apples + 3: 0.5, current_apples + 4: 0.5})

# Chain the operations
final_apples = initial_apples.bind(give_away).bind(gain_some)
# Or using the operator:
# final_apples_op = initial_apples >> give_away >> gain_some

print("10. Chained Operations (Apples):")
print(final_apples)
print(f"Expected final apples: {final_apples.expected_value():.3f}\n")

10. Chained Operations (Apples):
Probability Distribution:
  12: 0.250000
  13: 0.250000
  14: 0.250000
  11: 0.125000
  15: 0.125000

Expected final apples: 13.000



### Example 11: Error Handling

In [22]:
try:
    bad_prob = ProbabilityMonad({"A": 0.6, "B": 0.3}) # Sums to 0.9
    print("This should warn, then normalize:")
    print(bad_prob)
except ValueError as e:
    print(f"Caught expected error for bad sum: {e}\n")

try:
    negative_prob = ProbabilityMonad({"A": 0.5, "B": -0.5, "C": 1.0})
except ValueError as e:
    print(f"Caught expected error for negative probability: {e}\n")

try:
    # .bind(_) should return a Monad
    non_monad_return = ProbabilityMonad.pure(1).bind(lambda x: x * 2)
except TypeError as e:
    print(f"Caught expected error for non-monad return from bind: {e}\n")

This should warn, then normalize:
Probability Distribution:
  'A': 0.666667
  'B': 0.333333

Caught expected error for negative probability: Probability for outcome 'B' cannot be negative: -0.5

Caught expected error for non-monad return from bind: The function passed to bind must return a ProbabilityMonad, but got <class 'int'>



## Limitations

**Discrete Only.** This implementation is for discrete probability
distributions (finite or countably infinite outcomes). Modeling
continuous probability distributions would require significantly
different mathematical tools (measure theory, integrals, etc.) and
a much more complex data structure (e.g., functions representing
PDFs/CDFs or Monte Carlo simulations).

**Performance for Large Distributions.** For very large discrete
distributions, dictionary operations might become slow. For extremely
large distributions, a sparse representation or alternative data
structures might be needed.

**Floating Point Precision.** Probability calculations are inherently
susceptible to floating-point errors. I've added an `_EPSILON` for
comparisons and included normalization logic to mitigate this, but
it's always a concern.

**Type Safety.** While Python's type hints help, runtime type
checking for `func` in `bind` and `map_f` is still limited to
checking `isinstance(inner_monad, ProbabilityMonad)`.