# Recovering latent coin probabilities with EM

Simulate 10 repetitions of the following procedure:
- Pick one of two coins; one has P(heads)=0.8, one has P(heads)=0.4
- Flip the chosen coin 5 times, recording 1 for 'heads' and 0 for 'tails'

In [1]:
from emstats import hiddencoins

sim = hiddencoins.Simulation(
    n_sequences=20,
    n_reps_per_sequence=5,
    p=(0.8, 0.4)
)
data = sim.run()
data.head()

Unnamed: 0,flip_0,flip_1,flip_2,flip_3,flip_4,true_p
0,1,0,0,0,1,0.8
1,1,0,0,1,1,0.8
2,0,0,1,0,1,0.4
3,0,1,0,0,1,0.8
4,0,1,1,0,0,0.4


We can see the true probability of each coin because this is a simulation, but what if we couldn't? Suppose we want to estimate the heads probability of the coin that generated each sequence (row), and all we know is that there are two coins. EM is good for this:

In [2]:
hc = hiddencoins.EM(
    X=data.iloc[:, :-1].values,
    n_coins=2
)
hc.run()
data['p_est'] = hc.fitted_bernoulli_p()
data['mode_p_est'] = hc.fitted_bernoulli_p(mode=True)
data.tail()

Unnamed: 0,flip_0,flip_1,flip_2,flip_3,flip_4,true_p,p_est,mode_p_est
15,1,0,0,0,0,0.4,0.497789,0.49693
16,0,0,0,1,1,0.4,0.50377,0.49693
17,1,1,1,1,0,0.8,0.707683,0.888833
18,1,1,1,1,1,0.8,0.851206,0.888833
19,1,1,1,1,0,0.8,0.707683,0.888833


Here p_est is the heads probability, averaged over the estimated probability of each latent coin, while mode_p_est is the head probability of the most likely latent coin. This is small data, so it's not surprising to see that the estimates have significant error, but at least they are correlated with the truth:

In [3]:
data.true_p.corr(data.p_est)

0.5760558486726696

This also works in general for n coins. Here it is with 3:

In [4]:
sim = hiddencoins.Simulation(
    n_sequences=20,
    n_reps_per_sequence=7,
    p=(0.8, 0.1, 0.4)
)
data = sim.run()
hc = hiddencoins.EM(
    X=data.iloc[:, :-1].values,
    n_coins=3
)
hc.run()
data['p_est'] = hc.fitted_bernoulli_p()
data['mode_p_est'] = hc.fitted_bernoulli_p(mode=True)
data.true_p.corr(data.p_est)

0.9213005598390488

You can see how many iterations it took:

In [5]:
hc.iter

66