<a href="https://colab.research.google.com/github/profteachkids/CHE2064/blob/master/Broyden.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)
from scipy.optimize import root, minimize, NonlinearConstraint
eps=1e-12

In [2]:
def func(x):
    return jnp.array([jnp.sin(x[0])  + 0.5 * (x[0] - x[1])**3 - 1.0,
            0.5 * (x[1] - x[0])**3 + x[1]])

In [3]:
# Broyden update with no Sherman-Morrison
x = jnp.zeros(2)
J = jax.jacobian(func)(x)

f = func(x)

for i in range(20):
  xp = jnp.linalg.solve(J, -f) + x
  dx = xp - x
  fp = func(xp)
  f= fp
  x= xp
  print(x,fp)
  if jnp.linalg.norm(fp) < 1e-12:
    break

  J = J + jnp.outer(fp,dx)/jnp.linalg.norm(dx)**2



[1. 0.] [ 0.34147098 -0.5       ]
[0.74545034 0.37272517] [-0.29580697  0.34683492]
[0.87452582 0.23424262] [-0.10151446  0.10299655]
[0.94426608 0.18266999] [ 0.03094058 -0.03820376]
[0.92886637 0.19874589] [-0.00445347  0.00414106]
[0.931058   0.19761267] [-0.00047246  0.00033713]
[0.93134054 0.1975642 ] [-3.66025399e-05  2.14376618e-05]
[0.93136529 0.19756394] [-1.63573041e-06  9.88074612e-07]
[0.93136644 0.19756391] [ 9.80559234e-09 -5.98158895e-09]
[0.93136643 0.19756391] [-1.81729076e-11  1.10534082e-11]
[0.93136643 0.19756391] [-1.01030295e-14  6.16173779e-15]


In [4]:
def broyden(func, x, J=None, max_iter=100, verbose=0):
  J = jax.jacobian(func)(x) if J is None else J(x)
  Jinv = jnp.linalg.inv(J)
  f = func(x)

  for i in range(max_iter):
    xp = x - Jinv @ f
    dx = xp - x
    fp = func(xp)
    f= fp
    x= xp
    if verbose>0:
      print(x, f)
    if jnp.linalg.norm(fp) < 1e-12:
      break

    u = jnp.expand_dims(fp,1)
    v = jnp.expand_dims(dx,1)/jnp.linalg.norm(dx)**2
    Jinv = Jinv - Jinv @ u @ v.T @ Jinv / (1 + v.T @ Jinv @ u)  #Sherman-Morrison
  return x, f

broyden(func, jnp.zeros(2))

(DeviceArray([0.93136643, 0.19756391], dtype=float64),
 DeviceArray([-1.01030295e-14,  6.16173779e-15], dtype=float64))

In [5]:
prng = jax.random.PRNGKey(1234)
J = jax.random.uniform(prng, (3,3))
f = jax.random.uniform(prng, (3,1))
fp = jax.random.uniform(prng, (3,1))
dx = jnp.linalg.solve(J,-f)

def constraint(x):
    Jp = x.reshape((3,3))
    return jnp.squeeze(Jp @ dx)

nlc = NonlinearConstraint(constraint,jnp.squeeze(fp-f),jnp.squeeze(fp-f), jac=jax.jacobian(constraint))

def func(x):
  Jp = x.reshape((3,3))
  return jnp.linalg.norm(Jp-J)

res = minimize(func, jnp.zeros(9), method='SLSQP', jac=jax.jacobian(func), constraints=(nlc))
print(res.x.reshape((3,3))) #via constrained minimization
print(J + fp@dx.T/jnp.linalg.norm(dx)**2) #via Broyden update

[[0.81624494 0.71880643 0.15777668]
 [0.22143383 0.01955389 0.22160245]
 [0.7486602  0.24306155 0.56889726]]
[[0.8163133  0.71886663 0.15778989]
 [0.22145238 0.01955553 0.221621  ]
 [0.7487229  0.24308191 0.56894491]]


In [6]:
# Accomodate box bounds on variables

