In [9]:
# EM算法入门：https://www.zhihu.com/question/40797593/answer/275171156（推导 https://zhuanlan.zhihu.com/p/36331115）
# <What is the expectation maximization algorithm?>

In [30]:
import numpy as np

In [None]:
#==================================== 抛硬币的经典例子

coin_sets = [
    (5,5),
    (9,1),
    (8,2),
    (4,6),
    (7,3)
]


In [75]:
# 1. hard EM or soft EM？
# - A、B两个coin只要更可能是A就选；决定好A或B后，再迭代一轮
# - 给定中立原理，即全空间上 P(A) = P(B)
# 2. 硬币的顺序本身是不是信息：是（注意每一组 coin_set 一定是特定的一组，而不是搞出了所有的5正5负）


class EM(object):
    def __init__(self, init_probs):
        self.init_probs = init_probs
        self.probs = np.array(init_probs)
        coin_nums = len(init_probs)
        self.coin_nums = coin_nums
        
        self.max_iter = 100
        self.lower_bound = 0.0
    
    @classmethod
    def _norm_by_rowsum(cls, np_2darray):
        return np_2darray / np.expand_dims(np.sum(np_2darray, axis=1), axis=1)

    @classmethod
    def _one_hot(cls, indexs, class_num):
        return np.eye(class_num)[indexs, : ]
        
    def expectation(self, coin_sets, mode="hard"):
        # hard EM
        # soft EM
        # return: matrix, shape: (coin_sets_num, coin_nums)

        coin_seq_X_prob = np.zeros(
            [len(coin_sets), self.coin_nums]
        )
        for i, _ in enumerate(coin_sets):
            on,off = _
            for j, prob in enumerate(self.probs):
                coin_seq_X_prob[i][j] =  pow(prob, on) * pow(1-prob, off)
        
        # 计算 取到各个硬币的概率（归一化） 的结果
        coin_seq = self._norm_by_rowsum(coin_seq_X_prob)
        
        # hard EM
        if mode == "hard":
            return self._one_hot(np.argmax(coin_seq, axis=1), self.coin_nums)
        else:
            return coin_seq
    
    def maxmization(self, coin_seq, coin_sets):
        # 计算 各个硬币 的 正反面期望次数，coin_nums * 2
        coin_exp_on_off = np.matmul(
            np.transpose(coin_seq), 
            np.array(coin_sets)
        )
        coin_probs_on_off = self._norm_by_rowsum(coin_exp_on_off)
        self.probs = coin_probs_on_off[:,0]
        return


In [78]:
#----------------------------------------  EM 10次之后的结果

em_obj = EM(init_probs = [0.6, 0.5])
for i in range(10):
    coin_seq = em_obj.expectation(coin_sets, mode="soft")
    em_obj.maxmization(coin_seq, coin_sets)

print(em_obj.probs)

[0.79674415 0.51965866]


In [None]:
#----------------------------------------  HARD EM: 在这个case中，10次之后的结果和1次之后等价

In [79]:
em_obj = EM(init_probs = [0.6, 0.5])
for i in range(10):
    coin_seq = em_obj.expectation(coin_sets, mode="hard")
    em_obj.maxmization(coin_seq, coin_sets)

print(em_obj.probs)

[0.8  0.45]


In [None]:
# UnitTest
em_obj = EM(init_probs = [0.6, 0.5])
print(em_obj.probs)
coin_seq = em_obj.expectation(coin_sets, mode="soft")
print(coin_seq)
em_obj.maxmization(coin_seq, coin_sets)
print(em_obj.probs)



In [None]:
em_obj = EM(init_probs = [0.6, 0.5])
print(em_obj.probs)
coin_seq = em_obj.expectation(coin_sets, mode="soft")
print(coin_seq)
em_obj.maxmization(coin_seq, coin_sets)
print(em_obj.probs)

In [None]:
# 初始版繁复算法：
#     def hard_expectation(self, coin_sets):
#         coin_seq = {i: 0 for i in self.coin_idxs}
#         coin_seq = []
#         coin_sets_num = len)(c)
#         coin_seq = np.zeros(coin_)
#         for _ in coin_sets:
#             on,off = _
#             tmp_coin = 0
#             max_prob = 0.0

#             for idx, prob in enumerate(self.init_probs):                
#                 if coin_prob < max_prob:
#                     continue

#                 tmp_coin = idx
#                 max_prob = coin_prob
                
#             coin_seq.append(tmp_coin)
#         return coin_seq

################################################################33

#     def maxmization(self, coin_seq, coin_sets):        
#         coin_result = {i:[0,0] for i in self.coin_idxs}
        
#         for _, on_off in zip(coin_seq, coin_sets):
#             on,off = on_off
#             coin_result[_][0] += on
#             coin_result[_][1] += off
        
#         for key, val in coin_result.items():
#             prob = val[0] / (val[0] + val[1])
#             self.init_probs[key] = prob

#         return