In [1]:
import jax
import jax.numpy as jnp
from mcnnm import estimate, generate_data, complete_matrix

In [4]:
jax.config.update('jax_platforms', 'cpu')  # Avoid Problems with Metal on Apple Silicon Machines

# Example 1: Generate Data and Estimate the Treatment Effect using Holdout Validation (Faster than Cross-Validation but potentially less accurate)

In [5]:
data, true_params = generate_data(nobs=500, nperiods=30, seed=42, assignment_mechanism='last_periods',
                                  X_cov=False, Z_cov=False, V_cov=False, treatment_probability=0.4)

Y = jnp.array(data.pivot(index='unit', columns='period', values='y').values)
W = jnp.array(data.pivot(index='unit', columns='period', values='treat').values)

results = estimate(Y, W, validation_method='holdout')

print(f"\nTrue effect: {true_params['treatment_effect']}, Estimated effect: {results.tau:.4f}")
print(f"Chosen lambda_L: {results.lambda_L:.4f}")


True effect: 1.0, Estimated effect: 1.0044
Chosen lambda_L: 0.0010
