## Let's do the urn problem again with a bigger set.

A jar has $t$ tokens in it.

Of these, $r$ are red and $t-r$ are black 

We select tokens from the jar, with replacement.

We know $t$ but we don't know $r$

In [None]:
import random
 
from fractions import Fraction
from enum import Enum
from plotly import graph_objects as go
from plotly.subplots import make_subplots


In [2]:
rng = random.Random()
TOKENS = 100


class Token(Enum):
    red = "red"
    black = "black"

    def __repr__(self) -> str:
        return f"{self.value}"


class Jar:
    def __init__(self):
        self.red_tokens = rng.randint(0, TOKENS)

    def __repr__(self) -> str:
        return "Jar(red_tokens=****)"

    def observe(self):
        return Token.black if rng.randint(0, TOKENS) > self.red_tokens else Token.red

In this problem, a *hypothesis* is a number between 0 and 100, and represents 
the number of red tokens we believe is in the jar.

In [3]:
def p_e_given_h(evidence: Token, hypothesis: int) -> Fraction:
    assert isinstance(evidence, Token)
    assert 0 <= hypothesis <= TOKENS
    # red_count = hypothesis

    return {
        Token.red: Fraction(hypothesis, TOKENS),
        Token.black: Fraction((TOKENS - hypothesis), TOKENS),
    }[evidence]


assert p_e_given_h(Token.black, TOKENS) == 0
assert p_e_given_h(Token.black, 0) == 1
assert p_e_given_h(Token.red, TOKENS) == 1
assert p_e_given_h(Token.red, 0) == 0


In [4]:
def p_e(evidence, priors):
    """
    My current belief about the likelihood of `E` occuring at all
    is stated by integrating all possible hypotheses
    """
    assert sum(priors.values()) == 1, priors.values()

    return sum(
        [
            p_e_given_h(evidence, hypothesis) * prior
            for hypothesis, prior in priors.items()
        ]
    )


def infer(evidence: Token, priors: dict[int, Fraction]) -> dict[str, Fraction]:
    "apply Bayes formula to a set of priors and a piece of evidence"
    assert sum(priors.values()) == 1

    p_e_ = p_e(evidence, priors)
    return {
        hypothesis: Fraction(p_e_given_h(evidence, hypothesis) * p_h, p_e_)
        for hypothesis, p_h in priors.items()
    }


def inference(jar: Jar, priors, hypothesis):
    yield dict(evidence=None, prior=priors[hypothesis])
    while True:
        evidence = jar.observe()
        posteriors = infer(evidence, priors)
        priors = posteriors
        yield dict(evidence=evidence, prior=priors[hypothesis])

In [5]:
jar = Jar()

steps = 100
priors = {count: Fraction(1, 1 + TOKENS) for count in range(0, 1 + TOKENS)}
assert sum(priors.values()) == 1


for step in range(1, 1 + steps):
    evidence = jar.observe()
    posteriors = infer(evidence, priors)
    priors = posteriors
count, p = max(priors.items(), key=lambda pair: pair[1])
print(f"Best guess after {steps} steps: {count}: {p:.2f}")
print(f"Actual: {jar.red_tokens}")

Best guess after 100 steps: 69: 0.09
Actual: 73


In [6]:
def simulate(steps, jar: Jar, priors=None):
    if priors is None:
        priors = {count: Fraction(1, 1 + TOKENS) for count in range(0, 1 + TOKENS)}
    assert sum(priors.values()) == 1

    assert all([prior>0 for prior in priors.values()])
    fig = make_subplots(rows=1 + steps, cols=1)
    fig.add_trace(
        go.Bar(
            x=[int(count) for count, _ in priors.items()],
            y=[float(p) for _, p in priors.items()],
            name="Initial",
        ),
        row=1,
        col=1,
    )
    fig.update_yaxes(range=(0, 0.2), row=1, col=1)

    max_y = 0
    for step in range(1, 1 + steps):
        evidence = jar.observe()
        posteriors = infer(evidence, priors)
        priors = posteriors
        fig.add_trace(
            go.Bar(
                x=[int(count) for count, _ in priors.items()],
                y=[float(p) for _, p in priors.items()],
                name=f"Step {step}",
            ),
            row=1 + step,
            col=1,
        )
        max_y = max(max_y, float(max(priors.values())))

    for step in range(0, 1 + steps):
        fig.update_yaxes(range=(0, max_y), row=1 + step, col=1)

    fig.update_layout(
        margin=dict(l=10, r=10, t=30, b=10),
        height=steps * 100,
    )

    return fig

In [7]:
def opinionated_priors():
    counts = list(range(0, TOKENS + 1))

    m = Fraction(1, sum([count + 1 for count in counts]))
    return {
        count: 
        m*(count+1)
        for count in counts 
    }



In [8]:
jar = Jar()
priors = {count: Fraction(1, 2 ** (1 + count)) for count in range(0, 1 + TOKENS)}
priors[TOKENS] *= 2

simulate(jar=jar, steps=20, priors=opinionated_priors())

In [9]:
jar.red_tokens

96

So even with quite opinionated hypotheses, we trend towards the correct value.