Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@
*.blg
*.out
*.synctex.gz
nb/scratch.ipynb
8 changes: 7 additions & 1 deletion jaxonometrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,21 @@
__version__ = "0.0.1"

from .base import BaseEstimator
from .causal import EntropyBalancing
from .causal import EntropyBalancing, IPW, AIPW # Added IPW, AIPW
from .gmm import GMM, LinearIVGMM, TwoStepGMM
from .linear import LinearRegression
from .mle import LogisticRegression, PoissonRegression, MaximumLikelihoodEstimator # Added MLE models

__all__ = [
"BaseEstimator",
"EntropyBalancing",
"IPW", # Added
"AIPW", # Added
"GMM",
"LinearIVGMM",
"TwoStepGMM",
"LinearRegression",
"MaximumLikelihoodEstimator",
"LogisticRegression",
"PoissonRegression",
]
266 changes: 265 additions & 1 deletion jaxonometrics/causal.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import Dict, Optional
from typing import Dict, Optional, Any

import jax # Ensure jax is imported
import optax # Added for type hint
import jax.numpy as jnp
from jaxopt import LBFGS

from .base import BaseEstimator
from .linear import LinearRegression # For default outcome model in AIPW


class EntropyBalancing(BaseEstimator):
Expand All @@ -19,6 +22,7 @@ def __init__(self):
super().__init__()

