In [1]:
dum1 = (1, (2, 3))
dum2 = (1, (2, None))
dum3 = (1, (None, 3))
dum4 = (1, None)
dum5 = (None, (2, 3))
dum6 = (None, (2, None))
dum7 = (None, (None, 3))
dum8 = (None, None)


def get_system(f_info):
    bounds, constraints = f_info
    if bounds is None:
        if constraints is None:
            # Unconstrained Newton step
            raise NotImplementedError
        else:
            equality, inequality = constraints
            if equality is None:
                if inequality is not None:
                    # InteriorDescent (but not robust)
                    print("No bounds, no equality constraints, inequality constraints.")
                else:
                    assert False
            elif inequality is None:
                if equality is not None:
                    # IPOPTLike descent with dummy bounds
                    print("No bounds, equality constraints, no inequality constraints.")
                else:
                    assert False
            else:
                # No bounds, inequality constraints and equality constraints:
                # XDYcYd descent
                print("No bounds, equality constraints and inequality constraints.")
    else:
        if constraints is None:
            # Bounds without other constraints: not implemented for interior point
            raise NotImplementedError
        else:
            equality, inequality = constraints
            if equality is None:
                if inequality is not None:
                    # Bounds, no equality constraints, inequality constraints
                    raise NotImplementedError
                else:
                    assert False
            elif inequality is None:
                if equality is not None:
                    # Bounds and equality constraints: IPOPTLike descent
                    print("Bounds, equality constraints, no inequality constraints.")
                else:
                    assert False
            else:
                # Bounds, equality and inequality constraints
                raise NotImplementedError
    return f_info


# print(get_system(dum1))  # Not implemented
print(get_system(dum2))  # IPOPTLike descent
# print(get_system(dum3))  # Not implemented
# print(get_system(dum4))  # Not implemented
print(get_system(dum5))  # XDYcYd descent
print(get_system(dum6))  # IPOPTLike descent with dummy bounds
print(get_system(dum7))  # InteriorDescent (but that one sucks)
# print(get_system(dum8))  # A regular Newton step, but we don't patch through

Bounds, equality constraints, no inequality constraints.
(1, (2, None))
No bounds, equality constraints and inequality constraints.
(None, (2, 3))
No bounds, equality constraints, no inequality constraints.
(None, (2, None))
No bounds, no equality constraints, inequality constraints.
(None, (None, 3))


In [2]:
import equinox as eqx
import jax
import jax.numpy as jnp
import optimistix as optx


jax.config.update("jax_disable_jit", False)

In [3]:
def paraboloid(y, args):
    del args
    x1, x2 = y
    return x1**2 + x2**2


def constraint(y):
    x1, x2 = y
    circle = x1**2 + x2**2 - 1  # Equality constraints
    linear = x1 - x2  # Inequality constraint
    return circle, linear


solver = optx.IPOPTLike(rtol=0, atol=1e-3)
descent = optx.NewInteriorDescent()
solver = eqx.tree_at(lambda s: s.descent, solver, descent)

y0 = jnp.ones(2)
bounds = (0 * y0, 2 * y0)

# TODO: remove the default infinite bounds
# TODO: IPOPTLike currently errors out when no constraints are present
solution = optx.minimise(
    paraboloid,
    solver,
    y0,
    constraint=constraint,
    bounds=bounds,
    throw=False,
    max_steps=3,
)

system: _BoundedEqualityInequalityConstrainedKKTSystem()
