In [25]:
import numpy as np
import jax
import jax.numpy as jnp
import torch
import torch.optim as optim
from scipy.optimize import minimize
from jax import grad, hessian
from jax.scipy.stats import norm as jnorm
from scipy.stats import norm
import pandas as pd

# ===================== Data Generation =====================

# Set random seed for reproducibility
np.random.seed(1234)
torch.manual_seed(1234)

# Data generation
N = 10000
t_np = np.random.binomial(1, 0.5, N)
e0_np = np.random.normal(0, 1, N)
e1_np = np.random.normal(0, 1, N)
a1 = a0 = -1
tau = 2
mu0_np = a0
mu1_np = a1 + tau * t_np
sig0 = 1
sig1 = 2
lny0_np = mu0_np + sig0 * e0_np
lny1_np = mu1_np + sig1 * e1_np
y0_np = np.exp(lny0_np)
y1_np = np.exp(lny1_np)
y_np = t_np * y1_np + (1 - t_np) * y0_np
lny_np = np.round(np.log(y_np), 4)
ate_true = np.exp(a1 + tau + 0.5 * sig1**2) - np.exp(a0 + 0.5 * sig0**2)
print(f"True ATE: {ate_true:.4f}")

# ===================== SciPy Estimation =====================

def negloglik_scipy(theta):
    mu0, mu1, log_sig0, log_sig1 = theta
    sig0 = np.exp(log_sig0)
    sig1 = np.exp(log_sig1)
    lny0_ll = norm.logpdf(lny_np, loc=mu0, scale=sig0)
    lny1_ll = norm.logpdf(lny_np, loc=mu1, scale=sig1)
    loss = -np.sum(t_np * lny1_ll + (1 - t_np) * lny0_ll)
    return loss

# Initial parameter guesses
theta0 = np.array([-1.0, -2.0, np.log(1.0), np.log(2.0)])

# Optimization using BFGS without providing gradients (numerical differentiation)
res_scipy = minimize(negloglik_scipy, theta0, method='BFGS', tol=1e-6)

# Extract estimated parameters
mu0_est_scipy, mu1_est_scipy = res_scipy.x[0], res_scipy.x[1]
log_sig0_est_scipy, log_sig1_est_scipy = res_scipy.x[2], res_scipy.x[3]
sig0_est_scipy = np.exp(log_sig0_est_scipy)
sig1_est_scipy = np.exp(log_sig1_est_scipy)

# Compute standard errors using the inverse Hessian
V_scipy = res_scipy.hess_inv
std_errors_scipy = np.sqrt(np.diag(V_scipy))
var_log_sig0_scipy = V_scipy[2, 2]
var_log_sig1_scipy = V_scipy[3, 3]
std_sig0_est_scipy = sig0_est_scipy * np.sqrt(var_log_sig0_scipy)
std_sig1_est_scipy = sig1_est_scipy * np.sqrt(var_log_sig1_scipy)

# Compute ATE and its variance
Ey0_scipy = np.exp(mu0_est_scipy + 0.5 * sig0_est_scipy**2)
Ey1_scipy = np.exp(mu1_est_scipy + 0.5 * sig1_est_scipy**2)
ate_est_scipy = Ey1_scipy - Ey0_scipy

delta_g_scipy = np.array([
    -Ey0_scipy,  # derivative w.r.t mu0
    Ey1_scipy,   # derivative w.r.t mu1
    -Ey0_scipy * sig0_est_scipy**2,  # derivative w.r.t log_sig0
    Ey1_scipy * sig1_est_scipy**2    # derivative w.r.t log_sig1
])

V_ATE_scipy = delta_g_scipy.T @ V_scipy @ delta_g_scipy
ate_std_error_scipy = np.sqrt(V_ATE_scipy)

# ===================== JAX Estimation =====================

# Convert data to JAX arrays
t_jax = jnp.array(t_np)
lny_jax = jnp.array(lny_np)

