Skip to content

Commit

Permalink
Enabled back ellipsoidal damping in LM with linear solvers support ch…
Browse files Browse the repository at this point in the history
…ecks (facebookresearch#87)

* Enabled back ellipsoidal damping in LM and added checks for supporting linear solvers.

* Fixed bad dtype handling in Objective and SparseLinearization.

* Renamed _LM_ALLOWED_SOLVERS to _LM_ALLOWED_ELLIPS_DAMP_SOLVERS.

* Changed nox so that pytest only runs 'not cudaext' market tests.

* Added batch_size kwarg to LUCudaSparseSolver.
  • Loading branch information
luisenp authored Mar 8, 2022
1 parent 185bf0e commit 89b3d51
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 12 deletions.
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ def mypy_and_tests(session):
session.install("-r", "requirements/dev.txt")
session.run("mypy", "theseus")
session.install("-e", ".")
session.run("pytest", "theseus")
session.run("pytest", "theseus", "-m", "not cudaext")
4 changes: 2 additions & 2 deletions theseus/core/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,5 +462,5 @@ def to(self, *args, **kwargs):
for cost_function in self.cost_functions.values():
cost_function.to(*args, **kwargs)
device, dtype, *_ = torch._C._nn._parse_to(*args, **kwargs)
self.device = device
self.dtype = dtype
self.device = device or self.device
self.dtype = dtype or self.dtype
6 changes: 5 additions & 1 deletion theseus/optimizer/linear/lu_cuda_sparse_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(
linearization_cls: Optional[Type[Linearization]] = None,
linearization_kwargs: Optional[Dict[str, Any]] = None,
num_solver_contexts=1,
batch_size: Optional[int] = None,
**kwargs,
):
if not torch.cuda.is_available():
Expand All @@ -39,7 +40,10 @@ def __init__(
self._num_solver_contexts: int = num_solver_contexts

if self.linearization.structure().num_rows:
self.reset()
if batch_size is not None:
self.reset(batch_size=batch_size)
else:
self.reset()

def reset(self, batch_size: int = 16):
if not torch.cuda.is_available():
Expand Down
27 changes: 20 additions & 7 deletions theseus/optimizer/nonlinear/levenberg_marquardt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,20 @@

from theseus.core import Objective
from theseus.optimizer import Linearization
from theseus.optimizer.linear import DenseSolver, LinearSolver
from theseus.optimizer.linear import DenseSolver, LinearSolver, LUCudaSparseSolver

from .nonlinear_least_squares import NonlinearLeastSquares

_LM_ALLOWED_ELLIPS_DAMP_SOLVERS = [DenseSolver, LUCudaSparseSolver]


def _check_ellipsoidal_damping_cls(linear_solver: LinearSolver):
good = False
for lsc in _LM_ALLOWED_ELLIPS_DAMP_SOLVERS:
if isinstance(linear_solver, lsc):
good = True
return good


# See Nocedal and Wright, Numerical Optimization, pp. 258 - 261
# https://www.csie.ntu.edu.tw/~r97002/temp/num_optimization.pdf
Expand Down Expand Up @@ -40,6 +50,7 @@ def __init__(
max_iterations=max_iterations,
step_size=step_size,
)
self._allows_ellipsoidal = _check_ellipsoidal_damping_cls(self.linear_solver)

def compute_delta(
self,
Expand All @@ -48,15 +59,17 @@ def compute_delta(
damping_eps: Optional[float] = None,
**kwargs,
) -> torch.Tensor:
if ellipsoidal_damping:
raise NotImplementedError("Ellipsoidal damping is not currently supported.")
if ellipsoidal_damping and not isinstance(self.linear_solver, DenseSolver):

solvers_str = ",".join(c.__name__ for c in _LM_ALLOWED_ELLIPS_DAMP_SOLVERS)
if ellipsoidal_damping and not self._allows_ellipsoidal:
raise NotImplementedError(
"Ellipsoidal damping is only supported when using DenseSolver."
f"Ellipsoidal damping is only supported by solvers with type "
f"[{solvers_str}]."
)
if damping_eps and not isinstance(self.linear_solver, DenseSolver):
if damping_eps and not self._allows_ellipsoidal:
raise NotImplementedError(
"damping eps is only supported when using DenseSolver."
f"damping eps is only supported by solvers with type "
f"[{solvers_str}]."
)
damping_eps = damping_eps or 1e-8

Expand Down
45 changes: 44 additions & 1 deletion theseus/optimizer/nonlinear/tests/test_levenberg_marquardt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,24 @@
# LICENSE file in the root directory of this source tree.

import pytest # noqa: F401
import torch

import theseus as th

from .common import run_nonlinear_least_squares_check


@pytest.fixture
def mock_objective():
objective = th.Objective()
v1 = th.Vector(1, name="v1")
v2 = th.Vector(1, name="v2")
objective.add(th.eb.VariableDifference(v1, th.ScaleCostWeight(1.0), v2))
return objective


def test_levenberg_marquartd():
for ellipsoidal_damping in [False]:
for ellipsoidal_damping in [True, False]:
for damping in [0, 0.001, 0.01, 0.1]:
run_nonlinear_least_squares_check(
th.LevenbergMarquardt,
Expand All @@ -22,3 +32,36 @@ def test_levenberg_marquartd():
},
singular_check=damping < 0.001,
)


def test_ellipsoidal_damping_compatibility(mock_objective):
mock_objective.update({"v1": torch.ones(1, 1), "v2": torch.zeros(1, 1)})
for lsc in [th.LUDenseSolver, th.CholeskyDenseSolver]:
optimizer = th.LevenbergMarquardt(mock_objective, lsc)
optimizer.optimize(ellipsoidal_damping=True)
optimizer.optimize(damping_eps=0.1)

for lsc in [th.CholmodSparseSolver]:
optimizer = th.LevenbergMarquardt(mock_objective, lsc)
with pytest.raises(RuntimeError):
optimizer.optimize(ellipsoidal_damping=True)
with pytest.raises(RuntimeError):
optimizer.optimize(damping_eps=0.1)


@pytest.mark.cuda
def test_ellipsoidal_damping_compatibility_cuda(mock_objective):
mock_objective.to(device="cuda", dtype=torch.double)
batch_size = 2
mock_objective.update(
{
"v1": torch.ones(batch_size, 1, device="cuda", dtype=torch.double),
"v2": torch.zeros(batch_size, 1, device="cuda", dtype=torch.double),
}
)
for lsc in [th.LUCudaSparseSolver]:
optimizer = th.LevenbergMarquardt(
mock_objective, lsc, linear_solver_kwargs={"batch_size": batch_size}
)
optimizer.optimize(ellipsoidal_damping=True)
optimizer.optimize(damping_eps=0.1)
2 changes: 2 additions & 0 deletions theseus/optimizer/sparse_linearization.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,12 @@ def _linearize_jacobian_impl(self):
self.A_val = torch.empty(
size=(self.objective.batch_size, len(self.A_col_ind)),
device=self.objective.device,
dtype=self.objective.dtype,
)
self.b = torch.empty(
size=(self.objective.batch_size, self.num_rows),
device=self.objective.device,
dtype=self.objective.dtype,
)

err_row_idx = 0
Expand Down

0 comments on commit 89b3d51

Please sign in to comment.