In [1]:
%matplotlib inline


# Implicit differentiation of lasso.


In [5]:
import sys
sys.argv = sys.argv[:1]
from absl import app
from absl import flags
import jax
import jax.numpy as jnp
from jaxopt import BlockCoordinateDescent
from jaxopt import objective
from jaxopt import OptaxSolver
from jaxopt import prox
from jaxopt import ProximalGradient
import optax

from sklearn import datasets
from sklearn import model_selection
from sklearn import preprocessing

flags.DEFINE_bool("unrolling", True, "Whether to use unrolling.")
flags.DEFINE_string("solver", "bcd", "Solver to use (bcd or pg).")
FLAGS = flags.FLAGS

def outer_objective(theta, init_inner, data):
    """
    Validation loss.
    """
    X_tr, X_val, y_tr, y_val = data
    # We use the bijective mapping lam = jnp.exp(theta) to ensure positivity.
    lam = jnp.exp(theta)
    if FLAGS.solver == "pg":
        solver = ProximalGradient(
            fun=objective.least_squares,
            prox=prox.prox_lasso,
            implicit_diff=not FLAGS.unrolling,
            maxiter=500)
    elif FLAGS.solver == "bcd":
        solver = BlockCoordinateDescent(
            fun=objective.least_squares,
            block_prox=prox.prox_lasso,
            implicit_diff=not FLAGS.unrolling,
            maxiter=500)
    else:
        raise ValueError("Unknown solver.")

    # The format is run(init_params, hyperparams_prox, *args, **kwargs)
    # where *args and **kwargs are passed to `fun`.
    w_fit = solver.run(init_inner, lam, (X_tr, y_tr)).params

    y_pred = jnp.dot(X_val, w_fit)
    loss_value = jnp.mean((y_pred - y_val) ** 2)

    # We return w_fit as auxiliary data.
    # Auxiliary data is stored in the optimizer state (see below).
    return loss_value, w_fit


def main(argv):
    del argv
    print("Solver:", FLAGS.solver)
    print("Unrolling:", FLAGS.unrolling)

    # Prepare data.
    X, y = datasets.load_boston(return_X_y=True)
    X = preprocessing.normalize(X)
    # data = (X_tr, X_val, y_tr, y_val)
    data = model_selection.train_test_split(X, y, test_size=0.33, random_state=0)

    # Initialize solver.
    solver = OptaxSolver(opt=optax.adam(1e-2), fun=outer_objective, has_aux=True)
    theta = 1.0
    state = solver.init_state(theta)
    init_w = jnp.zeros(X.shape[1])

    # Run outer loop.
    for _ in range(10):
        theta, state = solver.update(params=theta, state=state, init_inner=init_w,
                                     data=data)
        # The auxiliary data returned by the outer loss is stored in the state.
        init_w = state.aux
        print(f"[Step {state.iter_num}] Validation loss: {state.value:.3f}.")


DuplicateFlagError: The flag 'unrolling' is defined twice. First from /home/sam/anaconda3/envs/mohca/lib/python3.9/site-packages/ipykernel_launcher.py, Second from /home/sam/anaconda3/envs/mohca/lib/python3.9/site-packages/ipykernel_launcher.py.  Description from first occurrence: Whether to use unrolling.

In [3]:
app.run(main)


    The Boston housing prices dataset has an ethical problem. You can refer to
    the documentation of this function for further details.

    The scikit-learn maintainers therefore strongly discourage the use of this
    dataset unless the purpose of the code is to study and educate about
    ethical issues in data science and machine learning.

    In this special case, you can fetch the dataset from the original
    source::

        import pandas as pd
        import numpy as np


        data_url = "http://lib.stat.cmu.edu/datasets/boston"
        raw_df = pd.read_csv(data_url, sep="\s+", skiprows=22, header=None)
        data = np.hstack([raw_df.values[::2, :], raw_df.values[1::2, :2]])
        target = raw_df.values[1::2, 2]

    Alternative datasets include the California housing dataset (i.e.
    :func:`~sklearn.datasets.fetch_california_housing`) and the Ames housing
    dataset. You can load the datasets as follows::

        from sklearn.datasets import fetch_california_h

Solver: bcd
Unrolling: False
[Step 1] Validation loss: 89.985.
[Step 2] Validation loss: 79.132.
[Step 3] Validation loss: 77.374.
[Step 4] Validation loss: 76.564.
[Step 5] Validation loss: 76.100.
[Step 6] Validation loss: 75.832.
[Step 7] Validation loss: 75.679.
[Step 8] Validation loss: 75.588.
[Step 9] Validation loss: 75.526.
[Step 10] Validation loss: 75.474.


SystemExit: 

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [4]:
%tb

SystemExit: 