## 【ロジスティック回帰】カブトムシの死亡率

カブトムシの死亡数のデータを使って、ロジスティック回帰を行ってみる。

In [None]:
import pymc as pm
import arviz as az

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
plt.rcParams['font.size'] = 12
plt.rcParams['figure.figsize'] = [8, 6]

## Load & Check Data

In [None]:
data = pd.read_csv('beetle.csv')

In [None]:
data

In [None]:
sns.scatterplot(x=data['gas'], y=data['deaths']/data['beetles'], s=100)

plt.xlabel('Gas Concentration')
plt.ylabel('Death Rate');

## Preprocess & Scale Data

In [None]:
gas = data['gas'].values

deaths = data['deaths'].values
beetles = data['beetles'].values

In [None]:
gas_mu = np.mean(gas)
gas_sd = np.std(gas)

gas_scaled = (gas - gas_mu) / gas_sd

## Define Model & Inference

In [None]:
with pm.Model() as model:

    a = pm.Normal('a', mu=0, sigma=10)
    b = pm.Normal('b', mu=0, sigma=10)

    mu = a * gas_scaled + b

    theta = pm.math.invlogit(mu)

    obs = pm.Binomial('obs', p=theta, observed=deaths, n=beetles)

In [None]:
with model:

    trace = pm.sample(3000, return_inferencedata=False)
    idata = pm.to_inference_data(trace)

## Check MCMC-samples

In [None]:
az.plot_trace(idata)
plt.gcf().subplots_adjust(wspace=0.5, hspace=0.5)

In [None]:
az.summary(idata)

In [None]:
az.plot_posterior(idata);

In [None]:
def plot_curves(trace):

    samples_to_plot = 50

    x_scaled_new = np.linspace(-2, 2, 50)

    for k in range(1, samples_to_plot):

        a_sample = trace['a'][-k]
        b_sample = trace['b'][-k]

        mu = a_sample * x_scaled_new + b_sample

        p = 1 / (1 + np.exp(-mu))

        plt.plot(x_scaled_new, p, c='g', alpha=0.2)


In [None]:
trace

In [None]:
plt.scatter(gas_scaled, deaths / beetles)
plot_curves(trace)

plt.xlabel('Gas Concentration (Standardized)')
plt.ylabel('Death Rate');

## Prior Predictive Check

In [None]:
with model:

    prior_samples = pm.sample_prior_predictive(samples=50, return_inferencedata=False)

In [None]:
_, ax = plt.subplots()

plt.scatter(gas_scaled, deaths / beetles)

plot_curves(prior_samples)

plt.xlabel('Gas Concentration (Standardized)')
plt.ylabel('Death Rate');

## Posterior Predictive Check

In [None]:
with model:

    ppc = pm.sample_posterior_predictive(idata, return_inferencedata=False)

In [None]:
ppc['obs'].shape

In [None]:
fig = plt.figure(figsize=(12, 12))

for k in range(8):

    ax = fig.add_subplot(4, 2, k+1)

    az.plot_dist(ppc['obs'][0, :, k])
    ax.axvline(deaths[k], color='r', linestyle='dashed')
    ax.set_title('Gas Concentration = {:.2f}'.format(gas[k]))

plt.tight_layout()