In [3]:
import numpy as np

def soft_threshold(x, lambda_):
    return np.sign(x) * np.maximum(np.abs(x) - lambda_, 0)

def admm_lasso(A, b, lambda_, rho=1.0, max_iter=1000, tol=1e-4):
    n, m = A.shape
    x = np.zeros(m)
    z = np.zeros(m)
    u = np.zeros(m)

    Atb = A.T @ b
    L = np.linalg.cholesky(A.T @ A + rho * np.eye(m))

    for _ in range(max_iter):
        x_old = x.copy()
        z_old = z.copy()

        # x update
        q = Atb + rho * (z - u)
        x = np.linalg.solve(L.T, np.linalg.solve(L, q))

        # z update
        z = soft_threshold(x + u, lambda_ / rho)

        # u update
        u += x - z

        # Check convergence
        if np.linalg.norm(x - x_old) < tol and np.linalg.norm(z - z_old) < tol:
            break

    return x

# Example usage
if __name__ == "__main__":
    np.random.seed(0)
    n, m = 50, 200
    A = np.random.randn(n, m)
    x_true = np.random.randn(m)
    x_true[np.random.rand(m) < 0.8] = 0
    b = A @ x_true + 0.5 * np.random.randn(n)

    lambda_ = 0.1
    x_est = admm_lasso(A, b, lambda_)
    print("True x:", x_true)
    print("Estimated x:", x_est)


True x: [ 0.          0.          0.          0.190649    0.          0.
  0.          0.          0.         -0.62325381  0.          0.
  0.          0.          0.          0.          0.         -0.40125494
  0.          0.          0.          0.91706862  0.          0.55203214
  0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.17294185  0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.         -1.3630435
  0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.         -0.85432549
  0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.         -0.50465599
  0.          0.          0.74653145  1.79105614  0.          0.
  0.07114298  0.          0.          0.          0.          0.
 -1.91167094  0.93480541  0.          0.   