Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CG solver #312

Merged
merged 18 commits into from
Dec 8, 2023
Merged

Add CG solver #312

merged 18 commits into from
Dec 8, 2023

Conversation

zitongzhan
Copy link
Contributor

@zitongzhan zitongzhan commented Dec 7, 2023

Add cg solver
Solve #311

@zitongzhan zitongzhan changed the title Bsr/cg Add CG solver Dec 7, 2023
@zitongzhan zitongzhan marked this pull request as draft December 7, 2023 05:39
@zitongzhan zitongzhan marked this pull request as ready for review December 7, 2023 06:29
@zitongzhan
Copy link
Contributor Author

Copy link
Member

@hxu296 hxu296 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

pypose/optim/solver.py Outdated Show resolved Hide resolved
pypose/optim/solver.py Outdated Show resolved Hide resolved
pypose/optim/solver.py Outdated Show resolved Hide resolved
pypose/optim/solver.py Outdated Show resolved Hide resolved
pypose/optim/solver.py Outdated Show resolved Hide resolved
pypose/optim/solver.py Outdated Show resolved Hide resolved
pypose/optim/solver.py Outdated Show resolved Hide resolved
pypose/optim/solver.py Outdated Show resolved Hide resolved
@zitongzhan zitongzhan requested review from wang-chen and removed request for wang-chen December 7, 2023 21:04
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test doesn't follow the existing style. Refer to this test or this test with predefined data. Use class style.

Comment on lines 304 to 332
if self.maxiter is None:
maxiter = n*10
else:
maxiter = self.maxiter
r = b - bmv(A, x) if x.any() else b.clone()
rho_prev, p = None, None

for iteration in range(maxiter):
if (torch.linalg.norm(r, dim=-1) < atol).all():
return x

z = bmv(M, r) if M is not None else r
rho_cur = vecdot(r, z)
if iteration > 0:
beta = rho_cur / rho_prev
p = p * beta.unsqueeze(-1) + z
else: # First spin
p = torch.empty_like(r)
p[:] = z[:]

q = bmv(A, p)
alpha = rho_cur / vecdot(p, q)
x += alpha.unsqueeze(-1)*p
r -= alpha.unsqueeze(-1)*q
rho_prev = rho_cur

else: # for loop exhausted
# Return incomplete progress
return x
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about

        if self.maxiter is None:
            maxiter = n * 10
        else:
            maxiter = self.maxiter
        r = b - bmv(A, x) if x.any() else b.clone()
        p = torch.empty_like(r)
        rho_prev, p[:]= None, z[:]

        for iteration in range(maxiter):
            if (torch.linalg.norm(r, dim=-1) < atol).all():
                return x

            z = bmv(M, r) if M is not None else r
            rho_cur = vecdot(r, z)
            beta = rho_cur / rho_prev
            p = p * beta.unsqueeze(-1) + z
            q = bmv(A, p)
            alpha = rho_cur / vecdot(p, q)
            x += alpha.unsqueeze(-1)*p
            r -= alpha.unsqueeze(-1)*q
            rho_prev = rho_cur

        return x

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

z has to be defined before it is used in rho_prev, p[:]= None, z[:]

@zitongzhan zitongzhan merged commit 375740e into main Dec 8, 2023
3 of 4 checks passed
@zitongzhan zitongzhan deleted the bsr/cg branch December 8, 2023 03:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants