Skip to content

Commit

Permalink
Ensure PyTorch cuda compatability
Browse files Browse the repository at this point in the history
Previously, many arrays were being created without specifying the
device. This omission caused device errors when using CoLA with PyTorch
and cuda.

This PR adds the device kwargs to all uses of the following xnp methods:
- zeros
- ones
- array
- canonical
- eye
  • Loading branch information
gpleiss committed Aug 21, 2023
1 parent 91ce584 commit 505f471
Show file tree
Hide file tree
Showing 14 changed files with 68 additions and 55 deletions.
17 changes: 10 additions & 7 deletions cola/algorithms/arnoldi.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ def get_householder_vec(x, idx, xnp):
def run_householder_arnoldi(A: LinearOperator, rhs: Array, max_iters: int):
xnp = A.xnp
dtype = A.dtype
Ps = [Householder(xnp.zeros((rhs.shape[-2], 1), dtype=dtype)) for _ in range(max_iters + 2)]
device = A.device
Ps = [Householder(xnp.zeros((rhs.shape[-2], 1), dtype=dtype, device=device)) for _ in range(max_iters + 2)]

def body_fun(idx, state):
Q, H, zj = state
Expand Down Expand Up @@ -178,8 +179,9 @@ def last_iter_fun(state):


def initialize_householder_arnoldi(xnp, rhs, max_iters, dtype):
H = xnp.zeros(shape=(max_iters, max_iters + 1), dtype=dtype)
Q = xnp.zeros(shape=(rhs.shape[-2], max_iters + 1), dtype=dtype)
device = xnp.get_device(rhs)
H = xnp.zeros(shape=(max_iters, max_iters + 1), dtype=dtype, device=device)
Q = xnp.zeros(shape=(rhs.shape[-2], max_iters + 1), dtype=dtype, device=device)
rhs = rhs / xnp.norm(rhs)
Q = xnp.update_array(Q, xnp.copy(rhs[:, 0]), ..., 0)
zj = Q[:, 0]
Expand All @@ -199,7 +201,7 @@ def cond_fun(state):
def body_fun(state):
Q, H, idx, _ = state
new_vec = A @ Q[..., idx, :]
h_vec = xnp.zeros(shape=(max_iters + 1, rhs.shape[-1]), dtype=new_vec.dtype)
h_vec = xnp.zeros(shape=(max_iters + 1, rhs.shape[-1]), dtype=new_vec.dtype, device=xnp.get_device(new_vec))

def inner_loop(jdx, result):
new_vec, h_vec = result
Expand All @@ -225,9 +227,10 @@ def inner_loop(jdx, result):


def initialize_arnoldi(xnp, rhs, max_iters, dtype):
idx = xnp.array(0, dtype=xnp.int32)
H = xnp.zeros(shape=(max_iters + 1, max_iters, rhs.shape[-1]), dtype=dtype)
Q = xnp.zeros(shape=(rhs.shape[-2], max_iters + 1, rhs.shape[-1]), dtype=dtype)
device = xnp.get_device(rhs)
idx = xnp.array(0, dtype=xnp.int32, device=device)
H = xnp.zeros(shape=(max_iters + 1, max_iters, rhs.shape[-1]), dtype=dtype, device=device)
Q = xnp.zeros(shape=(rhs.shape[-2], max_iters + 1, rhs.shape[-1]), dtype=dtype, device=device)
rhs = rhs / xnp.norm(rhs, axis=-2)
Q = xnp.update_array(Q, xnp.copy(rhs), ..., 0, slice(None, None, None))
norm = xnp.norm(rhs, axis=-2)
Expand Down
23 changes: 14 additions & 9 deletions cola/algorithms/cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,10 @@ def initialize(A, b, preconditioner, x0, xnp):
r0 = b - A @ x0
z0 = preconditioner @ r0
p0 = z0
device = A.device
gamma0 = xnp.sum(xnp.conj(r0) * z0, axis=-2, keepdims=True)
alpha0 = xnp.zeros(shape=gamma0.shape, dtype=r0.dtype)
beta0 = xnp.zeros(shape=gamma0.shape, dtype=r0.dtype)
alpha0 = xnp.zeros(shape=gamma0.shape, dtype=r0.dtype, device=device)
beta0 = xnp.zeros(shape=gamma0.shape, dtype=r0.dtype, device=device)
return (x0, 0, r0, p0, alpha0, beta0, gamma0)


Expand All @@ -120,7 +121,7 @@ def cond_fun(value, tol, max_iters, xnp):

def take_cg_step(state, A, preconditioner, xnp):
x0, k, r0, p0, _, _, gamma0 = state
eps = xnp.array(_small_value, dtype=p0.real.dtype)
eps = xnp.array(_small_value, dtype=p0.real.dtype, device=A.device)
has_converged = xnp.norm(r0, axis=-2, keepdims=True) < eps
Ap0 = A @ p0

Expand All @@ -137,19 +138,22 @@ def take_cg_step(state, A, preconditioner, xnp):
def update_alpha(gamma, p, Ap, has_converged, xnp):
denom = xnp.sum(xnp.conj(p) * Ap, axis=-2, keepdims=True)
alpha = do_safe_div(gamma, denom, xnp=xnp)
alpha = xnp.where(has_converged, x=xnp.array(0.0, dtype=p.dtype), y=alpha)
device = xnp.get_device(p)
alpha = xnp.where(has_converged, x=xnp.array(0.0, dtype=p.dtype, device=device), y=alpha)
return alpha


def update_gamma_beta(r, z, gamma0, has_converged, xnp):
gamma1 = xnp.sum(xnp.conj(r) * z, axis=-2, keepdims=True)
beta = do_safe_div(gamma1, gamma0, xnp=xnp)
beta = xnp.where(has_converged, x=xnp.array(0.0, dtype=r.dtype), y=beta)
device = xnp.get_device(r)
beta = xnp.where(has_converged, x=xnp.array(0.0, dtype=r.dtype, device=device), y=beta)
return gamma1, beta


def do_safe_div(num, denom, xnp):
is_zero = xnp.abs(denom) < xnp.array(_small_value, dtype=num.real.dtype)
device = xnp.get_device(num)
is_zero = xnp.abs(denom) < xnp.array(_small_value, dtype=num.real.dtype, device=device)
denom = xnp.where(is_zero, _small_value, denom)
output = num / denom
return output
Expand Down Expand Up @@ -178,9 +182,10 @@ def body_fun(state):
def initialize_track(A, b, preconditioner, x0, max_iters, xnp):
state = initialize(A=A, b=b, preconditioner=preconditioner, x0=x0, xnp=xnp)
*_, gamma0 = state
alphas = xnp.zeros(shape=(max_iters, ) + gamma0.shape, dtype=b.dtype)
betas = xnp.zeros(shape=(max_iters, ) + gamma0.shape, dtype=b.dtype)
rs = xnp.zeros(shape=(max_iters, ) + gamma0.shape, dtype=b.dtype)
device = A.device
alphas = xnp.zeros(shape=(max_iters, ) + gamma0.shape, dtype=b.dtype, device=device)
betas = xnp.zeros(shape=(max_iters, ) + gamma0.shape, dtype=b.dtype, device=device)
rs = xnp.zeros(shape=(max_iters, ) + gamma0.shape, dtype=b.dtype, device=device)
return (state, (rs, alphas, betas))


Expand Down
6 changes: 3 additions & 3 deletions cola/algorithms/diagonal_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ def get_I_chunk_like(A: LinearOperator, i, bs, shift=0):
elif k <= 0:
k = abs(k)
I_chunk = Id[:, i:i + bs + k].to_dense()
padded_chunk = A.xnp.zeros((A.shape[0], bs + k), dtype=A.dtype)
padded_chunk = A.xnp.zeros((A.shape[0], bs + k), dtype=A.dtype, device=A.device)
slc = np.s_[:I_chunk.shape[-1]]
padded_chunk = xnp.update_array(padded_chunk, I_chunk, slice(0, None), slc)
chunk = I_chunk[:, :bs]
shifted_chunk = padded_chunk[:, k:k + bs]
else:
I_chunk = Id[:, max(i - k, 0):i + bs].to_dense()
padded_chunk = A.xnp.zeros((A.shape[0], bs + k), dtype=A.dtype)
padded_chunk = A.xnp.zeros((A.shape[0], bs + k), dtype=A.dtype, device=A.device)
slc = np.s_[-I_chunk.shape[-1]:]
padded_chunk = xnp.update_array(padded_chunk, I_chunk, slice(0, None), slc)
chunk = I_chunk[:, -bs:]
Expand Down Expand Up @@ -140,7 +140,7 @@ def cond(state):

while_loop, infos = xnp.while_loop_winfo(err, tol, pbar=pbar)
# while_loop = xnp.while_loop
zeros = xnp.zeros((A.shape[0] - abs(k), ), dtype=A.dtype)
zeros = xnp.zeros((A.shape[0] - abs(k), ), dtype=A.dtype, device=A.device)
n, diag_sum, *_ = while_loop(cond, body, (0, zeros, zeros, xnp.PRNGKey(42)))
mean = diag_sum / (n * bs)
return mean, infos
6 changes: 4 additions & 2 deletions cola/algorithms/gmres.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def fun(*theta):
@iterative_autograd(gmres_bwd)
def gmres_fwd(A, rhs, x0, max_iters, tol, P, use_householder, use_triangular, pbar):
xnp = A.xnp
device = A.device
res = rhs - A @ x0
if use_householder:
Q, H, infodict = run_householder_arnoldi(A=A, rhs=res, max_iters=max_iters)
Expand All @@ -72,7 +73,7 @@ def gmres_fwd(A, rhs, x0, max_iters, tol, P, use_householder, use_triangular, pb
pbar=pbar)
Q, H = Q[:, :idx, :], H[:idx, :idx, :]
beta = xnp.norm(res, axis=-2)
e1 = xnp.zeros(shape=(H.shape[0], beta.shape[0]), dtype=rhs.dtype)
e1 = xnp.zeros(shape=(H.shape[0], beta.shape[0]), dtype=rhs.dtype, device=A.device)
e1 = xnp.update_array(e1, beta, 0)

