<a href="https://colab.research.google.com/github/tomonari-masada/course2025-stats2/blob/main/01_introduction_PyMC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 第1回授業の例題にPyMCを使う

### 例題：メッセージ数に変化はあるか？（参考書より）
* **参考書**： キャメロン・デビッドソン=ピロン(著), 玉木徹(訳)『Pythonで体験するベイズ推論:PyMCによるMCMC入門』, 森北出版 (2017)
 * https://www.amazon.co.jp/dp/4627077912

In [None]:
!wget "https://raw.githubusercontent.com/CamDavidsonPilon/Probabilistic-Programming-and-Bayesian-Methods-for-Hackers/master/Chapter1_Introduction/data/txtdata.csv"

In [None]:
import numpy as np
import matplotlib.pyplot as plt

PATH = 'txtdata.csv'

plt.figure(figsize=(12.5, 4))

count_data = np.loadtxt(PATH)
n_count_data = len(count_data)
plt.bar(np.arange(n_count_data), count_data, color="#348ABD")
plt.xlabel("Time (days)")
plt.ylabel("count of text-msgs received")
plt.title("Did the user's texting habits change over time?")
plt.xlim(0, n_count_data);

$$ X_t \sim \text{Poi}(\lambda_1) \; \; \text{ if $t < \tau$ } $$
$$ X_t \sim \text{Poi}(\lambda_2) \; \; \text{ if $t \geq \tau$ } $$

$$ \lambda_1 \sim \text{Exp}(\alpha) $$
$$ \lambda_2 \sim \text{Exp}(\alpha) $$


$$ P(\tau = k) = \frac{1}{N} \;\; \text{ for $k = 1,\ldots, N$ } $$

In [None]:
import numpy as np
import pymc as pm
import arviz as az

RANDOM_SEED = 8927
rng = np.random.default_rng(RANDOM_SEED)
az.style.use("arviz-darkgrid")

In [None]:
with pm.Model() as model:
  n_data = len(count_data)

  alpha = 1.0 / count_data.mean() # 受信数の平均の逆数（なぜこのように設定するかについては参考書を参照）
  lambda_1 = pm.Exponential("lambda_1", lam=alpha)
  lambda_2 = pm.Exponential("lambda_2", lam=alpha)

  tau = pm.Uniform("tau", lower=0, upper=n_data)

  idx = np.arange(n_data)
  lambda_ = pm.Deterministic("lambda_", pm.math.where(tau > idx, lambda_1, lambda_2))
  obs = pm.Poisson("obs", mu=lambda_, observed=count_data)

In [None]:
pm.model_to_graphviz(model)

In [None]:
with model:
  idata = pm.sample(draws=1000, tunes=1000, chains=2, cores=2, random_seed=rng)

In [None]:
az.plot_trace(idata);

In [None]:
var_names = ["lambda_1", "lambda_2", "tau"]
az.plot_trace(idata, var_names=var_names);

In [None]:
az.summary(idata, var_names=var_names)