In [35]:
import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt

# Model

$$\theta_i \sim N(0,1)$$
$$x_i \sim \theta_i + N(0,1)$$
$$z_{ij} \sim \theta_i + N(0,1)$$
$$y_{ij} = \alpha\theta_i + \delta z_{ij} + N(0,0.5)$$


In [36]:
ALPHA = 2
DELTA = 2
RNG = np.random.default_rng(32742895)

In [37]:
def ols(X,y):
    if len(X.shape) == 1:
        return (X@y) / (X@X)
    else:
        return np.linalg.inv(X.T @ X) @ X.T @ y

def demean(y, n, m):
    y_r = y.reshape((n,m))
    return (y_r - y_r.mean(axis=1, keepdims=True)).flatten()

In [38]:
class DataGenerator:
    def __init__(self, n, m):
        #直接讓他們都是一維
        theta = RNG.normal(size=n)
        theta = np.repeat(theta, m)

        x = theta + RNG.normal(size=n*m)
        z = theta + RNG.normal(size=n*m)
        y = ALPHA * theta + DELTA * z + RNG.normal(size=n*m, scale=0.5)
        self.theta = theta


        self.x = x
        self.z = z
        self.y = y
        self.theta_post_mean = self.x / 2

        self.n = n
        self.m = m


    def direct_ols(self):
        X = np.stack([self.theta_post_mean, self.z], axis=-1)
        return ols(X, self.y)[0]

    def demean_ols(self):
        y_demean_flat = demean(self.y, self.n, self.m)
        z_demean_flat = demean(self.z, self.n, self.m)

        delta_hat = ols(z_demean_flat, y_demean_flat)

        y_minus_z = self.y - delta_hat * self.z
        alpha_hat = ols(self.theta_post_mean, y_minus_z)
        return alpha_hat

In [None]:
def repeat(n,m,B):
    records = []
    for b in range(B):
        data = {}
        dg = DataGenerator(n,m)
        data['direct'] = dg.direct_ols()
        data['demean'] = dg.demean_ols()
        records.append(data)
    repeat_dict = {'df': pd.DataFrame(records),
                   'n': n, 'm': m, 'B': B}
    return repeat_dict


def plot_kde(repeat_dict, fname):
    fig, axs = plt.subplots(1,2, figsize=(12,4))

    fig.suptitle(f'n={repeat_dict['n']}, m={repeat_dict['m']}, B={repeat_dict['B']}')
    df = repeat_dict['df']
    for i in range(2):
        sns.kdeplot(df[df.columns[i]], ax=axs[i])
        axs[i].axvline(ALPHA, color='r')
        
    fig.savefig(fname)
