In [1]:
from magpi.prelude import *
from magpi import calc

In [129]:

def _tree_set(M, v, i):
    return tree_map(lambda M, v: M.at[i].set(v), M, v)

def tree_is_zero(t):
    return tree_reduce(lambda a, b: a & jnp.all(b == 0), t, jnp.array(True))

def lanczos_iteration(k, w0, hvp, last_update=None, *args, **kwargs):
    
    alpha = zeros((k,))
    beta = zeros((k,))
    W = tree_map(lambda t: zeros((k, *t.shape)), w0)
    V = tree_map(lambda t: zeros((k, *t.shape)), w0)
    w_last = w0
    v_last = tree_zeros_like(w0)
    
    def body(i, state):
        (V, W, alpha, beta, v_last, w_last) = state
        beta_i = tree_l2_norm(w_last)
        vi = tree_scalar_mul(1 / beta_i, w_last)
        if last_update is None:
            wi = hvp(vi)
        else:
            wi = lax.cond(
                tree_is_zero(last_update),
                lambda: hvp(vi),
                lambda: lax.cond(i == k-1, lambda: last_update, lambda: hvp(vi)
                )
            )
        W = _tree_set(W, wi, i)
        V = _tree_set(V, vi, i)
        
        alpha_i = tree_vdot(wi, vi)
        wi = tree_sub(wi, tree_scalar_mul(alpha_i, vi))
        wi = tree_sub(wi, tree_scalar_mul(beta_i, v_last))
        alpha = alpha.at[i].set(alpha_i)
        beta = beta.at[i].set(beta_i)
        return (V, W, alpha, beta, vi, wi)

    state = (V, W, alpha, beta, v_last, w_last)
    (V, W, alpha, beta, _, _) = lax.fori_loop(0, k, body, state)
    return V, W, alpha, beta
        
        
# key = random.PRNGKey(42)
# key, _key = random.split(key)
# # x = random.uniform(_key, (10,))
# last_update = tree_map(zeros_like, x)

H = array([[1.0, 2.0, 3.0], [2.0, 1.0, 2.0], [3.0, 2.0, 1.0]])
H = H @ H
def hvp(p):
    w = H @ asarray(p)
    return w
    #return w[0], w[1], w[2]

def f(x):
    return x @ H @ x

def df(x):
    return H @ x

r0 = array([1.0, 1.2, 0.5])
#r0 = (array(1.0), array(1.2), array(0.5))

#hvp = lambda p, t: calc.hvp_forward_over_reverse(f, (p,), (t,))
V, W, alpha, beta = lanczos_iteration(2, r0, hvp, None)#.shape
V, W, alpha, beta

(Array([[ 0.60971075,  0.7316529 ,  0.30485538],
        [ 0.15015638, -0.48427072,  0.86193675]], dtype=float32),
 Array([[18.901033, 15.730537, 17.681612],
        [ 5.87885 ,  5.762495,  8.725971]], dtype=float32),
 Array([28.423792 ,  5.6133747], dtype=float32),
 Array([ 1.6401219, 10.460705 ], dtype=float32))

In [197]:
from jaxopt import ScipyMinimize
        
def tree_gram(a, b):
    vmap_left = jax.vmap(tree_vdot, in_axes=(0,None))
    vmap_right = jax.vmap(vmap_left, in_axes=(None,0))
    return vmap_right(a, b)

x = r0

def sfn(f, init_params, *args, **kwargs):
    state = init_params
    for k in range(5):
        params = state
        hvp = lambda p: calc.hvp_forward_over_reverse(f, (params,), (p,), *args, **kwargs)
        g = grad(f)(params)
        V, W, alpha, beta = lanczos_iteration(3, -g, hvp, None)
        G = tree_gram(V, W)
        lam, v = jnp.linalg.eigh(G)
        def _fun(alpha, params):
            return f(params + alpha @ V)
            
        def fun(l, params, g):
            return _fun(v.T * 1 / (jnp.abs(lam) + l) @ v @ g, params)
        a = zeros_like(alpha)
        _state = (params, a)
        g = -grad(_fun)(a, params)
        for i in range(20):
            params, alpha = _state
            solver = ScipyMinimize("Newton-CG", fun=fun, tol=1e-6)
            l,_ = solver.run(1.0, params, g)
            alpha = v.T * 1 / (jnp.abs(lam) + l) @ v @ g
            params = params + alpha @ V
            _state = (params, alpha)
        
        state = params

    return params

