diff --git a/.gitignore b/.gitignore index cd831ff..68e5fd3 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,4 @@ *.blg *.out *.synctex.gz +nb/scratch.ipynb diff --git a/jaxonometrics/__init__.py b/jaxonometrics/__init__.py index b001477..a01ebb3 100644 --- a/jaxonometrics/__init__.py +++ b/jaxonometrics/__init__.py @@ -5,15 +5,21 @@ __version__ = "0.0.1" from .base import BaseEstimator -from .causal import EntropyBalancing +from .causal import EntropyBalancing, IPW, AIPW # Added IPW, AIPW from .gmm import GMM, LinearIVGMM, TwoStepGMM from .linear import LinearRegression +from .mle import LogisticRegression, PoissonRegression, MaximumLikelihoodEstimator # Added MLE models __all__ = [ "BaseEstimator", "EntropyBalancing", + "IPW", # Added + "AIPW", # Added "GMM", "LinearIVGMM", "TwoStepGMM", "LinearRegression", + "MaximumLikelihoodEstimator", + "LogisticRegression", + "PoissonRegression", ] diff --git a/jaxonometrics/causal.py b/jaxonometrics/causal.py index e03c9d7..8cf9a5f 100644 --- a/jaxonometrics/causal.py +++ b/jaxonometrics/causal.py @@ -1,9 +1,12 @@ -from typing import Dict, Optional +from typing import Dict, Optional, Any +import jax # Ensure jax is imported +import optax # Added for type hint import jax.numpy as jnp from jaxopt import LBFGS from .base import BaseEstimator +from .linear import LinearRegression # For default outcome model in AIPW class EntropyBalancing(BaseEstimator): @@ -19,6 +22,7 @@ def __init__(self): super().__init__() @staticmethod + @jax.jit def _eb_moment(b: jnp.ndarray, X0: jnp.ndarray, X1: jnp.ndarray) -> jnp.ndarray: """The moment condition for entropy balancing.""" return jnp.log(jnp.exp(-1 * X0 @ b).sum()) + X1 @ b @@ -44,3 +48,263 @@ def fit( wt /= wt.sum() self.params = {"weights": wt} return self + + +class IPW(BaseEstimator): + """ + Inverse Propensity Weighting estimator for Average Treatment Effect (ATE). + This implementation uses Logistic Regression for propensity score estimation. + """ + + def __init__( + self, + propensity_optimizer: Optional[optax.GradientTransformation] = None, + propensity_maxiter: int = 5000, + ps_clip_epsilon: float = 1e-6, + ): + """ + Initialize the IPW estimator. + + Args: + propensity_optimizer: Optional Optax optimizer for the Logit propensity score model. + Defaults to optax.adam(1e-3) if None. + propensity_maxiter: Maximum iterations for the Logit model optimization. + ps_clip_epsilon: Small constant to clip propensity scores. + """ + super().__init__() + from .mle import LogisticRegression + + self.logit_model = LogisticRegression( + optimizer=propensity_optimizer, maxiter=propensity_maxiter + ) + self.ps_clip_epsilon = ps_clip_epsilon + self.params: Dict[str, Any] = {"ate": None, "propensity_scores": None} + + def fit( + self, + X: jnp.ndarray, + W: jnp.ndarray, + y: jnp.ndarray, + ) -> "IPW": + """ + Estimate the Average Treatment Effect (ATE) using IPW. + + Args: + X: Covariate matrix of shape (n_samples, n_features). + It's assumed that X includes an intercept column if one is desired for the propensity score model. + T: Treatment assignment vector (binary, 0 or 1) of shape (n_samples,). + y: Outcome vector of shape (n_samples,). + + Returns: + The fitted estimator with ATE and propensity scores. + """ + # Ensure T is jnp.ndarray for Logit model + if not isinstance(W, jnp.ndarray): + W_jax = jnp.array(W) + else: + W_jax = W + + # 1. Estimate propensity scores P(T=1|X) using Logit + self.logit_model.fit(X, W_jax) + propensity_scores = self.logit_model.predict_proba(X) + + # Clip propensity scores using the instance attribute + propensity_scores = jnp.clip( + propensity_scores, self.ps_clip_epsilon, 1 - self.ps_clip_epsilon + ) + + self.params["propensity_scores"] = propensity_scores + + # 2. Calculate IPW weights + # Weight for treated: 1 / p_score + # Weight for control: 1 / (1 - p_score) + weights = W_jax / propensity_scores + (1 - W_jax) / (1 - propensity_scores) + + # 3. Estimate ATE: E[Y(1)] - E[Y(0)] + # E[Y(1)] = sum(T_i * y_i / p_i) / sum(T_i / p_i) + # E[Y(0)] = sum((1-T_i) * y_i / (1-p_i)) / sum((1-T_i) / (1-p_i)) + # ATE = sum( (T_i/p_i - (1-T_i)/(1-p_i)) * y_i ) / N (Hahn, 1998 type estimator) + # Or, more commonly for Horvitz-Thompson type: + # E[Y_1] = sum(T*y / ps) / sum(T/ps) + # E[Y_0] = sum((1-T)*y / (1-ps)) / sum((1-T)/(1-ps)) + # ATE = E[Y_1] - E[Y_0] + + # Using the simpler weighted average formulation for ATE: + # ATE = (1/N) * Σ [ (T_i * Y_i / e(X_i)) - ((1-T_i) * Y_i / (1-e(X_i))) ] + # This can also be seen as E[ (T - e)Y / (e(1-e)) ] + # However, the difference of means of weighted outcomes is more standard: + + mean_y1 = jnp.sum(W_jax * y * weights) / jnp.sum(W_jax * weights) + mean_y0 = jnp.sum((1 - W_jax) * y * weights) / jnp.sum((1 - W_jax) * weights) + + # The above is equivalent to: + # mean_y1 = jnp.sum( (T_jax * y) / propensity_scores ) / jnp.sum( T_jax / propensity_scores ) + # mean_y0 = jnp.sum( ((1-T_jax) * y) / (1-propensity_scores) ) / jnp.sum( (1-T_jax) / (1-propensity_scores) ) + + ate = mean_y1 - mean_y0 + self.params["ate"] = ate + + return self + + def summary(self) -> None: + super().summary() # Calls BaseEstimator summary + if self.params and "ate" in self.params and self.params["ate"] is not None: + print(f" Estimated ATE: {self.params['ate']:.4f}") + if ( + self.params + and "propensity_scores" in self.params + and self.params["propensity_scores"] is not None + ): + print( + f" Propensity scores min: {jnp.min(self.params['propensity_scores']):.4f}, max: {jnp.max(self.params['propensity_scores']):.4f}" + ) + + +# Need to add `Any` to imports for type hinting +# from typing import Dict, Optional, Any +# Need to add this at the top of causal.py + + +class AIPW(BaseEstimator): + """ + Augmented Inverse Propensity Weighting (AIPW) estimator for ATE. + Also known as doubly robust estimator. + """ + + def __init__( + self, + outcome_model: Optional[BaseEstimator] = None, + propensity_model: Optional[Any] = None, # Should be a Logit instance or similar + ps_clip_epsilon: float = 1e-6, + ): + """ + Initialize the AIPW estimator. + + Args: + outcome_model: A regression model (like LinearRegression or a custom one) + to estimate E[Y|X, T=t]. If None, LinearRegression is used. + The model should have a `fit(X,y)` and `predict(X)` method. + propensity_model: A binary classifier (like Logit) to estimate P(T=1|X). + If None, Logit() is used. Model should have `fit(X,T)` + and `predict_proba(X)` methods. + ps_clip_epsilon: Small constant to clip propensity scores to avoid extreme values. + """ + super().__init__() + from .mle import LogisticRegression + + self.outcome_model_template = ( + outcome_model if outcome_model else LinearRegression() + ) + # We need two instances of the outcome model, one for T=1 and one for T=0 + self.propensity_model = ( + propensity_model if propensity_model else LogisticRegression() + ) + + self.ps_clip_epsilon = ps_clip_epsilon + self.params: Dict[str, Any] = { + "ate": None, + "propensity_scores": None, + "mu0_params": None, + "mu1_params": None, + } + + def fit( + self, + X: jnp.ndarray, + W: jnp.ndarray, + y: jnp.ndarray, + ) -> "AIPW": + """ + Estimate the Average Treatment Effect (ATE) using AIPW. + + Args: + X: Covariate matrix of shape (n_samples, n_features). + It's assumed that X includes an intercept column if one is desired for the outcome and propensity score models. + T: Treatment assignment vector (binary, 0 or 1) of shape (n_samples,). + y: Outcome vector of shape (n_samples,). + + Returns: + The fitted estimator with ATE. + """ + if not isinstance(W, jnp.ndarray): + W_jax = jnp.array(W) + else: + W_jax = W + if not isinstance(y, jnp.ndarray): + y_jax = jnp.array(y) + else: + y_jax = y + if not isinstance(X, jnp.ndarray): + X_jax = jnp.array(X) + else: + X_jax = X + + n_samples = X_jax.shape[0] + + # 1. Estimate propensity scores P(T=1|X) = e(X) + self.propensity_model.fit(X_jax, W_jax) + propensity_scores = self.propensity_model.predict_proba(X_jax) + propensity_scores = jnp.clip( + propensity_scores, self.ps_clip_epsilon, 1 - self.ps_clip_epsilon + ) + self.params["propensity_scores"] = propensity_scores + + # 2. Estimate outcome models E[Y|X, T=1] = μ_1(X) and E[Y|X, T=0] = μ_0(X) + # Need to handle potential issues if one group has no samples (though unlikely with real data) + X_treated = X_jax[W_jax == 1] + y_treated = y_jax[W_jax == 1] + X_control = X_jax[W_jax == 0] + y_control = y_jax[W_jax == 0] + + # Create fresh instances of the outcome model for fitting + # This assumes the outcome_model_template can be re-used (e.g. by creating a new instance or being stateless after fit) + # For sklearn-like models, this means creating new instances. + # For our JAX models, they are re-fitted. + + model1 = ( + self.outcome_model_template.__class__() + ) # Create a new instance of the same type + if X_treated.shape[0] > 0: + model1.fit(X_treated, y_treated) + mu1_X = model1.predict(X_jax) + self.params["mu1_params"] = model1.params + else: # Should not happen in typical scenarios + mu1_X = jnp.zeros(n_samples) + self.params["mu1_params"] = None + + model0 = self.outcome_model_template.__class__() # Create a new instance + if X_control.shape[0] > 0: + model0.fit(X_control, y_control) + mu0_X = model0.predict(X_jax) + self.params["mu0_params"] = model0.params + else: # Should not happen + mu0_X = jnp.zeros(n_samples) + self.params["mu0_params"] = None + + # 3. Calculate AIPW estimator components + # ψ_i = μ_1(X_i) - μ_0(X_i) + T_i/e(X_i) * (Y_i - μ_1(X_i)) - (1-T_i)/(1-e(X_i)) * (Y_i - μ_0(X_i)) + + term1 = mu1_X - mu0_X + term2 = (W_jax / propensity_scores) * (y_jax - mu1_X) + term3 = ((1 - W_jax) / (1 - propensity_scores)) * (y_jax - mu0_X) + + psi_i = term1 + term2 - term3 + + ate = jnp.mean(psi_i) + self.params["ate"] = ate + + return self + + def summary(self) -> None: + super().summary() + if self.params and "ate" in self.params and self.params["ate"] is not None: + print(f" Estimated ATE (AIPW): {self.params['ate']:.4f}") + if ( + self.params + and "propensity_scores" in self.params + and self.params["propensity_scores"] is not None + ): + print( + f" Propensity scores min: {jnp.min(self.params['propensity_scores']):.4f}, max: {jnp.max(self.params['propensity_scores']):.4f}" + ) + # Could add info about outcome model parameters if desired diff --git a/jaxonometrics/linear.py b/jaxonometrics/linear.py index 85b6722..f44f279 100644 --- a/jaxonometrics/linear.py +++ b/jaxonometrics/linear.py @@ -1,13 +1,36 @@ from typing import Dict, Optional +from functools import partial import numpy as np - +import jax # Ensure jax is imported import jax.numpy as jnp import lineax as lx from .base import BaseEstimator +# Helper function for JIT compilation of vcov calculations +@partial(jax.jit, static_argnames=['se_type', 'n', 'k']) # Mark se_type, n, and k as static +def _calculate_vcov_details( + coef: jnp.ndarray, X: jnp.ndarray, y: jnp.ndarray, se_type: str, n: int, k: int +): + """Helper function to compute SEs, designed to be JIT compiled.""" + # n and k are also marked static because they are used in calculations + # that might affect array shapes or intermediate computations in ways + # JAX prefers to know at compile time (e.g., n / (n - k)). + # While JAX can often trace through these, being explicit can be safer. + ε = y - X @ coef + if se_type == "HC1": + M = jnp.einsum("ij,i,ik->jk", X, ε**2, X) + XtX_inv = jnp.linalg.inv(X.T @ X) + Σ = XtX_inv @ M @ XtX_inv + return jnp.sqrt((n / (n - k)) * jnp.diag(Σ)) + elif se_type == "classical": + XtX_inv = jnp.linalg.inv(X.T @ X) + return jnp.sqrt(jnp.diag(XtX_inv) * jnp.var(ε, ddof=k)) + return None # Should not be reached if se_type is valid + + class LinearRegression(BaseEstimator): """ Linear regression model using lineax for efficient solving. @@ -48,12 +71,6 @@ def fit( operator=lx.MatrixLinearOperator(X), vector=y, solver=lx.AutoLinearSolver(well_posed=None), - # per lineax docs, passing well_posed None is remarkably general: - # If the operator is non-square, then use lineax.QR. (Most likely case) - # If the operator is diagonal, then use lineax.Diagonal. - # If the operator is tridiagonal, then use lineax.Tridiagonal. - # If the operator is triangular, then use lineax.Triangular. - # If the matrix is positive or negative (semi-)definite, then use lineax.Cholesky. ) self.params = {"coef": sol.value} @@ -61,16 +78,16 @@ def fit( sol = jnp.linalg.lstsq(X, y) self.params = {"coef": sol[0]} elif self.solver == "numpy": # for completeness - X, y = np.array(X), np.array(y) - sol = np.linalg.lstsq(X, y, rcond=None) - self.params = {"coef": jnp.array(sol[0])} + X_np, y_np = np.array(X), np.array(y) # Convert to numpy arrays for numpy solver + sol = np.linalg.lstsq(X_np, y_np, rcond=None) + self.params = {"coef": jnp.array(sol[0])} # Convert back to jax array if se: self._vcov( y=y, X=X, - se=se, - ) # set standard errors in params + se_type=se, # Renamed to avoid conflict with self.se if it existed + ) return self def predict(self, X: jnp.ndarray) -> jnp.ndarray: @@ -82,15 +99,14 @@ def _vcov( self, y: jnp.ndarray, X: jnp.ndarray, - se: str = "HC1", + se_type: str = "HC1", # Renamed from 'se' ) -> None: n, k = X.shape - ε = y - X @ self.params["coef"] - if se == "HC1": - M = jnp.einsum("ij,i,ik->jk", X, ε**2, X) # yer a wizard harry - XtX = jnp.linalg.inv(X.T @ X) - Σ = XtX @ M @ XtX - self.params["se"] = jnp.sqrt((n / (n - k)) * jnp.diag(Σ)) - elif se == "classical": - XtX_inv = jnp.linalg.inv(X.T @ X) - self.params["se"] = jnp.sqrt(jnp.diag(XtX_inv) * jnp.var(ε, ddof=k)) + if self.params and "coef" in self.params: + coef = self.params["coef"] + se_values = _calculate_vcov_details(coef, X, y, se_type, n, k) + if se_values is not None: + self.params["se"] = se_values + else: + # This case should ideally not be reached if fit() is called first. + print("Coefficients not available for SE calculation.") diff --git a/jaxonometrics/mle.py b/jaxonometrics/mle.py new file mode 100644 index 0000000..0881a6e --- /dev/null +++ b/jaxonometrics/mle.py @@ -0,0 +1,232 @@ +from abc import abstractmethod +from typing import Any, Dict, Optional, Tuple + +import jax +import jax.numpy as jnp +import optax + + +from .base import BaseEstimator + + +class MaximumLikelihoodEstimator(BaseEstimator): + """ + Base class for Maximum Likelihood Estimators using Optax. + """ + + def __init__( + self, + optimizer: Optional[optax.GradientTransformation] = None, + maxiter: int = 5000, + tol: float = 1e-4, + ): + super().__init__() + self.optimizer = optimizer if optimizer is not None else optax.lbfgs() + self.maxiter = maxiter + # Tol is not directly used by basic optax loops for stopping but can be a reference + # or used if a convergence check is manually added. + self.tol = tol + self.params: Dict[str, jnp.ndarray] = {} # Initialize params + self.history: Dict[str, list] = {"loss": []} # To store loss history + + @abstractmethod + def _negative_log_likelihood( + self, + params: jnp.ndarray, + X: jnp.ndarray, + y: jnp.ndarray, + ) -> float: + """ + Computes the negative log-likelihood for the model. + Must be implemented by subclasses. + Args: + params: Model parameters. + X: Design matrix. + y: Target vector. + Returns: + Negative log-likelihood value. + """ + raise NotImplementedError + + def fit( + self, + X: jnp.ndarray, + y: jnp.ndarray, + init_params: Optional[jnp.ndarray] = None, + verbose: bool = False, + ) -> "MaximumLikelihoodEstimator": + """ + Fit the model using the specified Optax optimizer. + + Args: + X: Design matrix of shape (n_samples, n_features). + It's assumed that X includes an intercept column if one is desired. + y: Target vector of shape (n_samples,). + init_params: Optional initial parameters. If None, defaults to zeros + or small random numbers if a PRNGKey can be obtained. + + Returns: + The fitted estimator. + """ + n_features = X.shape[1] + if init_params is None: + try: # Try to use a key for initialization for better starting points + key = jax.random.PRNGKey(0) # Simple fixed key for reproducibility + init_params_val = jax.random.normal(key, (n_features,)) * 0.01 + except Exception: # Fallback if key generation fails or not in context + init_params_val = jnp.zeros(n_features) + else: + init_params_val = init_params + + # Define the loss function to be used by value_and_grad + # This function now closes over X and y + def loss_fn(params_lg): + return self._negative_log_likelihood(params_lg, X, y) + + # Get the gradient function + value_and_grad_fn = optax.value_and_grad_from_state(loss_fn) + + # Initialize optimizer state + opt_state = self.optimizer.init(init_params_val) + + current_params = init_params_val + self.history["loss"] = [] # Reset loss history + + # Optimization loop + for i in range(self.maxiter): + loss_val, grads = value_and_grad_fn(current_params, state=opt_state) + updates, opt_state = self.optimizer.update( + grads, + opt_state, + current_params, + value=loss_val, + grad=grads, + value_fn=loss_fn, + ) + current_params = optax.apply_updates( + current_params, + updates, + ) + self.history["loss"].append(loss_val) + if i > 10 and self.tol > 0: + loss_change = abs( + self.history["loss"][-2] - self.history["loss"][-1] + ) / (abs(self.history["loss"][-2]) + 1e-8) + if loss_change < self.tol: + if verbose: + print(f"Convergence tolerance {self.tol} met at iteration {i}.") + break + + self.params = {"coef": current_params} + self.iterations_run = i + 1 # Store how many iterations actually ran + + return self + + def summary(self) -> None: + """Print a summary of the model results.""" + if not self.params or "coef" not in self.params: + print("Model has not been fitted yet.") + return + + print(f"{self.__class__.__name__} Results") + print("=" * 30) + print(f"Optimizer: {self.optimizer}") + if hasattr(self, "iterations_run"): + print( + f"Optimization ran for {self.iterations_run}/{self.maxiter} iterations." + ) + if self.history["loss"]: + print(f"Final Loss: {self.history['loss'][-1]:.4e}") + + print(f"Coefficients: {self.params['coef']}") + print("=" * 30) + + +class LogisticRegression(MaximumLikelihoodEstimator): + """ + Logistic Regression model. + """ + + def _negative_log_likelihood( + self, + params: jnp.ndarray, + X: jnp.ndarray, + y: jnp.ndarray, + ) -> float: + """ + Computes the negative log-likelihood for logistic regression. + NLL = -Σ [y_i * log(p_i) + (1 - y_i) * log(1 - p_i)] + where p_i = σ(X_i @ β) = 1 / (1 + exp(-X_i @ β)) + Using numerically stable log_sigmoid: + log(p_i) = log_sigmoid(X_i @ β) + log(1-p_i) = log_sigmoid(-(X_i @ β)) + """ + logits = X @ params + # alt: Using jax.nn.log_sigmoid for log(σ(z)) and log(1-σ(z)) = log(σ(-z)) + h = jax.scipy.special.expit(logits) + nll = -jnp.sum(y * jnp.log(h) + (1 - y) * jnp.log1p(-h)) + return nll # / X.shape[0] if averaging + + def predict_proba(self, X: jnp.ndarray) -> jnp.ndarray: + """ + Predict probabilities for each class. + Args: + X: Design matrix of shape (n_samples, n_features). + Returns: + Array of probabilities of shape (n_samples,). + """ + if not self.params or "coef" not in self.params: + raise ValueError("Model has not been fitted yet.") + + logits = X @ self.params["coef"] + return jax.nn.sigmoid(logits) # jax.scipy.special.expit is equivalent + + def predict(self, X: jnp.ndarray, threshold: float = 0.5) -> jnp.ndarray: + """ + Predict class labels. + Args: + X: Design matrix of shape (n_samples, n_features). + threshold: Probability threshold for class assignment. + Returns: + Array of predicted class labels (0 or 1). + """ + probas = self.predict_proba(X) + return (probas >= threshold).astype(jnp.int32) + + +class PoissonRegression(MaximumLikelihoodEstimator): + """ + Poisson Regression model. + """ + + def _negative_log_likelihood( + self, + params: jnp.ndarray, + X: jnp.ndarray, + y: jnp.ndarray, + ) -> float: + """ + Computes the negative log-likelihood for Poisson regression. + The log(y_i!) term is constant w.r.t params, so ignored for optimization. + NLL = Σ [exp(X_i @ β) - y_i * (X_i @ β)] + """ + linear_predictor = X @ params + lambda_i = jnp.exp(linear_predictor) # Predicted rates + + # Sum over samples + nll = jnp.sum(lambda_i - y * linear_predictor) + return nll # / X.shape[0] if averaging + + def predict(self, X: jnp.ndarray) -> jnp.ndarray: + """ + Predict expected counts (lambda_i). + Args: + X: Design matrix of shape (n_samples, n_features). + Returns: + Array of predicted counts. + """ + if not self.params or "coef" not in self.params: + raise ValueError("Model has not been fitted yet.") + + linear_predictor = X @ self.params["coef"] + return jnp.exp(linear_predictor) diff --git a/nb/linmod.ipynb b/nb/linmod.ipynb index 33ee301..7a0f105 100644 --- a/nb/linmod.ipynb +++ b/nb/linmod.ipynb @@ -64,7 +64,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "INFO:2025-06-29 11:43:45,267:jax._src.xla_bridge:752: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory\n", + "INFO:2025-06-29 15:46:55,666:jax._src.xla_bridge:752: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory\n", "INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory\n" ] }, @@ -258,7 +258,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 6, "id": "25b0053d", "metadata": {}, "outputs": [ @@ -266,7 +266,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "9.2 ms ± 1.24 ms per loop (mean ± std. dev. of 5 runs, 100 loops each)\n" + "8.98 ms ± 1.47 ms per loop (mean ± std. dev. of 5 runs, 100 loops each)\n" ] } ], @@ -287,7 +287,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 7, "id": "051697e3", "metadata": {}, "outputs": [ @@ -295,7 +295,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "12.3 ms ± 1.57 ms per loop (mean ± std. dev. of 5 runs, 100 loops each)\n" + "10.9 ms ± 617 μs per loop (mean ± std. dev. of 5 runs, 100 loops each)\n" ] } ], diff --git a/pyproject.toml b/pyproject.toml index 316e9c1..a6cb1ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,8 +12,6 @@ dependencies = [ "jaxopt", "lineax", "optax", - "formulaic", - "narwhals", ] requires-python = ">=3.8" authors = [ diff --git a/tests/test_causal.py b/tests/test_causal.py new file mode 100644 index 0000000..18c943d --- /dev/null +++ b/tests/test_causal.py @@ -0,0 +1,165 @@ +import pytest +import jax +import jax.numpy as jnp +import numpy as np +from scipy.special import expit # Sigmoid function, same as jax.nn.sigmoid + +from jaxonometrics.causal import IPW, AIPW +from jaxonometrics.linear import LinearRegression +from jaxonometrics.mle import LogisticRegression + + +# Function to generate synthetic data for causal inference tests +def generate_causal_data( + n_samples=1000, + n_features=3, + true_ate=2.0, + seed=42, +): + """ + Generates synthetic data for testing causal inference estimators. + Args: + n_samples: Number of samples. + n_features: Number of covariates. + true_ate: The true Average Treatment Effect. + seed: Random seed for reproducibility. + + Returns: + X (jnp.ndarray): Covariates. + T (jnp.ndarray): Treatment assignment (0 or 1). + y (jnp.ndarray): Outcome. + true_propensity (jnp.ndarray): True propensity scores. + true_mu0 (jnp.ndarray): True E[Y|X, T=0]. + true_mu1 (jnp.ndarray): True E[Y|X, T=1]. + """ + key = jax.random.PRNGKey(seed) + key_X, key_T_noise, key_Y_noise, key_beta_T, key_beta_Y = jax.random.split(key, 5) + + X = jax.random.normal(key_X, (n_samples, n_features)) + + # True coefficients for propensity score model (logit) + beta_T_true = jax.random.uniform(key_beta_T, (n_features,), minval=-1, maxval=1) + true_propensity_logits = X @ beta_T_true - 0.5 # Centering term + true_propensity = jax.nn.sigmoid(true_propensity_logits) + + # Generate treatment based on true propensity + noise + W_noise = jax.random.uniform(key_T_noise, (n_samples,)) + W = (W_noise < true_propensity).astype(jnp.int32) + + # True coefficients for outcome model (linear) + beta_Y_common = jax.random.uniform( + key_beta_Y, (n_features,), minval=-0.5, maxval=0.5 + ) + + # E[Y|X,T=0] = X @ beta_Y_common + intercept_0 + # E[Y|X,T=1] = X @ beta_Y_common + intercept_0 + true_ate + intercept_0 = 0.5 + + true_mu0 = X @ beta_Y_common + intercept_0 + true_mu1 = X @ beta_Y_common + intercept_0 + true_ate + + # Outcome = T*mu1 + (1-T)*mu0 + noise + Y_noise = jax.random.normal(key_Y_noise, (n_samples,)) * 0.5 # Noise level + y = W * true_mu1 + (1 - W) * true_mu0 + Y_noise + + # Add intercept to X for models that expect it (like our LinearRegression and Logit) + X_intercept = jnp.hstack([jnp.ones((n_samples, 1)), X]) + + return X_intercept, W, y, true_propensity, true_mu0, true_mu1, true_ate + + +@pytest.fixture +def causal_sim_data(): + return generate_causal_data(n_samples=2000, n_features=3, true_ate=1.5, seed=123) + + +def test_ipw_ate_estimation(causal_sim_data): + X, W, y, _, _, _, true_ate = causal_sim_data + + ipw_estimator = IPW( + propensity_maxiter=10000 + ) # Changed logit_maxiter to propensity_maxiter + # The X passed to IPW should not have an intercept if Logit adds one, + # or Logit should be told not to add one. Our Logit currently assumes X has intercept. + ipw_estimator.fit(X, W, y) + estimated_ate = ipw_estimator.params["ate"] + + print(f"IPW - True ATE: {true_ate}, Estimated ATE: {estimated_ate}") + # Check if the estimated ATE is reasonably close to the true ATE. + # This can have some variance due to sampling and model misspecification if any. + assert estimated_ate is not None + np.testing.assert_allclose( + estimated_ate, true_ate, rtol=0.2, atol=0.2 + ) # Looser tolerance for IPW + + +def test_aipw_ate_estimation(causal_sim_data): + X, W, y, _, _, _, true_ate = causal_sim_data + + # Using default LinearRegression for outcome, Logit for propensity + aipw_estimator = AIPW( + outcome_model=LinearRegression(solver="lineax"), # Explicitly pass an instance + propensity_model=LogisticRegression( + maxiter=10000 + ), # Explicitly pass an instance + ) + # X should include intercept for LinearRegression and Logit as currently implemented + aipw_estimator.fit(X, W, y) + estimated_ate = aipw_estimator.params["ate"] + + print(f"AIPW - True ATE: {true_ate}, Estimated ATE: {estimated_ate}") + # AIPW is often more stable and precise if models are reasonably specified. + assert estimated_ate is not None + np.testing.assert_allclose( + estimated_ate, true_ate, rtol=0.15, atol=0.15 + ) # Potentially tighter tolerance for AIPW + + # Test if nuisance model parameters are stored + assert "mu0_params" in aipw_estimator.params + assert "mu1_params" in aipw_estimator.params + assert "propensity_scores" in aipw_estimator.params + assert aipw_estimator.params["mu0_params"] is not None or X[W == 0].shape[0] == 0 + assert aipw_estimator.params["mu1_params"] is not None or X[W == 1].shape[0] == 0 + + +# Test AIPW with pre-fitted models (not typical usage but tests flexibility) +def test_aipw_with_custom_models(causal_sim_data): + X, W, y, _, _, _, true_ate = causal_sim_data + + # 1. Fit propensity score model + ps_model = LogisticRegression(maxiter=10000) + ps_model.fit(X, W) # X includes intercept + + # 2. Fit outcome models + X_treated = X[W == 1] + y_treated = y[W == 1] + X_control = X[W == 0] + y_control = y[W == 0] + + outcome_model_t = LinearRegression() + if X_treated.shape[0] > 0: + outcome_model_t.fit(X_treated, y_treated) + + outcome_model_c = LinearRegression() + if X_control.shape[0] > 0: + outcome_model_c.fit(X_control, y_control) + + # Create a custom outcome "model" object for AIPW that uses pre-fitted models + # This is a bit of a hack for testing; ideally, AIPW would allow passing fitted models + # or be robust to how models are handled. + # Our current AIPW re-fits models, so this test is more about ensuring the logic + # could support it if AIPW was refactored to take already fitted nuisance models. + # For now, we pass new instances that will be re-fitted by AIPW. + + aipw_estimator = AIPW( + outcome_model=LinearRegression(), # It will create new instances and fit + propensity_model=LogisticRegression( + maxiter=10000 + ), # It will create a new instance and fit + ) + aipw_estimator.fit(X, W, y) + estimated_ate = aipw_estimator.params["ate"] + + print(f"AIPW (custom path) - True ATE: {true_ate}, Estimated ATE: {estimated_ate}") + np.testing.assert_allclose(estimated_ate, true_ate, rtol=0.15, atol=0.15) + diff --git a/tests/test_linear.py b/tests/test_linear.py index 70d3c63..1b39163 100644 --- a/tests/test_linear.py +++ b/tests/test_linear.py @@ -19,6 +19,6 @@ def test_linear_regression(): X_with_intercept = jnp.c_[jnp.ones(X.shape[0]), X] jax_model = LinearRegression() jax_model.fit(X_with_intercept, jnp.array(y)) - jax_coef = jax_model.params["beta"][1:] + jax_coef = jax_model.params["coef"][1:] assert np.allclose(sklearn_coef, jax_coef, atol=1e-6) diff --git a/tests/test_mle.py b/tests/test_mle.py new file mode 100644 index 0000000..cd05dca --- /dev/null +++ b/tests/test_mle.py @@ -0,0 +1,143 @@ +import pytest +import jax.numpy as jnp +import numpy as np # For sklearn data generation and comparison +from sklearn.linear_model import LogisticRegression as SklearnLogit +from sklearn.linear_model import PoissonRegressor as SklearnPoisson +from sklearn.datasets import ( + make_classification, + make_regression, +) # make_regression can simulate poisson data +from sklearn.preprocessing import StandardScaler + +from jaxonometrics.mle import LogisticRegression, PoissonRegression + +import optax + + +# Fixture for Logit data +@pytest.fixture +def logit_data(): + X_np, y_np = make_classification( + n_samples=200, + n_features=3, + n_informative=2, + n_redundant=1, + random_state=42, + n_classes=2, + ) + # Add an intercept column + X_np_intercept = np.hstack([np.ones((X_np.shape[0], 1)), X_np]) + scaler = StandardScaler() + X_np_scaled = scaler.fit_transform( + X_np_intercept[:, 1:] + ) # Scale only non-intercept features + X_np_intercept_scaled = np.hstack([X_np_intercept[:, :1], X_np_scaled]) + + return jnp.array(X_np_intercept_scaled), jnp.array(y_np) + + +# Fixture for Poisson data +@pytest.fixture +def poisson_data(): + X_np, y_np_reg = make_regression( + n_samples=200, + n_features=3, + n_informative=2, + n_targets=1, + random_state=123, + noise=0.1, + ) + # Add an intercept column + X_np_intercept = np.hstack([np.ones((X_np.shape[0], 1)), X_np]) + scaler = StandardScaler() + X_np_scaled = scaler.fit_transform(X_np_intercept[:, 1:]) + X_np_intercept_scaled = np.hstack([X_np_intercept[:, :1], X_np_scaled]) + + # Transform y to be count data, ensuring it's non-negative and integer-like for Poisson + # This is a simplistic way to generate data that vaguely resembles Poisson counts + # True Poisson data generation would involve X @ beta and then np.random.poisson(exp(X @ beta)) + # For testing purposes, this should be okay if sklearn's PoissonRegressor can handle it. + y_poisson_np = np.abs( + np.round(np.exp(y_np_reg / np.std(y_np_reg) * 0.5)) + ) # Scale and transform + + # Ensure no zeros if using log-link and some y are zero (though sklearn handles it) + # For our model, y=0 is fine. + + return jnp.array(X_np_intercept_scaled), jnp.array(y_poisson_np) + + +def test_logit_fit_predict(logit_data): + X, y = logit_data + + # Jaxonometrics Logit + jax_logit = LogisticRegression(maxiter=20) + jax_logit.fit(X, y) + jax_coef = jax_logit.params["coef"] + + # Sklearn Logit for comparison (using 'liblinear' which is good for small datasets, no penalty) + # Sklearn's LogisticRegression has regularization by default, so we need to turn it off or make it very weak. + # C is inverse of regularization strength. Large C = weak regularization. + # We also need to tell it not to add an intercept if we already have one. + sklearn_logit = SklearnLogit( + solver="liblinear", + C=1e9, + fit_intercept=False, + random_state=42, + tol=1e-6, + ) + sklearn_logit.fit(np.array(X), np.array(y)) + sklearn_coef = sklearn_logit.coef_.flatten() + + # print(f"Jax Logit Coef: {jax_coef}") + # print(f"Sklearn Logit Coef: {sklearn_coef}") + + assert jax_coef is not None + assert jax_coef.shape == (X.shape[1],) + # Check if coefficients are reasonably close. This can be sensitive to optimizer settings. + # A looser tolerance might be needed depending on exact optimizer behavior. + np.testing.assert_allclose(jax_coef, sklearn_coef, rtol=0.1, atol=0.1) + + # Test predict_proba and predict + jax_probas = jax_logit.predict_proba(X) + jax_preds = jax_logit.predict(X) + + assert jax_probas.shape == (X.shape[0],) + assert jnp.all((jax_probas >= 0) & (jax_probas <= 1)) + assert jax_preds.shape == (X.shape[0],) + assert jnp.all((jax_preds == 0) | (jax_preds == 1)) + + +def test_poisson_fit_predict(poisson_data): + X, y = poisson_data + + # Jaxonometrics Poisson + jax_poisson = PoissonRegression(maxiter=50) + jax_poisson.fit(X, y) + jax_coef = jax_poisson.params["coef"] + + # Sklearn Poisson for comparison + # Sklearn's PoissonRegressor also has alpha for regularization (L2). Set to 0 for no regularization. + sklearn_poisson = SklearnPoisson( + alpha=0, fit_intercept=False, max_iter=1000, tol=1e-6 + ) + sklearn_poisson.fit(np.array(X), np.array(y)) + sklearn_coef = sklearn_poisson.coef_.flatten() + + # print(f"Jax Poisson Coef: {jax_coef}") + # print(f"Sklearn Poisson Coef: {sklearn_coef}") + + assert jax_coef is not None + assert jax_coef.shape == (X.shape[1],) + # Poisson can be a bit more sensitive + np.testing.assert_allclose(jax_coef, sklearn_coef, rtol=0.15, atol=0.15) + + # Test predict + jax_counts = jax_poisson.predict(X) + assert jax_counts.shape == (X.shape[0],) + assert jnp.all(jax_counts >= 0) + + +# It might be good to also add a test for the summary methods, +# but that primarily checks printing, not core functionality. +# For now, focusing on fit/predict.