Skip to content

Commit

Permalink
Updated Arnoldi's backprop
Browse files Browse the repository at this point in the history
  • Loading branch information
AndPotap authored and AndPotap committed Aug 19, 2023
1 parent d325cce commit af58da7
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 35 deletions.
47 changes: 15 additions & 32 deletions cola/algorithms/arnoldi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from functools import partial
from cola.ops import LinearOperator
from cola.ops import Array
from cola.ops import Householder, Product
Expand All @@ -12,40 +11,24 @@ def arnoldi_eigs_bwd(res, grads, unflatten, *args, **kwargs):
val_grads, eig_grads, _ = grads
op_args, (eig_vals, eig_vecs, _) = res
A = unflatten(op_args)
N = A.shape[0]
xnp = A.xnp

def fun(*theta, loc):
Aop = unflatten(theta)
return Aop @ eig_vecs[:, loc]

d_params_vals = []
for idx in range(eig_vecs.shape[-1]):
fn = partial(fun, loc=idx)
dlam = xnp.vjp_derivs(fn, op_args, eig_vecs[:, idx])[0]
required_shape = dlam.shape
d_params_vals.append(dlam.reshape(-1))
d_vals = xnp.stack(d_params_vals)
d_vals = (val_grads @ d_vals).reshape(required_shape)

def fun_eig(*theta, loc):
e = eig_vals
V = eig_vecs # (n, m)
W = eig_grads # (n, m)

def altogether(*theta):
Aop = unflatten(theta)
op_diag = 1. / (eig_vals[loc] - eig_vals)
op_diag = xnp.nan_to_num(op_diag, nan=0., posinf=0., neginf=0.)
D = cola.ops.Diagonal(op_diag)
weights = eig_vecs @ D @ eig_vecs.T
return weights @ Aop @ eig_vecs[:, loc]

d_params_vecs = []
for idx in range(eig_vecs.shape[-1]):
fn = partial(fun_eig, loc=idx)
dl_jac = xnp.jacrev(fn)(*op_args)
dl_jac = dl_jac.reshape(N, N * N)
out = xnp.sum(eig_grads[:, idx] @ dl_jac)
d_params_vecs.append(out)
d_vecs = xnp.stack(d_params_vecs)

d_params = d_vals + d_vecs
AV = Aop @ V
eigs = (AV * V).sum(axis=-2) # output 1
out1 = xnp.sum(eigs * val_grads)
VHAV = V.conj().T @ AV
diff = xnp.nan_to_num(1 / (e[:, None] - e[None, :]), nan=0., posinf=0., neginf=0.)
C = (W.conj().T @ V) * diff
out2 = (C.T * VHAV).sum()
return out1 + out2

d_params = xnp.grad(altogether)(*op_args)
dA = unflatten([d_params])
return (dA, )

Expand Down
1 change: 0 additions & 1 deletion cola/algorithms/lanczos.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from functools import partial
from cola import SelfAdjoint
from cola.fns import lazify
from cola.ops import LinearOperator
Expand Down
4 changes: 2 additions & 2 deletions tests/algorithms/test_arnoldi.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# config.update("jax_enable_x64", True)


@parametrize([torch_fns])
@parametrize([torch_fns, jax_fns])
def test_arnoldi_vjp(xnp):
dtype = xnp.float64
matrix = [[6., 2., 3.], [2., 3., 1.], [3., 1., 4.]]
Expand Down Expand Up @@ -61,7 +61,7 @@ def f_alt(theta):
print(approx)
print(soln)
abs_error = xnp.norm(soln - approx)
assert abs_error < 1e-5
assert abs_error < 5e-5


@parametrize([torch_fns, jax_fns])
Expand Down

0 comments on commit af58da7

Please sign in to comment.