In [1]:
import pyomo.environ as pyo
from pyomo.opt import SolverFactory


In [2]:
model = pyo.ConcreteModel()
model.nVars = pyo.Param(initialize=4)
model.N = pyo.RangeSet(model.nVars)
model.x = pyo.Var(model.N, within=pyo.Binary)
model.obj = pyo.Objective(expr=pyo.summation(model.x))
model.cuts = pyo.ConstraintList()
opt = SolverFactory("glpk")
opt.solve(model)

# Iterate, adding a cut to exclude the previously found solution
for i in range(5):
    expr = 0
    for j in model.x:
        if pyo.value(model.x[j]) < 0.5:
            expr += model.x[j]
        else:
            expr += 1 - model.x[j]
    model.cuts.add(expr >= 1)
    results = opt.solve(model)
    print("\n===== iteration", i)
    model.display()


===== iteration 0
Model unknown

  Variables:
    x : Size=4, Index=N
        Key : Lower : Value : Upper : Fixed : Stale : Domain
          1 :     0 :   1.0 :     1 : False : False : Binary
          2 :     0 :   0.0 :     1 : False : False : Binary
          3 :     0 :   0.0 :     1 : False : False : Binary
          4 :     0 :   0.0 :     1 : False : False : Binary

  Objectives:
    obj : Size=1, Index=None, Active=True
        Key  : Active : Value
        None :   True :   1.0

  Constraints:
    cuts : Size=1
        Key : Lower : Body : Upper
          1 :   1.0 :  1.0 :  None

===== iteration 1
Model unknown

  Variables:
    x : Size=4, Index=N
        Key : Lower : Value : Upper : Fixed : Stale : Domain
          1 :     0 :   0.0 :     1 : False : False : Binary
          2 :     0 :   1.0 :     1 : False : False : Binary
          3 :     0 :   0.0 :     1 : False : False : Binary
          4 :     0 :   0.0 :     1 : False : False : Binary

  Objectives:
    obj : Si

In [42]:
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
import sympy

import sympy2jax

def assert_equal(x, y):
    x_leaves, x_tree = jtu.tree_flatten(x)
    y_leaves, y_tree = jtu.tree_flatten(y)
    assert x_tree == y_tree
    for xi, yi in zip(x_leaves, y_leaves):
        assert type(xi) is type(yi)
        if isinstance(xi, jnp.ndarray):
            assert xi.shape == yi.shape
            assert xi.dtype == yi.dtype
            assert jnp.all(xi == yi)
        else:
            assert xi == yi

def assert_sympy_allclose(x, y):
    assert isinstance(x, sympy.Expr)
    assert isinstance(y, sympy.Expr)
    assert x.func is y.func
    if isinstance(x, sympy.Float):
        assert abs(float(x) - float(y)) < 1e-5
    elif isinstance(x, sympy.Integer):
        assert x == y
    elif isinstance(x, sympy.Rational):
        assert x.numerator == y.numerator  # pyright: ignore
        assert x.denominator == y.denominator  # pyright: ignore
    elif isinstance(x, sympy.Symbol):
        assert x.name == y.name  # pyright: ignore
    else:
        assert len(x.args) == len(y.args)
        for xarg, yarg in zip(x.args, y.args):
            assert_sympy_allclose(xarg, yarg)

x_sym = sympy.symbols("x_sym")
y = 2.1 * x_sym**2
mod = sympy2jax.SymbolicModule(y)
x = jnp.array(1.1)

grad_m = eqx.filter_grad(lambda m, z: m(x_sym=z))(mod, x)
grad_z = eqx.filter_grad(lambda z, m: m(x_sym=z))(x, mod)

true_grad_m = eqx.filter(
    sympy2jax.SymbolicModule(1.21 * x_sym**2), eqx.is_inexact_array
)
true_grad_z = jnp.array(4.2 * x)
print(grad_z)

assert_equal(grad_m, true_grad_m)
assert_equal(grad_z, true_grad_z)

mod2 = eqx.apply_updates(mod, grad_m)
expr = mod2.sympy()

assert_sympy_allclose(expr, 3.31 * x_sym**2)

TypeError: SymbolicModule.__call__() takes 1 positional argument but 2 were given