## 【線形回帰】新生児の体重

新生児の性別を無視して、妊娠期間と体重の関係を線形回帰モデルにあてはめてみる。

In [None]:
import pymc as pm
import arviz as az

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
plt.rcParams['font.size'] = 12
plt.rcParams['figure.figsize'] = [8, 6]

## Load & Check Data

In [None]:
data = pd.read_csv('babies_weight.csv')

In [None]:
data.head(10)

In [None]:
sns.scatterplot(data=data, x='weeks', y='weight', hue='gender', s=100)

plt.xlabel('Period [weeks]')
plt.ylabel('Weight [g]');

## Preprocess & Scale Data

In [None]:
def standardize(x):

    return (x - x.mean()) / x.std()

In [None]:
x = data['weeks'].values
y = data['weight'].values

gender = data['gender'].values

In [None]:
x_scaled = standardize(x)
y_scaled = standardize(y)

## Define Model & Inference

In [None]:
with pm.Model() as model:

    a = pm.Normal('a', mu=0, sigma=10)
    b = pm.Normal('b', mu=0, sigma=10)

    mu = a * x_scaled + b

    sd = pm.HalfCauchy('sd', 5)

    obs = pm.Normal('obs', mu=mu, sigma=sd, observed=y_scaled)

In [None]:
with model:

    trace = pm.sample(3000, return_inferencedata=False)
    idata = pm.to_inference_data(trace)

## Check MCMC-samples

In [None]:
az.plot_trace(idata)
plt.gcf().subplots_adjust(wspace=0.5, hspace=0.5)

In [None]:
az.summary(idata)

## Visualize Parameters

In [None]:
def plot_lines(trace):

    samples_to_plot = 50

    x_scaled_new = np.linspace(-3, 3, 50)

    for k in range(1, samples_to_plot):

        a_sample = trace['a'][-k]
        b_sample = trace['b'][-k]

        mu = a_sample * x_scaled_new + b_sample

        plt.plot(x_scaled_new, mu, c='g', alpha=0.1)


In [None]:
plot_lines(trace)

sns.scatterplot(x=x_scaled, y=y_scaled, hue=data['gender'], s=80)

plt.xlabel('Period (Standardized)')
plt.ylabel('Weight (Standardized)');

## Posterior Predictive Check

In [None]:
with model:

    ppc = pm.sample_posterior_predictive(idata, return_inferencedata=False)

In [None]:
ppc['obs'].shape

In [None]:
az.plot_hdi(x_scaled, ppc['obs'])

sns.scatterplot(x=x_scaled, y=y_scaled, hue=gender, s=80)

plt.xlabel('Period (Standardized)')
plt.ylabel('Weight (Standardized)');