Skip to content

Commit

Permalink
Changed backprop formulation for Lanczos
Browse files Browse the repository at this point in the history
  • Loading branch information
AndPotap authored and AndPotap committed Aug 19, 2023
1 parent 28936d6 commit d325cce
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 34 deletions.
46 changes: 15 additions & 31 deletions cola/algorithms/lanczos.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,40 +26,24 @@ def lanczos_eig_bwd(res, grads, unflatten, *args, **kwargs):
val_grads, eig_grads, _ = grads
op_args, (eig_vals, eig_vecs, _) = res
A = unflatten(op_args)
N = A.shape[0]
xnp = A.xnp

def fun(*theta, loc):
Aop = unflatten(theta)
return Aop @ eig_vecs[:, loc]

d_params_vals = []
for idx in range(eig_vecs.shape[-1]):
fn = partial(fun, loc=idx)
dlam = xnp.vjp_derivs(fn, op_args, eig_vecs[:, idx])[0]
required_shape = dlam.shape
d_params_vals.append(dlam.reshape(-1))
d_vals = xnp.stack(d_params_vals)
d_vals = (val_grads @ d_vals).reshape(required_shape)

def fun_eig(*theta, loc):
e = eig_vals
V = eig_vecs # (n, m)
W = eig_grads # (n, m)

def altogether(*theta):
Aop = unflatten(theta)
op_diag = 1. / (eig_vals[loc] - eig_vals)
op_diag = xnp.nan_to_num(op_diag, nan=0., posinf=0., neginf=0.)
D = cola.ops.Diagonal(op_diag)
weights = eig_vecs @ D @ eig_vecs.T
return weights @ Aop @ eig_vecs[:, loc]

d_params_vecs = []
for idx in range(eig_vecs.shape[-1]):
fn = partial(fun_eig, loc=idx)
dl_jac = xnp.jacrev(fn)(*op_args)
dl_jac = dl_jac.reshape(N, N * N)
out = xnp.sum(eig_grads[:, idx] @ dl_jac)
d_params_vecs.append(out)
d_vecs = xnp.stack(d_params_vecs)

d_params = d_vals + d_vecs
AV = Aop @ V
eigs = (AV * V).sum(axis=-2) # output 1
out1 = xnp.sum(eigs * val_grads)
VHAV = V.conj().T @ AV
diff = xnp.nan_to_num(1 / (e[:, None] - e[None, :]), nan=0., posinf=0., neginf=0.)
C = (W.conj().T @ V) * diff
out2 = (C.T * VHAV).sum()
return out1 + out2

d_params = xnp.grad(altogether)(*op_args)
dA = unflatten([d_params])
return (dA, )

Expand Down
5 changes: 2 additions & 3 deletions tests/algorithms/test_lanczos.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
_tol = 1e-6


# @parametrize([torch_fns, jax_fns])
@parametrize([torch_fns])
@parametrize([torch_fns, jax_fns])
def test_lanczos_vjp(xnp):
dtype = xnp.float64
# diag = xnp.Parameter(xnp.array([3., 4., 5.], dtype=dtype))
Expand Down Expand Up @@ -75,7 +74,7 @@ def f_alt(theta):
print(approx)
print(soln)
abs_error = xnp.norm(soln - approx)
assert abs_error < _tol * 10
assert abs_error < _tol * 50


@parametrize([torch_fns, jax_fns])
Expand Down

0 comments on commit d325cce

Please sign in to comment.