In [198]:
params = sfn(f, ones((3,)))

In [199]:
f(params), f(r0)

(Array(36172432., dtype=float32), Array(76.46, dtype=float32))

In [169]:
params

Array([-7.152033 , -6.4241576, -7.6393805], dtype=float32)

In [83]:
fun(1., alpha)

Array(6939.327, dtype=float32)

In [None]:
v.T * jnp.abs(lam) @ v, G

In [70]:
v.T * jnp.abs(lam) @ v, G

(Array([[5.1206317, 1.4842291],
        [1.484229 , 1.8428023]], dtype=float32),
 Array([[ 4.7918215 ,  2.337143  ],
        [ 2.337143  , -0.36961633]], dtype=float32))

In [45]:
tree_vdot(W, V)

Array(5.9604645e-08, dtype=float32)

In [37]:
V, W, alpha, beta = lanczos_iteration(2, r0, hvp, None)#.shape
V, W, alpha, beta

Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>


((Array([0.60971075, 0.02822144], dtype=float32),
  Array([ 0.7316529, -0.4044105], dtype=float32),
  Array([0.30485538, 0.91414213], dtype=float32)),
 (Array([2.9875827, 1.9618268], dtype=float32),
  Array([2.560785 , 1.4803166], dtype=float32),
  Array([3.5972934 , 0.18998545], dtype=float32)),
 Array([ 4.7918215 , -0.36961633], dtype=float32),
 Array([1.6401219, 2.337143 ], dtype=float32))

In [28]:
V @ H @ V.T

Array([[ 4.7918215 ,  2.337143  ],
       [ 2.337143  , -0.36961633]], dtype=float32)

In [26]:
H

Array([[1., 2., 3.],
       [2., 1., 2.],
       [3., 2., 1.]], dtype=float32)

In [48]:
last_update = tree_map(zeros_like, x)
df = grad(f)
hvp = lambda p, t: calc.hvp_forward_over_reverse(f, (p,), (t,))
lanczos_iteration(5, df, hvp, ones_like(last_update) * 10, x)#.shape

foo
foo
foo
foo


Array([[ 0.09587097,  1.1853564 ,  0.8064642 ,  1.3451769 ,  1.2022831 ,
         1.2225151 ,  0.91227436,  1.4312901 ,  1.2658458 ,  0.08718801],
       [-0.02851362, -0.35254464, -0.23985583, -0.40007794, -0.35757893,
        -0.36359626, -0.2713255 , -0.4256894 , -0.3764835 , -0.02593116],
       [-0.02851362, -0.35254464, -0.2398558 , -0.4000779 , -0.3575789 ,
        -0.36359626, -0.2713255 , -0.4256894 , -0.37648347, -0.02593116],
       [ 0.02851361,  0.35254458,  0.2398558 ,  0.40007788,  0.3575789 ,
         0.3635962 ,  0.27132547,  0.42568937,  0.37648347,  0.02593116],
       [ 0.06783528,  0.35184446,  0.25307408,  0.39350677,  0.3562569 ,
         0.36153105,  0.28065687,  0.41595492,  0.37282652,  0.06557178]],      dtype=float32)

In [46]:
tree_is_zero(last_update)

Array(True, dtype=bool)

In [36]:
def r(a, b):
    print(a, b)
    return jnp.sum(a) + jnp.sum(b)
tree_reduce(r, last_update, 0.0)

0.0 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]


Array(0., dtype=float32)

In [38]:
tree_reduce(lambda a, b: a & jnp.all(b == 0), last_update, jnp.array(True))

Array(True, dtype=bool)

In [25]:
import operator
jax.tree.reduce(operator.add, [1, (2, 3), [4, 5, 6]])

21