@staticmethod
@jax.jit
def _eb_moment(b: jnp.ndarray, X0: jnp.ndarray, X1: jnp.ndarray) -> jnp.ndarray:
"""The moment condition for entropy balancing."""
return jnp.log(jnp.exp(-1 * X0 @ b).sum()) + X1 @ b
Expand All @@ -44,3 +48,263 @@ def fit(
wt /= wt.sum()
self.params = {"weights": wt}
return self


class IPW(BaseEstimator):
"""
Inverse Propensity Weighting estimator for Average Treatment Effect (ATE).
This implementation uses Logistic Regression for propensity score estimation.
"""

def __init__(
self,
propensity_optimizer: Optional[optax.GradientTransformation] = None,
propensity_maxiter: int = 5000,
ps_clip_epsilon: float = 1e-6,
):
"""
Initialize the IPW estimator.

Args:
propensity_optimizer: Optional Optax optimizer for the Logit propensity score model.
Defaults to optax.adam(1e-3) if None.
propensity_maxiter: Maximum iterations for the Logit model optimization.
ps_clip_epsilon: Small constant to clip propensity scores.
"""
super().__init__()
from .mle import LogisticRegression

self.logit_model = LogisticRegression(
optimizer=propensity_optimizer, maxiter=propensity_maxiter
)
self.ps_clip_epsilon = ps_clip_epsilon
self.params: Dict[str, Any] = {"ate": None, "propensity_scores": None}

def fit(
self,
X: jnp.ndarray,
W: jnp.ndarray,
y: jnp.ndarray,
) -> "IPW":
"""
Estimate the Average Treatment Effect (ATE) using IPW.

Args:
X: Covariate matrix of shape (n_samples, n_features).
It's assumed that X includes an intercept column if one is desired for the propensity score model.
T: Treatment assignment vector (binary, 0 or 1) of shape (n_samples,).
y: Outcome vector of shape (n_samples,).

Returns:
The fitted estimator with ATE and propensity scores.
"""
# Ensure T is jnp.ndarray for Logit model
if not isinstance(W, jnp.ndarray):
W_jax = jnp.array(W)
else:
W_jax = W

# 1. Estimate propensity scores P(T=1|X) using Logit
self.logit_model.fit(X, W_jax)
propensity_scores = self.logit_model.predict_proba(X)

# Clip propensity scores using the instance attribute
propensity_scores = jnp.clip(
propensity_scores, self.ps_clip_epsilon, 1 - self.ps_clip_epsilon
)

self.params["propensity_scores"] = propensity_scores

# 2. Calculate IPW weights
# Weight for treated: 1 / p_score
# Weight for control: 1 / (1 - p_score)
weights = W_jax / propensity_scores + (1 - W_jax) / (1 - propensity_scores)

# 3. Estimate ATE: E[Y(1)] - E[Y(0)]
# E[Y(1)] = sum(T_i * y_i / p_i) / sum(T_i / p_i)
# E[Y(0)] = sum((1-T_i) * y_i / (1-p_i)) / sum((1-T_i) / (1-p_i))
# ATE = sum( (T_i/p_i - (1-T_i)/(1-p_i)) * y_i ) / N (Hahn, 1998 type estimator)
# Or, more commonly for Horvitz-Thompson type:
# E[Y_1] = sum(T*y / ps) / sum(T/ps)
# E[Y_0] = sum((1-T)*y / (1-ps)) / sum((1-T)/(1-ps))
# ATE = E[Y_1] - E[Y_0]

# Using the simpler weighted average formulation for ATE:
# ATE = (1/N) * Σ [ (T_i * Y_i / e(X_i)) - ((1-T_i) * Y_i / (1-e(X_i))) ]
# This can also be seen as E[ (T - e)Y / (e(1-e)) ]
# However, the difference of means of weighted outcomes is more standard:

mean_y1 = jnp.sum(W_jax * y * weights) / jnp.sum(W_jax * weights)
mean_y0 = jnp.sum((1 - W_jax) * y * weights) / jnp.sum((1 - W_jax) * weights)

# The above is equivalent to:
# mean_y1 = jnp.sum( (T_jax * y) / propensity_scores ) / jnp.sum( T_jax / propensity_scores )
# mean_y0 = jnp.sum( ((1-T_jax) * y) / (1-propensity_scores) ) / jnp.sum( (1-T_jax) / (1-propensity_scores) )

ate = mean_y1 - mean_y0
self.params["ate"] = ate

return self

def summary(self) -> None:
super().summary() # Calls BaseEstimator summary
if self.params and "ate" in self.params and self.params["ate"] is not None:
print(f" Estimated ATE: {self.params['ate']:.4f}")
if (
self.params
and "propensity_scores" in self.params
and self.params["propensity_scores"] is not None
):
print(
f" Propensity scores min: {jnp.min(self.params['propensity_scores']):.4f}, max: {jnp.max(self.params['propensity_scores']):.4f}"
)


# Need to add `Any` to imports for type hinting
# from typing import Dict, Optional, Any
# Need to add this at the top of causal.py


class AIPW(BaseEstimator):
"""
Augmented Inverse Propensity Weighting (AIPW) estimator for ATE.
Also known as doubly robust estimator.
"""

def __init__(
self,
outcome_model: Optional[BaseEstimator] = None,
propensity_model: Optional[Any] = None, # Should be a Logit instance or similar
ps_clip_epsilon: float = 1e-6,
):
"""
Initialize the AIPW estimator.

Args:
outcome_model: A regression model (like LinearRegression or a custom one)
to estimate E[Y|X, T=t]. If None, LinearRegression is used.
The model should have a `fit(X,y)` and `predict(X)` method.
propensity_model: A binary classifier (like Logit) to estimate P(T=1|X).
If None, Logit() is used. Model should have `fit(X,T)`
and `predict_proba(X)` methods.
ps_clip_epsilon: Small constant to clip propensity scores to avoid extreme values.
"""
super().__init__()
from .mle import LogisticRegression

self.outcome_model_template = (
outcome_model if outcome_model else LinearRegression()
)
# We need two instances of the outcome model, one for T=1 and one for T=0
self.propensity_model = (
propensity_model if propensity_model else LogisticRegression()
)

self.ps_clip_epsilon = ps_clip_epsilon
self.params: Dict[str, Any] = {
"ate": None,
"propensity_scores": None,
"mu0_params": None,
"mu1_params": None,
}

def fit(
self,
X: jnp.ndarray,
W: jnp.ndarray,
y: jnp.ndarray,
) -> "AIPW":
"""
Estimate the Average Treatment Effect (ATE) using AIPW.

Args:
X: Covariate matrix of shape (n_samples, n_features).
It's assumed that X includes an intercept column if one is desired for the outcome and propensity score models.
T: Treatment assignment vector (binary, 0 or 1) of shape (n_samples,).
y: Outcome vector of shape (n_samples,).

Returns:
The fitted estimator with ATE.
"""
if not isinstance(W, jnp.ndarray):
W_jax = jnp.array(W)
else:
W_jax = W
if not isinstance(y, jnp.ndarray):
y_jax = jnp.array(y)
else:
y_jax = y
if not isinstance(X, jnp.ndarray):
X_jax = jnp.array(X)
else:
X_jax = X

n_samples = X_jax.shape[0]

# 1. Estimate propensity scores P(T=1|X) = e(X)
self.propensity_model.fit(X_jax, W_jax)
propensity_scores = self.propensity_model.predict_proba(X_jax)
propensity_scores = jnp.clip(
propensity_scores, self.ps_clip_epsilon, 1 - self.ps_clip_epsilon
)
self.params["propensity_scores"] = propensity_scores

# 2. Estimate outcome models E[Y|X, T=1] = μ_1(X) and E[Y|X, T=0] = μ_0(X)
# Need to handle potential issues if one group has no samples (though unlikely with real data)
X_treated = X_jax[W_jax == 1]
y_treated = y_jax[W_jax == 1]
X_control = X_jax[W_jax == 0]
y_control = y_jax[W_jax == 0]

# Create fresh instances of the outcome model for fitting
# This assumes the outcome_model_template can be re-used (e.g. by creating a new instance or being stateless after fit)
# For sklearn-like models, this means creating new instances.
# For our JAX models, they are re-fitted.

model1 = (
self.outcome_model_template.__class__()
) # Create a new instance of the same type
if X_treated.shape[0] > 0:
model1.fit(X_treated, y_treated)
mu1_X = model1.predict(X_jax)
self.params["mu1_params"] = model1.params
else: # Should not happen in typical scenarios
mu1_X = jnp.zeros(n_samples)
self.params["mu1_params"] = None

model0 = self.outcome_model_template.__class__() # Create a new instance
if X_control.shape[0] > 0:
model0.fit(X_control, y_control)
mu0_X = model0.predict(X_jax)
self.params["mu0_params"] = model0.params
else: # Should not happen
mu0_X = jnp.zeros(n_samples)
self.params["mu0_params"] = None

# 3. Calculate AIPW estimator components
# ψ_i = μ_1(X_i) - μ_0(X_i) + T_i/e(X_i) * (Y_i - μ_1(X_i)) - (1-T_i)/(1-e(X_i)) * (Y_i - μ_0(X_i))

term1 = mu1_X - mu0_X
term2 = (W_jax / propensity_scores) * (y_jax - mu1_X)
term3 = ((1 - W_jax) / (1 - propensity_scores)) * (y_jax - mu0_X)

psi_i = term1 + term2 - term3

ate = jnp.mean(psi_i)
self.params["ate"] = ate

return self

def summary(self) -> None:
super().summary()
if self.params and "ate" in self.params and self.params["ate"] is not None:
print(f" Estimated ATE (AIPW): {self.params['ate']:.4f}")
if (
self.params
and "propensity_scores" in self.params
and self.params["propensity_scores"] is not None
):
print(
f" Propensity scores min: {jnp.min(self.params['propensity_scores']):.4f}, max: {jnp.max(self.params['propensity_scores']):.4f}"
)
# Could add info about outcome model parameters if desired
Loading