Skip to content

Conversation

@apoorvalal
Copy link
Member

Adds Maximum Likelihood Estimators:

  • Base MaximumLikelihoodEstimator class.
  • Logit regression.
  • PoissonRegression.

Adds new Causal Inference methods:

  • IPW (Inverse Propensity Weighting) using Logit for propensity scores.
  • AIPW (Augmented Inverse Propensity Weighting) using configurable outcome and propensity score models (defaults to LinearRegression and Logit).

Includes:

  • Performance enhancement by JIT-compiling relevant functions in existing LinearRegression and EntropyBalancing classes.
  • Unit tests for all new estimators (tests/test_mle.py, tests/test_causal.py) using synthetic data and comparison with sklearn where applicable.
  • Docstrings and API refinements for the new components.

google-labs-jules bot and others added 4 commits June 29, 2025 20:17
Adds Maximum Likelihood Estimators:
- Base `MaximumLikelihoodEstimator` class.
- `Logit` regression.
- `PoissonRegression`.

Adds new Causal Inference methods:
- `IPW` (Inverse Propensity Weighting) using Logit for propensity scores.
- `AIPW` (Augmented Inverse Propensity Weighting) using configurable outcome and propensity score models (defaults to LinearRegression and Logit).

Includes:
- Performance enhancement by JIT-compiling relevant functions in existing `LinearRegression` and `EntropyBalancing` classes.
- Unit tests for all new estimators (`tests/test_mle.py`, `tests/test_causal.py`) using synthetic data and comparison with sklearn where applicable.
- Docstrings and API refinements for the new components.
This resolves a TypeError when calling the JIT-compiled standard error
calculation function with a string argument (se_type) that was not
marked as static. `n` and `k` are also marked static for robustness as they
can influence computations and array shapes known at compile time.
Refactors MaximumLikelihoodEstimator and its subclasses (Logit, PoissonRegression)
to use Optax optimizers directly, replacing the deprecated jaxopt.OptaxSolver.
- MLE `fit` method now uses a standard Optax optimization loop.
- MLE `__init__` accepts an Optax GradientTransformation object.

Updates causal estimators:
- `IPW` constructor now takes `propensity_optimizer` (Optax object) and
  `propensity_maxiter` for its internal Logit model.
- `AIPW` remains compatible, as its default Logit propensity model
  now uses Optax, and a custom-configured Logit (with Optax) can be passed.

Test files (`test_mle.py`, `test_causal.py`) have been updated or confirmed
to be compatible with these changes.
@apoorvalal apoorvalal changed the title feat: Implement MLE and new Causal Inference methods feat: Implement MLE and IPW/AIPW Jun 29, 2025
@apoorvalal apoorvalal force-pushed the feat/mle-causal-methods branch from 7692512 to 04e5a90 Compare June 29, 2025 23:41
@apoorvalal apoorvalal requested a review from Copilot June 30, 2025 01:47
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces new maximum likelihood estimators (Logistic and Poisson regression) along with causal inference methods (IPW and AIPW) while also enhancing performance via JIT-compilation and expanding unit test coverage.

  • Adds MLE models and causal estimators with corresponding unit tests.
  • Refactors naming conventions (e.g. using "coef" consistently) and integrates new methods in the package's public API.

Reviewed Changes

Copilot reviewed 8 out of 9 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
tests/test_mle.py New unit tests added for logistic and Poisson regression models.
tests/test_linear.py Updated to reflect the naming change from "beta" to "coef".
tests/test_causal.py New tests for IPW and AIPW estimators with synthetic data.
nb/linmod.ipynb Minor notebook adjustments (timestamps and execution counts).
jaxonometrics/mle.py Added the MaximumLikelihoodEstimator base and its subclasses.
jaxonometrics/linear.py Updated VCOV computation using JIT and consistent parameter naming.
jaxonometrics/causal.py Added IPW and AIPW implementations and enhanced documentation.
jaxonometrics/init.py Updated public API to include new estimators and models.
Comments suppressed due to low confidence (2)

tests/test_linear.py:22

  • Ensure that the renaming from 'beta' to 'coef' is consistently reflected in all related tests and modules.
    jax_coef = jax_model.params["coef"][1:]

jaxonometrics/causal.py:221

  • [nitpick] When instantiating new outcome models via class(), ensure that any necessary constructor arguments are preserved to avoid losing configuration from the original template.
        model1 = self.outcome_model_template.__class__() # Create a new instance of the same type

self.params["se"] = se_values
else:
# This case should ideally not be reached if fit() is called first.
print("Coefficients not available for SE calculation.")
Copy link

Copilot AI Jun 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Consider raising an error or using a more robust error-handling approach instead of printing a warning when coefficients are unavailable during SE calculation.

Suggested change
print("Coefficients not available for SE calculation.")
raise ValueError("Coefficients are not available. Ensure `fit()` is called before `_vcov()`.")

Copilot uses AI. Check for mistakes.
@apoorvalal apoorvalal merged commit 8e04850 into master Jun 30, 2025
@apoorvalal apoorvalal deleted the feat/mle-causal-methods branch June 30, 2025 14:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants