Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
9490ff8
add normal matvec and memory profiler
zitongzhan Dec 13, 2025
9c90aca
print peak cuda allocation
zitongzhan Dec 23, 2025
6256e79
add warp memory pool report
zitongzhan Dec 28, 2025
3a5ce9b
use `A._get_Jt` when matrix_free_normal
zitongzhan Dec 28, 2025
0064146
add back schur by warp's matmul
zitongzhan Dec 28, 2025
acd1b3c
safely import cudss
zitongzhan Jan 14, 2026
91c8ade
Add future plans section to README
zitongzhan Dec 20, 2025
19774c3
add normal matvec and memory profiler
zitongzhan Dec 13, 2025
4ca9c86
print peak cuda allocation
zitongzhan Dec 23, 2025
b71f1a3
add warp memory pool report
zitongzhan Dec 28, 2025
d678867
use `A._get_Jt` when matrix_free_normal
zitongzhan Dec 28, 2025
d127b88
add back schur by warp's matmul
zitongzhan Dec 28, 2025
fa9ab70
Merge branch 'schur-matmul' of github.com:zitongzhan/bae_private into…
zitongzhan Jan 26, 2026
6619808
Merge remote-tracking branch 'upstream/release' into schur-matmul
SEOKWOOPARK Apr 13, 2026
3e4761d
Preventing TrustRegion from accepting diverging steps
SEOKWOOPARK Apr 20, 2026
5d9e2b2
fix(optimizer/LM): Remove redundant solver calls so matrix_free_norma…
SEOKWOOPARK Apr 29, 2026
e34bea2
feat(optim/Schur): Add Matrix-Free path and matrix_free_normal branch
SEOKWOOPARK Apr 29, 2026
5f4f093
Resolving conflict with release branch in README
SEOKWOOPARK Apr 29, 2026
f64d00b
Version up to 0.2.1
SEOKWOOPARK Apr 29, 2026
40798f1
Fix deprecated function in Warp
SEOKWOOPARK May 23, 2026
165104d
Replace Warp with Triton kernels and adjust corresponding codes
SEOKWOOPARK May 23, 2026
b305f81
Remove codes relevant to Chunk
SEOKWOOPARK May 23, 2026
3a97f9e
Merge branch 'release' into memory-issue-swp
SEOKWOOPARK May 24, 2026
a0b4b8b
Remove ba_helpers.py
SEOKWOOPARK May 24, 2026
f46fb74
Fix a conflict in ba_example.py
SEOKWOOPARK May 24, 2026
48ad787
Potential fix for pull request finding 'Variable defined multiple times'
zitongzhan May 24, 2026
8cc6eb3
Potential fix for pull request finding 'Unused local variable'
zitongzhan May 24, 2026
074b931
minimize diff
zitongzhan May 24, 2026
4746522
restore pysolvers
zitongzhan May 24, 2026
7f3ea3d
revert import shuffle
zitongzhan May 24, 2026
d3e24d9
restore LM
zitongzhan May 24, 2026
04908d9
fix import order ba example
zitongzhan May 24, 2026
fd5cc96
Revert a version with Triton to a version with Warp
SEOKWOOPARK May 27, 2026
35b9b79
Remove Triton implmentation file
SEOKWOOPARK May 27, 2026
733c0a6
Add 'final' and 'venice' dataset in ba_example.py and remove unnecess…
SEOKWOOPARK May 27, 2026
c94f27a
Potential fix for pull request finding 'Unused import'
zitongzhan May 28, 2026
aaa7e5b
Free PyTorch cache before Warp handoff in Schur.step
SEOKWOOPARK May 29, 2026
dff7575
Rollback unnecessary changes in gitignore
SEOKWOOPARK May 30, 2026
2bf5e62
Rollback single quote to double quote in ba_example.py
SEOKWOOPARK May 30, 2026
50075ec
Rollback the optimizer's default class from Schur to LM and adjust up…
SEOKWOOPARK May 30, 2026
0ced1df
Remove class Reproj and use Residual
SEOKWOOPARK May 30, 2026
ba6dc25
Rollback single to double quote in Time print
SEOKWOOPARK May 30, 2026
13dc2a2
Remove overlapped torch.synchronize
SEOKWOOPARK May 30, 2026
ec3262e
Single quote -> double quote in dataset's key
SEOKWOOPARK May 30, 2026
f5c1751
Rollback dataset's declaration in main
SEOKWOOPARK May 31, 2026
90b14f8
Remove rotate_quat
SEOKWOOPARK May 31, 2026
6bb75ee
Rollback least_square_error's parameter name
SEOKWOOPARK May 31, 2026
a55f607
Remove unused variable 'USE_QUATERNIONS'
SEOKWOOPARK May 31, 2026
d7f9ebc
Retrieve print for Initial loss
SEOKWOOPARK May 31, 2026
077dce4
Retrieve the decorator
SEOKWOOPARK May 31, 2026
fb41830
Remove overlapped import
SEOKWOOPARK May 31, 2026
60a9530
Remove empty_cache()
SEOKWOOPARK May 31, 2026
7adb46e
Rollback variable names in class LM
SEOKWOOPARK May 31, 2026
49a18f1
Transfer TrustRegion and Adaptive into strategy.py in the optim direc…
SEOKWOOPARK Jun 1, 2026
8196c03
Potential fix for pull request finding 'Unused import'
SEOKWOOPARK Jun 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 82 additions & 3 deletions ba_example.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
from time import perf_counter
from datetime import datetime
from pathlib import Path

import pypose as pp
import torch
import torch.nn as nn
import warp as wp
from pypose.autograd.function import psjac

from bae.optim.optimizer import Schur
Comment thread
github-code-quality[bot] marked this conversation as resolved.
Fixed
Comment thread
SEOKWOOPARK marked this conversation as resolved.
Dismissed
from datapipes.bal_loader import get_problem
from bae.optim import LM
Comment thread
github-code-quality[bot] marked this conversation as resolved.
Fixed
Comment thread
github-code-quality[bot] marked this conversation as resolved.
Fixed
from bae.utils.pysolvers import PCG
from bae.optim.strategy import TrustRegion


TARGET_DATASET = "trafalgar"
TARGET_PROBLEM = "problem-257-65132-pre"
Expand All @@ -16,10 +22,31 @@
# TARGET_PROBLEM = "problem-1723-156502-pre"
# TARGET_DATASET = "dubrovnik"
# TARGET_PROBLEM = "problem-356-226730-pre"
# TARGET_DATASET = "venice"
# TARGET_PROBLEM = "problem-1778-993923-pre"
# TARGET_DATASET = "final"
# TARGET_PROBLEM = "problem-13682-4456117-pre"

DEVICE = "cuda"
OPTIMIZE_INTRINSICS = True
NUM_CAMERA_PARAMS = 10 if OPTIMIZE_INTRINSICS else 7
REPORT_WARP_MEMPOOL = True


def _format_bytes(num_bytes: int) -> str:
sign = "-" if num_bytes < 0 else ""
size = float(abs(num_bytes))
units = ["B", "KiB", "MiB", "GiB", "TiB"]

for unit in units:
if size < 1024.0 or unit == units[-1]:
break
size /= 1024.0

if unit == "B":
return f"{sign}{int(size)} {unit}"

return f"{sign}{size:.2f} {unit}"


@psjac
Expand Down Expand Up @@ -57,7 +84,12 @@ def least_square_error(camera_params, points, cidx, pidx, observes):
def main():
dataset = get_problem(TARGET_PROBLEM, TARGET_DATASET)
print(f"Fetched {TARGET_PROBLEM} from {TARGET_DATASET}")

file_name = f'{TARGET_DATASET}.{TARGET_PROBLEM}'
cuda_device = torch.device(DEVICE) if DEVICE.startswith("cuda") else None
memory_snapshot_path = None
warp_device = None
warp_mempool_start_current = None
warp_mempool_start_high = None
dataset = {
key: value.to(DEVICE)
for key, value in dataset.items()
Expand All @@ -69,11 +101,36 @@ def main():
"pidx": dataset["point_index_of_observations"],
}

if DEVICE.startswith("cuda") and torch.cuda.is_available():
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
snapshot_dir = Path("memory_traces")
snapshot_dir.mkdir(exist_ok=True)
memory_snapshot_path = snapshot_dir / f"{file_name}_cuda_memory_{timestamp}.pickle"
torch.cuda.memory._record_memory_history(
enabled="all",
context="all",
stacks="python",
device=cuda_device,
clear_history=True,
)

if REPORT_WARP_MEMPOOL and DEVICE.startswith("cuda"):
try:
if wp.is_cuda_available():
warp_device = wp.get_device("cuda:0" if DEVICE == "cuda" else DEVICE)
if not wp.is_mempool_enabled(warp_device):
wp.set_mempool_enabled(warp_device, True)
warp_mempool_start_current = wp.get_mempool_used_mem_current(warp_device)
warp_mempool_start_high = wp.get_mempool_used_mem_high(warp_device)
except Exception as e:
print(f"Warning: failed to query Warp mempool stats: {e}")

model = Residual(
dataset["camera_params"][:, :NUM_CAMERA_PARAMS].clone(),
dataset["points_3d"].clone(),
).to(DEVICE)
strategy = pp.optim.strategy.TrustRegion(up=2.0, down=0.5**4)

strategy = TrustRegion(up=2.0, down=0.5**4)
solver = PCG(tol=1e-4, maxiter=250)
optimizer = LM(model, strategy=strategy, solver=solver, reject=30)

Expand All @@ -96,6 +153,28 @@ def main():
end = perf_counter()
print("Time", end - start)

if memory_snapshot_path:
torch.cuda.memory._dump_snapshot(str(memory_snapshot_path))
print(f"CUDA memory snapshot saved to {memory_snapshot_path}")

if cuda_device is not None and torch.cuda.is_available():
peak_allocated = torch.cuda.max_memory_allocated(cuda_device)
try:
peak_reserved = torch.cuda.max_memory_reserved(cuda_device)
except AttributeError:
peak_reserved = torch.cuda.max_memory_cached(cuda_device)
print(f"Peak CUDA memory allocated: {_format_bytes(peak_allocated)}")
print(f"Peak CUDA memory reserved: {_format_bytes(peak_reserved)}")

if warp_device is not None and warp_mempool_start_current is not None:
try:
warp_current = wp.get_mempool_used_mem_current(warp_device)
warp_high = wp.get_mempool_used_mem_high(warp_device)
print(f"Warp CUDA mempool current: {_format_bytes(warp_current)} (Δ {_format_bytes(warp_current - warp_mempool_start_current)})")
print(f"Warp CUDA mempool high-water: {_format_bytes(warp_high)} (Δ {_format_bytes(warp_high - warp_mempool_start_high)})")
except Exception as e:
print(f"Warning: failed to query Warp mempool stats: {e}")

