## 【一般化線形混合モデル】果物の収穫量

新たに取得した特徴量で、木々の個体差の説明ができるかをモデルにより検証してみる。

In [None]:
import pymc as pm
import arviz as az

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
plt.rcParams['font.size'] = 12

## Load Data

In [None]:
data = pd.read_csv('data_updated.csv')
data

## Preprocess & Scale Data

In [None]:
x = data['span'].values
y = data['num_fruits'].values

num_trees = len(y)

In [None]:
x_mu = np.mean(x)
x_sd = np.std(x)

x_scaled = (x - x_mu) / x_sd

## Define Model & Inference

In [None]:
with pm.Model() as model:

    a = pm.Normal('a', mu=0, sigma=10)
    b = pm.Normal('b', mu=0, sigma=10)

    s = pm.HalfCauchy('s', 5)
    r = pm.Normal('r', mu=0, sigma=s, shape=num_trees)

    theta = a * x_scaled + b + r

    mu = pm.math.exp(theta)

    obs = pm.Poisson('obs', mu=mu, observed=y)

In [None]:
with model:

    trace = pm.sample(3000, tune=6000, target_accept=0.99, return_inferencedata=False)
    idata = pm.to_inference_data(trace)

## Check MCMC-samples

In [None]:
az.plot_trace(idata)
plt.gcf().subplots_adjust(wspace=0.5, hspace=0.5)

In [None]:
az.summary(idata)

## Posterior Predictive Check

In [None]:
with model:

    ppc = pm.sample_posterior_predictive(idata, return_inferencedata=False)

In [None]:
ppc['obs'].shape

In [None]:
fig = plt.figure(figsize=(12, 24))

for k in range(20):

    ax = fig.add_subplot(10, 2, k+1)

    az.plot_dist(ppc['obs'][:, :, k])
    ax.axvline(y[k], color='r', linestyle='dashed')
    ax.set_title('ID = {}'.format(k))

plt.tight_layout()

## Check Random Effect

In [None]:
az.plot_violin(idata.posterior['r'], grid=(1, num_trees), figsize=(12, 4));

## Compare with True values

In [None]:
data_true = pd.read_csv('data_true.csv')
data_true.head(10)

In [None]:
r_true = data_true['random_effects']

In [None]:
r_mean = trace['r'].mean(axis=0)

In [None]:
fig = plt.figure(figsize=(10, 4))

plt.plot(np.arange(num_trees), r_true, 'o-', markersize=8, label='True Values')
plt.plot(r_mean, 'o-', markersize=8, label='Estimated Values')
plt.xticks(np.arange(num_trees))
plt.xlabel('ID')
plt.ylabel('Random Effects')
plt.legend()

plt.tight_layout()

In [None]:
fig = plt.figure(figsize=(6, 6))

sns.scatterplot(x=r_mean, y=r_true, s=100)

plt.xlabel('True Value')
plt.ylabel('Estimated Value');