def broyden2(func, x, J=None, max_iter=100, verbose=0, xmax=jnp.inf, xmin=-jnp.inf):
  Jf = jax.jacobian(func) if J is None else J
  J = Jf(x)
  Jinv = jnp.linalg.inv(J)
  f = func(x)

  for i in range(max_iter):
    dx = - Jinv @ f

    alpha_max_limits = jnp.min(jnp.where(x + dx > xmax, (xmax - x) / (dx), 1))
    alpha_min_limits = jnp.min(jnp.where(x + dx < xmin, (xmin - x) / (dx), 1))
    alpha = min(alpha_max_limits, alpha_min_limits)

    print(alpha)
    dx = alpha*dx
    xp = x + dx
    fp = func(xp)
    f= fp
    x= xp
    if verbose>0:
      print(x, f)
    if jnp.linalg.norm(fp) < 1e-12:
      break

    u = jnp.expand_dims(fp,1)
    v = jnp.expand_dims(dx,1)/jnp.linalg.norm(dx)**2
    Jinv = Jinv - Jinv @ u @ v.T @ Jinv / (1 + v.T @ Jinv @ u)  #Sherman-Morrison
  return x, f

In [7]:
def func2(x):
    return jnp.array([jnp.sin(x[0])  + 0.5 * (x[0] - x[1])**3 - 0.01*jnp.sqrt(x[1]-0.1) - 1.0,
            0.5 * (x[1] - x[0])**3 + x[1] + 0.001*jnp.sqrt(1.-x[0])])

In [8]:
broyden(func2, 0.95*jnp.ones(2),verbose=1, max_iter=20)

[1.27776387e+00 5.09295507e-04] [nan nan]
[nan nan] [nan nan]
[nan nan] [nan nan]
[nan nan] [nan nan]
[nan nan] [nan nan]
[nan nan] [nan nan]
[nan nan] [nan nan]
[nan nan] [nan nan]
[nan nan] [nan nan]
[nan nan] [nan nan]
[nan nan] [nan nan]
[nan nan] [nan nan]
[nan nan] [nan nan]
[nan nan] [nan nan]
[nan nan] [nan nan]
[nan nan] [nan nan]
[nan nan] [nan nan]
[nan nan] [nan nan]
[nan nan] [nan nan]
[nan nan] [nan nan]


(DeviceArray([nan, nan], dtype=float64),
 DeviceArray([nan, nan], dtype=float64))

In [9]:
broyden2(func2, 0.95*jnp.ones(2), verbose=1, max_iter=20, xmin=jnp.array([-jnp.inf, 0.1]), xmax = jnp.array([1., jnp.inf]))

0.1525488433491314
[1.         0.80515629] [-0.16322784  0.80145776]
1.0
[0.93959614 0.98225636] [-0.20211182  0.98254095]
0.19327261837160434
[1.         0.79656208] [-0.1626652   0.79235224]
1.0
[0.92070903 1.0408568 ] [-0.21453601  1.04200559]
0.24205077931482566
[1.         0.79406184] [-0.16249308  0.78969486]
1.0
[0.88356639 1.15670184] [-0.24746165  1.16723142]
0.3484361198064109
[1.         0.76610502] [-0.16029271  0.75970719]
1.0
[0.75065498 1.6076198 ] [-0.6448333   1.92279178]
1.0
[0.9274454  0.20208157] [-0.01227885  0.01152487]
1.0
[0.94127706 0.19437996] [ 0.01356875 -0.01370796]
1.0
[0.93420749 0.19877441] [-0.00013078  0.00014708]
1.0
[0.93426884 0.1987212 ] [-5.09566974e-07  7.82997886e-07]
1.0
[0.93426899 0.19872083] [ 1.0061536e-08 -8.9609632e-09]
1.0
[0.93426899 0.19872083] [ 2.1699087e-10 -1.9291255e-10]
1.0
[0.93426899 0.19872083] [ 1.08801856e-14 -9.52585450e-15]


(DeviceArray([0.93426899, 0.19872083], dtype=float64),
 DeviceArray([ 1.08801856e-14, -9.52585450e-15], dtype=float64))

In [32]:
# Limit step size to ensure a decrease in norm

