<a href="https://colab.research.google.com/github/tomonari-masada/course2025-stats1/blob/main/EM_for_zero_inflated_poisson.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# EM algorithm for zero-inflated Poisson model

# 準備

In [None]:
import numpy as np
from scipy.stats import poisson
from scipy.special import expit
import seaborn as sns

rng = np.random.default_rng(seed=42)

# データの合成

In [None]:
n_samples = 1000

prob_true = 0.7
mu_true = 5

x = np.zeros(n_samples)
for i in range(n_samples):
    if rng.uniform() > prob_true:
        x[i] = poisson.rvs(mu=mu_true, random_state=rng)
    else:
        x[i] = 0

In [None]:
y = np.bincount(x.astype(np.int64))
sns.barplot(x=np.arange(len(y)), y=y)

# M step

In [None]:
def M_step(x, q):
    prob_guessed = q.sum() / len(q)
    mu_guessed = ((1 - q) * x).sum() / (1 - q).sum()
    return prob_guessed, mu_guessed

# E step

In [None]:
def E_step(x, prob_guessed, mu_guessed):
    q = prob_guessed * (x == 0.0)
    q = q / (q + (1 - prob_guessed) * poisson.pmf(x, mu_guessed))
    return q

# 初期化

In [None]:
q = expit(rng.normal(size=n_samples))
q[:10]

# EMアルゴリズム

In [None]:
for _ in range(10):
    prob_guessed, mu_guessed = M_step(x, q)
    q = E_step(x, prob_guessed, mu_guessed)
    print(f"{prob_guessed:.3f}, {mu_guessed:.2f}")