if use_triangular:
Expand All @@ -94,11 +95,12 @@ def gmres_fwd(A, rhs, x0, max_iters, tol, P, use_householder, use_triangular, pb


def get_hessenberg_triangular_qr(H, xnp):
device = xnp.get_device(H)
R = xnp.copy(H)
Gs = []
for jdx in range(H.shape[0] - 1):
cx, sx = get_givens_cos_sin(R[jdx, jdx], R[jdx + 1, jdx], xnp)
G = xnp.array([[cx, sx], [-sx, cx]], dtype=H.dtype)
G = xnp.array([[cx, sx], [-sx, cx]], dtype=H.dtype, device=device)
Gs.append(G)
update = G.T @ R[[jdx, jdx + 1], :]
R = xnp.update_array(R, update, [jdx, jdx + 1])
Expand Down
4 changes: 2 additions & 2 deletions cola/algorithms/iram.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ def iram(A: LinearOperator, start_vector: Array = None, max_iters: int = 100, to
del start_vector

def matvec(x):
X = xnp.array(x, dtype=A.dtype)
X = xnp.array(x, dtype=A.dtype, device=A.device)
out = A @ X
return np.array(out, dtype=np.float32)

A2 = LO(shape=A.shape, dtype=np.float32, matvec=matvec)
k = min(A.shape[0] - 1, max_iters)
eigvals, eigvecs = eigsh(A2, k=k, M=None, sigma=None, which="LM", v0=None, tol=tol)
eigvals, eigvecs = xnp.array(eigvals, dtype=A.dtype), xnp.array(eigvecs, dtype=A.dtype)
eigvals, eigvecs = xnp.array(eigvals, dtype=A.dtype, device=A.device), xnp.array(eigvecs, dtype=A.dtype, device=A.device)
info = {}
return eigvals, Dense(eigvecs), info
27 changes: 15 additions & 12 deletions cola/algorithms/lanczos.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,11 @@ def cond_fun(state):


def initialize_lanczos_vec(xnp, rhs, max_iters, dtype):
i = xnp.array(1, dtype=xnp.int32)
beta = xnp.zeros(shape=(rhs.shape[-1], max_iters), dtype=dtype)
alpha = xnp.zeros(shape=(rhs.shape[-1], max_iters + 1), dtype=dtype)
vec = xnp.zeros(shape=(rhs.shape[-1], rhs.shape[0], max_iters + 2), dtype=dtype)
device = xnp.get_device(rhs)
i = xnp.array(1, dtype=xnp.int32, device=device)
beta = xnp.zeros(shape=(rhs.shape[-1], max_iters), dtype=dtype, device=device)
alpha = xnp.zeros(shape=(rhs.shape[-1], max_iters + 1), dtype=dtype, device=device)
vec = xnp.zeros(shape=(rhs.shape[-1], rhs.shape[0], max_iters + 2), dtype=dtype, device=device)
rhs = rhs / xnp.norm(rhs, axis=-2, keepdims=True)
vec = xnp.update_array(vec, xnp.copy(rhs.T), ..., 1)
return i, vec, beta, alpha
Expand Down Expand Up @@ -226,9 +227,10 @@ def cond_fun(state):


def initialize_lanczos(xnp, vec, max_iters, dtype):
i = xnp.array(1, dtype=xnp.int32)
beta = xnp.zeros(shape=(max_iters + 1, 1), dtype=dtype)
alpha = xnp.zeros(shape=(max_iters + 1, 1), dtype=dtype)
device = xnp.get_device(vec)
i = xnp.array(1, dtype=xnp.int32, device=device)
beta = xnp.zeros(shape=(max_iters + 1, 1), dtype=dtype, device=device)
alpha = xnp.zeros(shape=(max_iters + 1, 1), dtype=dtype, device=device)
vec /= xnp.norm(vec)
vec_prev = xnp.copy(vec)
return i, vec, vec_prev, beta, alpha
Expand All @@ -241,25 +243,26 @@ def construct_tridiagonal(alpha: Array, beta: Array, gamma: Array) -> Array:

def construct_tridiagonal_batched(alpha: Array, beta: Array, gamma: Array) -> Array:
xnp = get_library_fns(beta.dtype)
T = xnp.zeros(shape=(beta.shape[-2], beta.shape[-1], beta.shape[-1]), dtype=beta.dtype)
diag_ind = xnp.array([idx for idx in range(beta.shape[-1])], dtype=xnp.int64)
device = xnp.get_device(beta)
T = xnp.zeros(shape=(beta.shape[-2], beta.shape[-1], beta.shape[-1]), dtype=beta.dtype, device=device)
diag_ind = xnp.array([idx for idx in range(beta.shape[-1])], dtype=xnp.int64, device=device)
T = xnp.update_array(T, beta, ..., diag_ind, diag_ind)
shifted_ind = xnp.array([idx + 1 for idx in range(gamma.shape[-1])], dtype=xnp.int64)
shifted_ind = xnp.array([idx + 1 for idx in range(gamma.shape[-1])], dtype=xnp.int64, device=device)
T = xnp.update_array(T, gamma, ..., diag_ind[:-1], shifted_ind)
T = xnp.update_array(T, alpha, ..., shifted_ind, diag_ind[:-1])
return T


def get_lu_from_tridiagonal(A: LinearOperator) -> Array:
xnp = A.xnp
eigenvals = xnp.zeros(shape=(A.shape[0], ), dtype=A.dtype)
eigenvals = xnp.zeros(shape=(A.shape[0], ), dtype=A.dtype, device=A.device)
eigenvals = xnp.update_array(eigenvals, A.beta[0, 0], 0)

def body_fun(i, state):
pi = A.beta[i + 1, 0] - ((A.alpha[i, 0] * A.gamma[i, 0]) / state[i])
state = xnp.update_array(state, pi, i + 1)
return state

lower, upper = xnp.array(0, dtype=xnp.int32), xnp.array(A.shape[0] - 1, dtype=xnp.int32)
lower, upper = xnp.array(0, dtype=xnp.int32), xnp.array(A.shape[0] - 1, dtype=xnp.int32, device=A.device)
eigenvals = xnp.for_loop(lower, upper, body_fun, init_val=eigenvals)
return eigenvals
6 changes: 3 additions & 3 deletions cola/algorithms/lobpcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@ def lobpcg(A: LinearOperator, start_vector: Array = None, max_iters: int = 100,
del pbar, start_vector, tol

def matvec(x):
X = xnp.array(x, dtype=A.dtype)
X = xnp.array(x, dtype=A.dtype, device=A.device)
out = A @ X
return np.array(out, dtype=np.float32)

A2 = LO(shape=A.shape, dtype=np.float32, matvec=matvec)
k = min(A.shape[0] - 1, max_iters)
X = np.random.normal(size=(A.shape[0], k)).astype(np.float32)
eigvals, eigvecs = lobpcg_sp(A2, X)
eigvals = xnp.array(np.copy(eigvals), dtype=A.dtype)
eigvecs = xnp.array(np.copy(eigvecs), dtype=A.dtype)
eigvals = xnp.array(np.copy(eigvals), dtype=A.dtype, device=A.device)
eigvecs = xnp.array(np.copy(eigvecs), dtype=A.dtype, device=A.device)
idx = xnp.argsort(eigvals, axis=-1)
info = {}
return eigvals[idx], Dense(eigvecs[:, idx]), info
2 changes: 1 addition & 1 deletion cola/algorithms/preconditioners.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def get_nys_approx(A, Omega, eps):
xnp = A.xnp
Omega, _ = xnp.qr(Omega, full_matrices=False)
Y = A @ Omega
# Y = xnp.array(Y, dtype=xnp.float64)
# Y = xnp.array(Y, dtype=xnp.float64, device=A.device)
nu = eps * xnp.norm(Y, ord="fro")
Y += nu * Omega
C = xnp.cholesky(Omega.T @ Y)
Expand Down
2 changes: 1 addition & 1 deletion cola/jax_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def lu_solve(a, b):


def get_device(array):
if not isinstance(array, jax.core.Tracer):
if not isinstance(array, jax.core.Tracer) and hasattr(array, 'device'):
return array.device()
else:
return get_default_device()
Expand Down
6 changes: 3 additions & 3 deletions cola/linalg/diag_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,17 @@ def diag(A: Dense, k=0, **kwargs):
@dispatch
def diag(A: Identity, k=0, **kwargs):
if k == 0:
return A.xnp.ones(A.shape[0], A.dtype)
return A.xnp.ones(A.shape[0], A.dtype, device=A.device)

Check warning on line 61 in cola/linalg/diag_trace.py

View check run for this annotation

Codecov / codecov/patch

cola/linalg/diag_trace.py#L61

Added line #L61 was not covered by tests
else:
return A.xnp.zeros(A.shape[0] - k, A.dtype)
return A.xnp.zeros(A.shape[0] - k, A.dtype, device=A.device)

Check warning on line 63 in cola/linalg/diag_trace.py

View check run for this annotation

Codecov / codecov/patch

cola/linalg/diag_trace.py#L63

Added line #L63 was not covered by tests


@dispatch
def diag(A: Diagonal, k=0, **kwargs):
if k == 0:
return A.diag
else:
return A.xnp.zeros(A.shape[0] - k, A.dtype)
return A.xnp.zeros(A.shape[0] - k, A.dtype, device=A.device)

Check warning on line 71 in cola/linalg/diag_trace.py

View check run for this annotation

Codecov / codecov/patch

cola/linalg/diag_trace.py#L71

Added line #L71 was not covered by tests


@dispatch
Expand Down
2 changes: 1 addition & 1 deletion cola/linalg/eigs.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def eig(A: LinearOperator, **kwargs):
# def eig(A: LowerTriangular, **kwargs):
# xnp = A.xnp
# eig_vals = diag(A.A)[eig_slice]
# eig_vecs = xnp.eye(eig_vals.shape[0], eig_vals.shape[0])
# eig_vecs = xnp.eye(eig_vals.shape[0], eig_vals.shape[0], dtype=A.dtype, device=A.device)
# return eig_vals, eig_vecs
# else:
# raise ValueError(f"Unknown method {method}")
Expand Down
2 changes: 1 addition & 1 deletion cola/linalg/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _matmat(self, V): # (n,bs)
self.info.update(info)
eigvals, P = self.xnp.eig(H)
norms = self.xnp.norm(V, axis=0)
e0 = self.xnp.canonical(0, (P.shape[1], V.shape[-1]), dtype=P.dtype)
e0 = self.xnp.canonical(0, (P.shape[1], V.shape[-1]), dtype=P.dtype, device=self.device)
Pinv0 = self.xnp.solve(P, e0.T) # (bs, m, m) vs (bs, m)
out = Pinv0 * norms[:, None] # (bs, m)
Q = self.xnp.cast(Q, dtype=P.dtype) # (bs, n, m)
Expand Down
2 changes: 1 addition & 1 deletion cola/ops/operator_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _rmatmat(self, X: Array) -> Array:
if self.isa(cola.annotations.SelfAdjoint):
return self.xnp.conj(self._matmat(self.xnp.conj(XT)).T)
primals = self.xnp.zeros(shape=(self.shape[1], XT.shape[1]), dtype=XT.dtype,
device=X.device)
device=self.device)
out = self.xnp.linear_transpose(self._matmat, primals=primals, duals=XT)
return out.T

Expand Down
18 changes: 9 additions & 9 deletions cola/ops/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class ScalarMul(LinearOperator):
""" Linear Operator representing scalar multiplication"""
def __init__(self, c, shape, dtype=None):
super().__init__(dtype=dtype or type(c), shape=shape)
self.c = self.xnp.array(c, dtype=dtype)
self.c = self.xnp.array(c, dtype=dtype, device=self.device)
self.ensure_const_register_as_array()

def ensure_const_register_as_array(self):
Expand Down Expand Up @@ -198,8 +198,8 @@ def __str__(self):

def kronsum(A, B):
xnp = get_library_fns(A.dtype)
IA = xnp.eye(A.shape[-2])
IB = xnp.eye(B.shape[-2])
IA = xnp.eye(A.shape[-2], dtype=A.dtype, device=A.device)
IB = xnp.eye(B.shape[-2], dtype=B.dtype, device=B.device)
return xnp.kron(A, IB) + xnp.kron(IA, B)


Expand Down Expand Up @@ -328,8 +328,8 @@ def __init__(self, alpha: Array, beta: Array, gamma: Array):

def _matmat(self, X: Array) -> Array:
xnp = self.xnp
aux_alpha = xnp.zeros(shape=X.shape, dtype=X.dtype)
aux_gamma = xnp.zeros(shape=X.shape, dtype=X.dtype)
aux_alpha = xnp.zeros(shape=X.shape, dtype=X.dtype, device=X.device)
aux_gamma = xnp.zeros(shape=X.shape, dtype=X.dtype, device=X.device)

output = self.beta * X
aux_gamma = xnp.update_array(aux_gamma, self.gamma * X[1:], np.s_[:-1])
Expand Down Expand Up @@ -390,7 +390,7 @@ def __init__(self, A, slices):
def _matmat(self, X: Array) -> Array:
xnp = self.xnp
start_slices, end_slices = self.slices
Y = xnp.zeros(shape=(self.A.shape[-1], X.shape[-1]), dtype=self.dtype)
Y = xnp.zeros(shape=(self.A.shape[-1], X.shape[-1]), dtype=self.dtype, device=X.device)
Y = xnp.update_array(Y, X, end_slices)
output = self.A @ Y
return output[start_slices]
Expand Down Expand Up @@ -473,14 +473,14 @@ def _matmat(self, X):
xnp = self.xnp
# hack to make it work with pytorch
if xnp.__name__ == 'cola.torch_fns' and False:
expanded_x = self.x[None, :] + self.xnp.zeros((X.shape[0], 1), dtype=self.x.dtype)
expanded_x = self.x[None, :] + self.xnp.zeros((X.shape[0], 1), dtype=self.x.dtype, device=self.device)
fn = partial(self.xnp.vjp_derivs, self.xnp.vmap(self.xnp.grad(self.f)), (expanded_x, ))
out = fn((X, ))
else:
mvm = partial(xnp.jvp_derivs, xnp.grad(self.f), (self.x, ), create_graph=False)
out = xnp.vmap(mvm)((X.T, )).T
if xnp.__name__ == 'cola.torch_fns': # pytorch converts to double silently
out = out.to(dtype=self.dtype)
out = out.to(dtype=self.dtype, device=self.device)
return out

def __str__(self):
Expand Down Expand Up @@ -557,7 +557,7 @@ class Householder(LinearOperator):
def __init__(self, vec, beta=2.):
super().__init__(shape=(vec.shape[-2], vec.shape[-2]), dtype=vec.dtype)
self.vec = vec
self.beta = self.xnp.array(beta, dtype=vec.dtype)
self.beta = self.xnp.array(beta, dtype=vec.dtype, device=self.device)

def _matmat(self, X: Array) -> Array:
xnp = self.xnp
Expand Down

0 comments on commit 505f471

Please sign in to comment.