## Now let's do the german tank problem

An enemy has some tanks.

You don't know how many, but occasionally you capture one.

Each tank has a sequential serial number. (So if you capture tank #50, you can be sure the enemy has produced _at least_ 50 tanks ) 

To make this simple, we imagine that we very kindly give the tanks back after we're done inspecting them.

In [None]:
import random

from fractions import Fraction


import plotly.graph_objects as go
from plotly.subplots import make_subplots


In [17]:
MAXIMUM_TANKS = 100

In [18]:
# This is the only function that needs to change
# from the previous notebook


def p_e_given_h(evidence: int, hypothesis: int) -> Fraction:
    """
    evidence: The tank serial number.
    hypothesis: The number of tanks we believe they have
    """
    if evidence > hypothesis:
        # obviously they can't have produced only M tanks if N>M are observed
        return Fraction(0, 1)
    else:
        return Fraction(1, hypothesis)


def p_e(evidence, priors):
    assert sum(priors.values()) == 1, priors.values()

    return sum(
        [
            p_e_given_h(evidence, hypothesis) * prior
            for hypothesis, prior in priors.items()
        ]
    )


def infer(evidence: str, priors: dict[str, Fraction]) -> dict[str, Fraction]:
    assert sum(priors.values()) == 1
    p_e_ = p_e(evidence, priors)
    return {
        hypothesis: Fraction(p_e_given_h(evidence, hypothesis) * p_h, p_e_)
        for hypothesis, p_h in priors.items()
    }


def inference(tokens, priors, hypothesis):
    yield dict(evidence="", prior=priors[hypothesis])
    while True:
        evidence = random.choice(tokens)
        posteriors = infer(evidence, priors)
        priors = posteriors
        yield dict(evidence=evidence, prior=priors[hypothesis])

In [19]:
def best_interval(priors, width=10):
    intervals = [
        sum(p_h for h, p_h in priors.items() if h >= count if h < count + width)
        for count in range(0, MAXIMUM_TANKS - width)
    ]

    start, confidence = max(enumerate(intervals), key=lambda pair: pair[1])
    end = start + width
    return dict(interval=(start, end), confidence=float(confidence))


In [None]:
tanks = random.randint(1, MAXIMUM_TANKS)

steps = 10
priors = {count: Fraction(1, MAXIMUM_TANKS) for count in range(1, 1 + MAXIMUM_TANKS)}
assert sum(priors.values()) == 1

fig = make_subplots(
    rows=1 + steps,
    cols=1,
    shared_xaxes=True,
)
fig.add_trace(
    go.Bar(
        x=[int(count) for count, _ in priors.items()],
        y=[float(p) for _, p in priors.items()],
        name="Initial",
    ),
    row=1,
    col=1,
)
fig.update_yaxes(range=(0, 0.2), row=1, col=1)


for step in range(1, 1 + steps):
    evidence = random.randint(1, tanks)
    posteriors = infer(evidence, priors)
    priors = posteriors
    fig.add_trace(
        go.Bar(
            x=[int(count) for count, _ in priors.items()],
            y=[float(p) for _, p in priors.items()],
            name=f"Step {step}",
        ),
        row=1 + step,
        col=1,
    )
    fig.update_yaxes(range=(0, 0.2), row=1 + step, col=1)


fig.update_layout(
    margin=dict(l=10, r=10, t=30, b=10),
    height=1000,
)
fig

In [22]:
best = best_interval(priors, width=10)
print(
    f"The number of tanks is between {best['interval'][0]} and {best['interval'][1]} with confidence {(best['confidence']):.2f} "
)

The number of tanks is between 74 and 84 with confidence 0.73 


In [23]:
tanks

83