<a href="https://colab.research.google.com/github/profteachkids/CHE2064/blob/master/Broyden.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import jax
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)
from scipy.optimize import root, minimize, NonlinearConstraint

In [None]:
def func(x):
    return jnp.array([jnp.sin(x[0])  + 0.5 * (x[0] - x[1])**3 + 0.1*sqrt(x[0])- 1.0,
            0.5 * (x[1] - x[0])**3 + x[1]])

In [None]:
x = jnp.zeros(2)
J = jax.jacobian(func)(x)

f = func(x)

for i in range(20):
  xp = jnp.linalg.solve(J, -f) + x
  dx = xp - x
  fp = func(xp)
  f= fp
  x= xp
  print(x,fp)
  if jnp.linalg.norm(fp) < 1e-12:
    break

  J = J + jnp.outer(fp,dx)/jnp.linalg.norm(dx)**2






[1. 0.] [ 0.34147098 -0.5       ]
[0.74545034 0.37272517] [-0.29580697  0.34683492]
[0.87452582 0.23424262] [-0.10151446  0.10299655]
[0.94426608 0.18266999] [ 0.03094058 -0.03820376]
[0.92886637 0.19874589] [-0.00445347  0.00414106]
[0.931058   0.19761267] [-0.00047246  0.00033713]
[0.93134054 0.1975642 ] [-3.66025399e-05  2.14376618e-05]
[0.93136529 0.19756394] [-1.63573041e-06  9.88074612e-07]
[0.93136644 0.19756391] [ 9.80559234e-09 -5.98158895e-09]
[0.93136643 0.19756391] [-1.81729076e-11  1.10534082e-11]
[0.93136643 0.19756391] [-1.01030295e-14  6.16173779e-15]


In [None]:
x = jnp.zeros(2)
J = jax.jacobian(func)(x)
Jinv = jnp.linalg.inv(J)

f = func(x)

for i in range(20):
  xp = x - Jinv @ f
  dx = xp - x
  fp = func(xp)
  f= fp
  x= xp
  print(x,f)
  if jnp.linalg.norm(fp) < 1e-12:
    break

  u = jnp.expand_dims(fp,1)
  v = jnp.expand_dims(dx,1)/jnp.linalg.norm(dx)**2
  Jinv = Jinv - Jinv @ u @ v.T @ Jinv / (1 + v.T @ Jinv @ u)

[1. 0.] [ 0.34147098 -0.5       ]
[0.74545034 0.37272517] [-0.29580697  0.34683492]
[0.87452582 0.23424262] [-0.10151446  0.10299655]
[0.94426608 0.18266999] [ 0.03094058 -0.03820376]
[0.92886637 0.19874589] [-0.00445347  0.00414106]
[0.931058   0.19761267] [-0.00047246  0.00033713]
[0.93134054 0.1975642 ] [-3.66025399e-05  2.14376618e-05]
[0.93136529 0.19756394] [-1.63573041e-06  9.88074612e-07]
[0.93136644 0.19756391] [ 9.80559234e-09 -5.98158895e-09]
[0.93136643 0.19756391] [-1.81729076e-11  1.10534082e-11]
[0.93136643 0.19756391] [-1.01030295e-14  6.16173779e-15]


In [None]:
prng = jax.random.PRNGKey(1234)
J = jax.random.uniform(prng, (3,3))
dx = jax.random.uniform(prng, (3,1))
fp = jax.random.uniform(prng, (3,1))
f = -J @ dx

In [None]:
def constraint(x):
    Jp = x.reshape((3,3))
    return jnp.squeeze(Jp @ dx)

nlc = NonlinearConstraint(constraint,jnp.squeeze(fp-f),jnp.squeeze(fp-f), jac=jax.jacobian(constraint))

def func(x):
  Jp = x.reshape((3,3))
  return jnp.linalg.norm(Jp-J)

In [None]:
res = minimize(func, jnp.zeros(9), method='SLSQP', jac=jax.jacobian(func), constraints=(nlc))

In [None]:
res

     fun: 1.0000000367506372
     jac: array([0.88233216, 0.18608908, 0.26313258, 0.18584955, 0.03914606,
       0.05548666, 0.26309337, 0.05546026, 0.07859303])
 message: 'Optimization terminated successfully.'
    nfev: 3
     nit: 3
    njev: 3
  status: 0
 success: True
       x: array([1.60609461, 0.99142794, 0.50577275, 0.38780682, 0.0769163 ,
       0.29498072, 0.98421639, 0.32432932, 0.6728414 ])

In [None]:
J + fp@dx.T/jnp.linalg.norm(dx)**2

DeviceArray([[1.60614037, 0.99120491, 0.50577685],
             [0.38782331, 0.07692147, 0.29492175],
             [0.98425969, 0.32429676, 0.67271918]], dtype=float64)