print('Ending loss:', least_square_error(
model.pose,
model.points,
Expand All @@ -106,4 +185,4 @@ def main():


if __name__ == "__main__":
main()
main()
180 changes: 169 additions & 11 deletions bae/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,23 @@
import torch
from pypose.optim import LevenbergMarquardt as ppLM
import pypose as pp

from ..autograd.graph import jacobian
from ..autograd.function import TrackingTensor
from ..sparse.py_ops import diagonal_op_
from ..sparse.py_ops import diagonal_op_, inv_op
from ..sparse.spgemm import CuSparse
from ..utils.linear_operator import NormalMatVec
from ..utils.parameter import parameter_update_shape


import warp as wp
from warp import sparse
from warp.optim import linear
from bae.sparse.warp_wrappers import format_vec_for_bsr, torchbsr2wp, wp2torchbsr


class LM(ppLM):
def __init__(self, *args, **kwargs):
def __init__(self, *args, matrix_free_normal: bool = False, **kwargs):
self.matrix_free_normal = matrix_free_normal
super(LM, self).__init__(*args, **kwargs)
self.mm = CuSparse()

Expand All @@ -25,21 +31,31 @@ def step(self, input, target=None, weight=None):
J = jacobian(R, pg['params'])
if isinstance(R, TrackingTensor):
R = R.tensor()
J = torch.cat([j.to_sparse_coo() for j in J], dim=-1)
J = torch.cat([j.to_sparse_coo() for j in J], dim=-1).to_sparse_csr()

self.last = self.loss = self.loss if hasattr(self, 'loss') else self.model.loss(input, target)
J_T = J.mT
self.reject_count = 0
J_T = J_T.to_sparse_csr()
J = J.to_sparse_csr()
A = self.mm(J_T, J)

diagonal_op_(A, op=partial(torch.clamp_, min=pg['min'], max=pg['max']))
if self.matrix_free_normal:
diag = NormalMatVec._compute_diag(J).clamp(min=pg['min'], max=pg['max'])
A = NormalMatVec(J, damping=0.0, diag=diag)
rhs = -(A._get_Jt() @ R.view(-1, 1))
diag_scale = 1.0
else:
J_T = J.mT.to_sparse_csr()
rhs = -J_T @ R.view(-1, 1)
A = self.mm(J_T, J)
del J_T
diagonal_op_(A, op=partial(torch.clamp_, min=pg['min'], max=pg['max']))

while self.last <= self.loss:
diagonal_op_(A, op=partial(torch.mul, other=1+pg['damping']))
if self.matrix_free_normal:
diag_scale *= 1.0 + pg['damping']
A.set_damping(diag_scale - 1.0)
else:
diagonal_op_(A, op=partial(torch.mul, other=1+pg['damping']))
Comment thread
zitongzhan marked this conversation as resolved.
try:
D = self.solver(A, -J_T @ R.view(-1, 1))
D = self.solver(A, rhs)
except Exception as e:
print(e, "\nLinear solver failed. Breaking optimization step...")
break
Expand Down Expand Up @@ -69,3 +85,145 @@ def update_parameter(self, params, step):
param[:, 7:] += step_view[..., 6:]
else:
param.add_(step_view)


class Schur(LM):
@torch.no_grad()
def step(self, input, target=None, weight=None):
for pg in self.param_groups:
self.reject_count = 0
weight = self.weight if weight is None else weight
R = self.model(input, target)[0]
J = jacobian(R, pg['params'])

self.last = self.loss = self.loss if hasattr(self, 'loss') else self.model.loss(input, target)
J0wp = torchbsr2wp(J[0])
J1wp = torchbsr2wp(J[1])
J0twp = sparse.bsr_transposed(J0wp)
J1twp = sparse.bsr_transposed(J1wp)
U = sparse.bsr_mm(J0twp, J0wp)
V = sparse.bsr_mm(J1twp, J1wp)

if self.matrix_free_normal:
del J0twp, J1twp
else:
W = sparse.bsr_mm(J0twp, J1wp)
Wt = sparse.bsr_transposed(W)
del J0twp, J1twp

Upt = wp2torchbsr(U)
Vpt = wp2torchbsr(V)
diagonal_op_(Upt, op=partial(torch.clamp_, min=pg['min'], max=pg['max']))
diagonal_op_(Vpt, op=partial(torch.clamp_, min=pg['min'], max=pg['max']))
R_flat = R.reshape(-1).contiguous()
Rwp = format_vec_for_bsr(R_flat, (J0wp.block_shape[1], J0wp.block_shape[0]))
Ic = sparse.bsr_mv(J0wp, Rwp, alpha=-1.0, transpose=True)
Ip = sparse.bsr_mv(J1wp, Rwp, alpha=-1.0, transpose=True)
rhs_c = wp.empty_like(Ic)
rhs_p = wp.empty_like(Ip)
scratch_pts2 = wp.empty_like(Ip)

if self.matrix_free_normal:
scratch_obs = wp.empty_like(Rwp)
scratch_pts = wp.empty_like(Ip)

solver_tol = getattr(self.solver, "tol", None)
solver_maxiter = getattr(self.solver, "maxiter", 0) or 0

while self.last <= self.loss:
damp = partial(torch.mul, other=1+pg['damping'])
diagonal_op_(Upt, op=damp)
diagonal_op_(Vpt, op=damp)

V_i = torchbsr2wp(inv_op(Vpt))

if self.matrix_free_normal:
def schur_matvec(x, y, z, alpha, beta, _V_i=V_i):
sparse.bsr_mv(J0wp, x, y=scratch_obs, beta=0.0)
sparse.bsr_mv(J1wp, scratch_obs, y=scratch_pts, beta=0.0, transpose=True)
sparse.bsr_mv(_V_i, scratch_pts, y=scratch_pts2, beta=0.0)
sparse.bsr_mv(J1wp, scratch_pts2, y=scratch_obs, beta=0.0)
if z.ptr != y.ptr and beta != 0.0:
wp.copy(src=y, dest=z)
sparse.bsr_mv(J0wp, scratch_obs, y=z, alpha=-alpha, beta=beta, transpose=True)
sparse.bsr_mv(U, x, y=z, alpha=alpha, beta=1.0)

schur_op = linear.LinearOperator(
shape=U.shape, dtype=U.values.dtype, device=U.device,
matvec=schur_matvec,
)
schur_M = linear.preconditioner(U)

wp.copy(src=Ic, dest=rhs_c)
sparse.bsr_mv(V_i, Ip, y=scratch_pts2, beta=0.0)
sparse.bsr_mv(J1wp, scratch_pts2, y=scratch_obs, beta=0.0)
sparse.bsr_mv(J0wp, scratch_obs, y=rhs_c, alpha=-1.0, beta=1.0, transpose=True)
else:
WV_i = sparse.bsr_mm(W, V_i)
WVi_Wt = sparse.bsr_mm(WV_i, Wt)
U_clone_torch = torch.sparse_bsr_tensor(
crow_indices=Upt.crow_indices().clone(),
col_indices=Upt.col_indices().clone(),
values=Upt.values().clone(),
size=Upt.shape, device=Upt.device, dtype=Upt.dtype,
)
schur_op = sparse.bsr_axpy(WVi_Wt, torchbsr2wp(U_clone_torch), alpha=-1.0)
schur_M = linear.preconditioner(schur_op)
wp.copy(src=Ic, dest=rhs_c)
sparse.bsr_mv(V_i, Ip, y=scratch_pts2, beta=0.0)
sparse.bsr_mv(W, scratch_pts2, y=rhs_c, alpha=-1.0, beta=1.0)

D_c = wp.zeros_like(rhs_c)
linear.cg(
A=schur_op,
b=rhs_c,
x=D_c,
tol=solver_tol,
maxiter=solver_maxiter,
M=schur_M,
)


wp.copy(src=Ip, dest=rhs_p)

if self.matrix_free_normal:
sparse.bsr_mv(J0wp, D_c, y=scratch_obs, beta=0.0)
sparse.bsr_mv(J1wp, scratch_obs, y=rhs_p,
alpha=-1.0, beta=1.0, transpose=True)
else:
sparse.bsr_mv(Wt, D_c, y=rhs_p, alpha=-1.0, beta=1.0)

D_p = wp.zeros_like(rhs_p)
linear.cg(
A=V,
b=rhs_p,
x=D_p,
tol=solver_tol,
maxiter=solver_maxiter,
M=linear.preconditioner(V),
)

D_c_t = wp.to_torch(D_c).flatten()
D_p_t = wp.to_torch(D_p).flatten()
D = torch.cat([D_c_t, D_p_t])
self.update_parameter(pg['params'], D)
self.loss = self.model.loss(input, target)
print("Loss:", self.loss, "Last Loss:", self.last, "Reject Count:", self.reject_count, "Damping:", pg['damping'])

self.strategy.update(
pg,
last=self.last,
loss=self.loss,
J=J,
Jwp=[J0wp, J1wp],
D=[D_c_t, D_p_t],
R=R_flat.view(-1, 1),
)

if self.last < self.loss and self.reject_count < self.reject: # reject step
self.update_parameter(params=pg['params'], step=-D)
self.loss, self.reject_count = self.last, self.reject_count + 1
else:
break

return self.loss
Loading
Loading