In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from mlp import get_batch
from torch.nn import Sigmoid
import numpy as np
from sklearn.linear_model import LogisticRegression
import matplotlib.pyplot as plt


def truncate(pscores, thresh):
    return np.clip(pscores, thresh, 1 - thresh)

def ipw_ht(Z, Y, pscores):

    weights = Z / pscores + (1 - Z) / (1 - pscores)

    output = np.sum(weights * Y) / np.sum(weights)
    return output


def ipw_hayek(Z, Y, pscores):

    t1 = Z / pscores
    t2 = (1 - Z) / (1 - pscores)

    return np.sum(t1 * Y) / np.sum(t1) - np.sum(t2 * Y) / np.sum(t2)

hyperparameters = {
    "is_causal": True,
    "num_causes": 8,
    "prior_mlp_hidden_dim": 8,
    "num_layers": 3,
    "noise_std": 0.1,
    "y_is_effect": True,
    "pre_sample_weights": False,
    "prior_mlp_dropout_prob": 0.5,
    "pre_sample_causes": False,
    "prior_mlp_activations": Sigmoid,
    "block_wise_dropout": True,
    "init_std": 1.0,
    'prior_mlp_scale_weights_sqrt': True,
    'sampling': 'normal',
    'in_clique': False,
    'sort_features': False,
    'random_feature_rotation': False,
    'new_mlp_per_example': True
} 

x, y, _ = get_batch(batch_size=100, seq_len=10000, num_features=8, hyperparameters=hyperparameters,
          num_outputs=3, sampling="mixed")
x = x.numpy()
y = y.numpy()
print(x.shape, y.shape)

plt.scatter(x[:,0,0], x[:,0,1], c=y[:,0,0], s=2)

In [None]:
y.shape

In [None]:
ATEs = []
ATE_hayeks = []
for i in range(100):
    X = x[:,i,:]
    Y0 = y[:,i,0]
    Z = y[:,i,1]
    Z = Z > np.median(Z)
    Y1 = Y0 + 1#np.abs(y[:,i,2]) # 1
    Y = np.where(Z, Y1, Y0)
    ATE = np.mean(Y1) - np.mean(Y0)
    # print(ATE)
    ATEs.append(ATE)

    pscore_model = LogisticRegression()
    pscore_model.fit(X, Z)
    pscores = pscore_model.predict_proba(X)[:, 1]
    pscores = truncate(pscores, 0.01)

    ATE_hayek = ipw_hayek(Z, Y, pscores)
    ATE_hayeks.append(ATE_hayek)

In [None]:
pscores

In [None]:
plt.hist(pscores[Z==0], histtype='step')
plt.hist(pscores[Z==1], histtype='step')

In [None]:
plt.scatter(ATEs, ATE_hayeks)

In [None]:
plt.hist(ATE_hayeks)

In [None]:
hyperparameters = {
    "is_causal": True,
    "num_causes": 8,
    "prior_mlp_hidden_dim": 3,
    "num_layers": 2,
    "noise_std": 0.1,
    "y_is_effect": True,
    "pre_sample_weights": False,
    "prior_mlp_dropout_prob": 0.5,
    "pre_sample_causes": False,
    "prior_mlp_activations": Sigmoid,
    "block_wise_dropout": False,
    "init_std": 1.0,
    'prior_mlp_scale_weights_sqrt': True,
    'sampling': 'normal',
    'in_clique': False,
    'sort_features': False,
    'random_feature_rotation': False,
    'new_mlp_per_example': True
} 

_, x, _ = get_batch(batch_size=100, seq_len=10000, num_features=8, hyperparameters=hyperparameters,
          num_outputs=5, sampling="mixed")
_, u, _ = get_batch(batch_size=100, seq_len=10000, num_features=8, hyperparameters=hyperparameters,
          num_outputs=3, sampling="mixed")

# x = x.numpy()
# u = u.numpy()
x = x.detach()
u = u.detach()

# print(x.shape, y.shape)

# plt.scatter(x[:,0,0], x[:,0,1], c=y[:,0,0], s=2)

In [None]:
x.shape, u.shape

In [None]:
u.mean()

In [None]:
x.mean()

In [None]:
# make list of 10000 x, u pairs
causes = [[x[:,i,np.newaxis,:], u[:,i,np.newaxis,:]] for i in range(100)]

In [None]:
_, y, _ = get_batch(batch_size=100, seq_len=10000, num_features=8, hyperparameters=hyperparameters,
          num_outputs=2, sampling="mixed", causes=causes)

In [None]:
y.shape

In [None]:
rands = np.random.normal(0, 1, size=(10000, 8))

In [None]:
y.shape

In [None]:

np.corrcoef(y[:,0,0], rands[:,3])

In [None]:
np.corrcoef(y[:,0,0], x[:,0,0])

In [None]:
np.corrcoef(y[:,0,0], u[:,0,2])

In [None]:
plt.scatter(y[:,0,0], x[:,0,0], s=2)

In [None]:
plt.scatter(y[:,0,0], u[:,0,2], s=2)