def negloglik_jax(theta):
    mu0, mu1, log_sig0, log_sig1 = theta
    sig0 = jnp.exp(log_sig0)
    sig1 = jnp.exp(log_sig1)
    lny0_ll = jnorm.logpdf(lny_jax, loc=mu0, scale=sig0)
    lny1_ll = jnorm.logpdf(lny_jax, loc=mu1, scale=sig1)
    loss = -jnp.sum(t_jax * lny1_ll + (1 - t_jax) * lny0_ll)
    return loss

# Compute gradient and Hessian using JAX
negloglik_grad_jax = grad(negloglik_jax)
negloglik_hessian_jax = hessian(negloglik_jax)

# Optimization using BFGS with gradient
theta0_jax = np.array([-1.0, -2.0, np.log(1.0), np.log(2.0)])
res_jax = minimize(negloglik_jax, theta0_jax, method='BFGS', jac=negloglik_grad_jax)

# Extract estimated parameters
mu0_est_jax, mu1_est_jax = res_jax.x[0], res_jax.x[1]
log_sig0_est_jax, log_sig1_est_jax = res_jax.x[2], res_jax.x[3]
sig0_est_jax = np.exp(log_sig0_est_jax)
sig1_est_jax = np.exp(log_sig1_est_jax)

# Compute standard errors using the Hessian
H_jax = negloglik_hessian_jax(res_jax.x)
V_jax = np.linalg.inv(H_jax)
std_errors_jax = np.sqrt(np.diag(V_jax))
var_log_sig0_jax = V_jax[2, 2]
var_log_sig1_jax = V_jax[3, 3]
std_sig0_est_jax = sig0_est_jax * np.sqrt(var_log_sig0_jax)
std_sig1_est_jax = sig1_est_jax * np.sqrt(var_log_sig1_jax)

# Compute ATE and its variance
Ey0_jax = np.exp(mu0_est_jax + 0.5 * sig0_est_jax**2)
Ey1_jax = np.exp(mu1_est_jax + 0.5 * sig1_est_jax**2)
ate_est_jax = Ey1_jax - Ey0_jax

delta_g_jax = np.array([
    -Ey0_jax,  # derivative w.r.t mu0
    Ey1_jax,   # derivative w.r.t mu1
    -Ey0_jax * sig0_est_jax**2,  # derivative w.r.t log_sig0
    Ey1_jax * sig1_est_jax**2    # derivative w.r.t log_sig1
])

V_ATE_jax = delta_g_jax.T @ V_jax @ delta_g_jax
ate_std_error_jax = np.sqrt(V_ATE_jax)

# ===================== PyTorch Estimation =====================

# Convert data to PyTorch tensors
t_torch = torch.tensor(t_np, dtype=torch.float32)
lny_torch = torch.tensor(lny_np, dtype=torch.float32)

def negloglik_torch(theta):
    mu0, mu1, log_sig0, log_sig1 = theta
    sig0 = torch.exp(log_sig0)
    sig1 = torch.exp(log_sig1)
    lny0_dist = torch.distributions.Normal(mu0, sig0)
    lny1_dist = torch.distributions.Normal(mu1, sig1)
    lny0_ll = lny0_dist.log_prob(lny_torch)
    lny1_ll = lny1_dist.log_prob(lny_torch)
    loss = -torch.sum(t_torch * lny1_ll + (1 - t_torch) * lny0_ll)
    return loss

# Initial parameter guesses
theta0_torch = torch.tensor([-1.0, -2.0, np.log(1.0), np.log(2.0)], requires_grad=True)

# Optimization using LBFGS
optimizer = optim.LBFGS([theta0_torch], max_iter=100, line_search_fn='strong_wolfe')

def closure():
    optimizer.zero_grad()
    loss = negloglik_torch(theta0_torch)
    loss.backward()
    return loss

optimizer.step(closure)

