## 【一般化線形混合モデル】果物の収穫量

果物の収穫量のデータに対してポアソン回帰（ランダム効果あり）のモデルを適用してみる。

In [None]:
import pymc as pm
import arviz as az

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
plt.rcParams['font.size'] = 12

## Load & Check Data

In [None]:
data = pd.read_csv('data.csv')

In [None]:
data

In [None]:
sns.displot(data=data, x='num_fruits', kde=True, binwidth=50);

plt.xlabel('Number of Fruits');

In [None]:
sns.displot(data=data, x='num_fruits', kde=True, binwidth=25);

plt.xlabel('Number of Fruits');

In [None]:
print('平均：{:.2f}'.format(data['num_fruits'].mean()))
print('分散：{:.2f}'.format(data['num_fruits'].var()))

## Define Model & Inference

In [None]:
y = data['num_fruits'].values

num_trees = len(y)

In [None]:
with pm.Model() as model:

    b = pm.Normal('b', mu=0, sigma=10)

    s = pm.HalfCauchy('s', 5)
    r = pm.Normal('r', mu=0, sigma=s, shape=num_trees)

    theta = b + r
    mu = pm.math.exp(theta)

    obs = pm.Poisson('obs', mu=mu, observed=y)

In [None]:
with model:

    trace = pm.sample(3000, tune=6000, target_accept=0.9, return_inferencedata=False)
    idata = pm.to_inference_data(trace)

## Check MCMC-samples

In [None]:
az.plot_trace(idata)
plt.gcf().subplots_adjust(wspace=0.5, hspace=0.5)

In [None]:
az.summary(idata)

## Posterior Predictive Check

In [None]:
with model:

    ppc = pm.sample_posterior_predictive(idata, return_inferencedata=False)

In [None]:
ppc['obs'].shape

In [None]:
fig = plt.figure(figsize=(12, 24))

for k in range(20):

    ax = fig.add_subplot(10, 2, k+1)

    az.plot_dist(ppc['obs'][:, :, k])
    ax.axvline(y[k], color='r', linestyle='dashed')
    ax.set_title('ID = {}'.format(k))

plt.tight_layout()

## Check Random Effects

In [None]:
az.plot_violin(idata.posterior['r'], grid=(1, num_trees), figsize=(12, 4));

## Check New Feature

In [None]:
data_updated = pd.read_csv('data_updated.csv')
data_updated.head()

In [None]:
span = data_updated['span']

In [None]:
r_mean = trace['r'].mean(axis=0)

In [None]:
fig = plt.figure(figsize=(6, 6))

sns.scatterplot(x=span, y=r_mean, s=100)

plt.xlabel('Span')
plt.ylabel('Random Effects');