-
-
Notifications
You must be signed in to change notification settings - Fork 100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add CG solver #312
Add CG solver #312
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
tests/optim/test_solver.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test doesn't follow the existing style. Refer to this test or this test with predefined data. Use class style.
pypose/optim/solver.py
Outdated
if self.maxiter is None: | ||
maxiter = n*10 | ||
else: | ||
maxiter = self.maxiter | ||
r = b - bmv(A, x) if x.any() else b.clone() | ||
rho_prev, p = None, None | ||
|
||
for iteration in range(maxiter): | ||
if (torch.linalg.norm(r, dim=-1) < atol).all(): | ||
return x | ||
|
||
z = bmv(M, r) if M is not None else r | ||
rho_cur = vecdot(r, z) | ||
if iteration > 0: | ||
beta = rho_cur / rho_prev | ||
p = p * beta.unsqueeze(-1) + z | ||
else: # First spin | ||
p = torch.empty_like(r) | ||
p[:] = z[:] | ||
|
||
q = bmv(A, p) | ||
alpha = rho_cur / vecdot(p, q) | ||
x += alpha.unsqueeze(-1)*p | ||
r -= alpha.unsqueeze(-1)*q | ||
rho_prev = rho_cur | ||
|
||
else: # for loop exhausted | ||
# Return incomplete progress | ||
return x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about
if self.maxiter is None:
maxiter = n * 10
else:
maxiter = self.maxiter
r = b - bmv(A, x) if x.any() else b.clone()
p = torch.empty_like(r)
rho_prev, p[:]= None, z[:]
for iteration in range(maxiter):
if (torch.linalg.norm(r, dim=-1) < atol).all():
return x
z = bmv(M, r) if M is not None else r
rho_cur = vecdot(r, z)
beta = rho_cur / rho_prev
p = p * beta.unsqueeze(-1) + z
q = bmv(A, p)
alpha = rho_cur / vecdot(p, q)
x += alpha.unsqueeze(-1)*p
r -= alpha.unsqueeze(-1)*q
rho_prev = rho_cur
return x
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
z has to be defined before it is used in rho_prev, p[:]= None, z[:]
Add cg solver
Solve #311