# Extract estimated parameters
mu0_est_torch = theta0_torch[0].item()
mu1_est_torch = theta0_torch[1].item()
log_sig0_est_torch = theta0_torch[2].item()
log_sig1_est_torch = theta0_torch[3].item()
sig0_est_torch = np.exp(log_sig0_est_torch)
sig1_est_torch = np.exp(log_sig1_est_torch)

# Compute standard errors using Hessian
def negloglik_torch_vector(theta_vector):
    return negloglik_torch(theta_vector)

theta_opt = theta0_torch.detach().clone().requires_grad_(True)
hessian_torch = torch.autograd.functional.hessian(negloglik_torch_vector, theta_opt)
H_torch = hessian_torch.detach().numpy()
V_torch = np.linalg.inv(H_torch)
std_errors_torch = np.sqrt(np.diag(V_torch))
var_log_sig0_torch = V_torch[2, 2]
var_log_sig1_torch = V_torch[3, 3]
std_sig0_est_torch = sig0_est_torch * np.sqrt(var_log_sig0_torch)
std_sig1_est_torch = sig1_est_torch * np.sqrt(var_log_sig1_torch)

# Compute ATE and its variance
Ey0_torch = np.exp(mu0_est_torch + 0.5 * sig0_est_torch**2)
Ey1_torch = np.exp(mu1_est_torch + 0.5 * sig1_est_torch**2)
ate_est_torch = Ey1_torch - Ey0_torch

delta_g_torch = np.array([
    -Ey0_torch,  # derivative w.r.t mu0
    Ey1_torch,   # derivative w.r.t mu1
    -Ey0_torch * sig0_est_torch**2,  # derivative w.r.t log_sig0
    Ey1_torch * sig1_est_torch**2    # derivative w.r.t log_sig1
])

V_ATE_torch = delta_g_torch.T @ V_torch @ delta_g_torch
ate_std_error_torch = np.sqrt(V_ATE_torch)

# ===================== Compile Results =====================

results = {
    'Method': ['SciPy', 'JAX', 'PyTorch'],
    'mu0': [mu0_est_scipy, mu0_est_jax, mu0_est_torch],
    'std_mu0': [std_errors_scipy[0], std_errors_jax[0], std_errors_torch[0]],
    'mu1': [mu1_est_scipy, mu1_est_jax, mu1_est_torch],
    'std_mu1': [std_errors_scipy[1], std_errors_jax[1], std_errors_torch[1]],
    'sig0': [sig0_est_scipy, sig0_est_jax, sig0_est_torch],
    'std_sig0': [std_sig0_est_scipy, std_sig0_est_jax, std_sig0_est_torch],
    'sig1': [sig1_est_scipy, sig1_est_jax, sig1_est_torch],
    'std_sig1': [std_sig1_est_scipy, std_sig1_est_jax, std_sig1_est_torch],
    'ATE': [ate_est_scipy, ate_est_jax, ate_est_torch],
    'ATE_std_error': [ate_std_error_scipy, ate_std_error_jax, ate_std_error_torch]
}

df_results = pd.DataFrame(results)

# Format the numbers for better readability
pd.options.display.float_format = '{:.4f}'.format

# Display the results
print("\nEstimation Results:")
display(df_results)


True ATE: 19.4790

Estimation Results:


Unnamed: 0,Method,mu0,std_mu0,mu1,std_mu1,sig0,std_sig0,sig1,std_sig1,ATE,ATE_std_error
0,SciPy,-0.9895,0.0143,0.9941,0.0288,1.0056,0.0103,1.9972,0.0199,19.2381,0.9654
1,JAX,-0.9895,0.0143,0.9941,0.0281,1.0056,0.0101,1.9973,0.0199,19.2426,0.9667
2,PyTorch,-0.9893,0.0143,0.9929,0.0281,1.0056,0.0101,1.9966,0.0199,19.1927,0.9632