def broyden3(func, x, J=None, max_iter=100, verbose=0, xmax=jnp.inf, xmin=-jnp.inf):
    Jf = jax.jacobian(func) if J is None else J
    J = Jf(x)
    Jinv = jnp.linalg.inv(J)
    f = func(x)
    print(Jinv)

    for i in range(max_iter):

        dx = - Jinv @ f
        if verbose>0:
            print(f"\nIter: {i}  dx: {dx}")
        alpha_max_limits = jnp.min(jnp.where(x + dx > xmax, (xmax - x) / (dx), 1))
        alpha_min_limits = jnp.min(jnp.where(x + dx < xmin, (xmin - x) / (dx), 1))
        alpha = min(alpha_max_limits, alpha_min_limits)

        while alpha > 0.01:
            dx_try = alpha*dx
            xp = x + dx_try
            fp = func(xp)
            dnorm = jnp.linalg.norm(fp)-jnp.linalg.norm(f)
            if verbose>1:
                print(f"Alpha {alpha}   dnorm {dnorm}  dx_try {dx_try}   f {f}    fp {fp}")
            if dnorm > 0:
                alpha *= 0.5
            else:
                break
        if alpha <= 0.01:
            if verbose>0:
                print("reevaluate J")
            Jinv = jnp.linalg.inv(Jf(x))
            continue

        dx=dx_try
        f= fp
        x= xp
        if verbose>0:
          print(x, f)
        if jnp.linalg.norm(fp) < 1e-12:
          break

        u = jnp.expand_dims(fp,1)
        v = jnp.expand_dims(dx,1)/jnp.linalg.norm(dx)**2
        Jinv = Jinv - Jinv @ u @ v.T @ Jinv / (1 + v.T @ Jinv @ u)  #Sherman-Morrison
    return x, f

In [None]:
broyden3(func2, 0.95*jnp.ones(2), verbose=1, max_iter=20, xmin=jnp.array([-jnp.inf, 0.1+eps]), xmax = jnp.array([1.-eps, jnp.inf]))

In [77]:
# One equality constraint

def rosen(x):
    return 100*(x[1]-x[0]**2)**2 + (1-x[0])**2

def constr(x):
    return 2*x[0] + x[1] - 1

def grads(x):
    return jax.grad(rosen)(x[:2]) + x[2]*jax.grad(constr)(x[:2])

@jax.jit
def eqs(x):
    return jnp.concatenate([grads(x), jnp.atleast_1d(constr(x))])

x0 = jnp.array([0.,0., 1.])
broyden3(eqs, x0)

[[ 0.00124688 -0.00249377  0.49875312]
 [-0.00249377  0.00498753  0.00249377]
 [ 0.49875312  0.00249377 -0.49875312]]


(DeviceArray([0.41494432, 0.17011137, 0.41348319], dtype=float64),
 DeviceArray([ 1.03250741e-14, -5.82173199e-15,  0.00000000e+00], dtype=float64))

In [78]:
def L(x):
  return rosen(x)-x[2]*constr(x)

dL = jax.jit(jax.grad(L))
x0 = jnp.array([0.,0., 1.])
broyden3(dL, x0)


[[ 0.00124688 -0.00249377 -0.49875312]
 [-0.00249377  0.00498753 -0.00249377]
 [-0.49875312 -0.00249377 -0.49875312]]


(DeviceArray([ 0.41494432,  0.17011137, -0.41348319], dtype=float64),
 DeviceArray([ 1.21760241e-14, -4.87804241e-15,  0.00000000e+00], dtype=float64))

In [74]:
# One inequality constraint

def rosen(x):
    return 100*(x[1]-x[0]**2)**2 + (1-x[0])**2

def constr(x):
    return x[0]**2 + x[1] + x[2]**2 - 1

def grads(x):
    return jax.jacobian(rosen)(x[:3]) + x[3]*jax.jacobian(constr)(x[:3])

@jax.jit
def eqs(x):
    return jnp.concatenate([grads(x), jnp.atleast_1d(constr(x))])

x0 = jnp.array([0.1,0.1, 10., 1.])
x, f= broyden3(eqs, x0, max_iter=500, verbose=0)
print(x)
print(rosen(x[:3]))

[[-3.12507812e-02 -6.24984375e-03  6.25000000e-04 -6.25000000e-05]
 [-6.24984375e-03  3.74996875e-03 -1.25000000e-04  1.25000000e-05]
 [ 6.25000000e-04 -1.25000000e-04  0.00000000e+00  5.00000000e-02]
 [-6.25000000e-05  1.25000000e-05  5.00000000e-02 -5.00000000e-03]]
[ 7.07472158e-01  4.99483146e-01 -3.44769115e-16  2.06741593e-01]
0.08567939371082826


In [None]:
def L(x):
  return rosen(x)-x[3]*constr(x)

dL = jax.jit(jax.grad(L)) 
x0 = jnp.array([0.1,0.1, 10., 1.])
x,f=broyden3(dL, x0, max_iter=500, verbose=1)
print(x)
print(rosen(x[:3]))