Skip to content

Commit

Permalink
Vectorization refactor (facebookresearch#205)
Browse files Browse the repository at this point in the history
* Created a wrapper cost function class that combines the aux vars for a cost function and its weight.

* Disabled support for optimization variables in cost weights.

* Changed Objective to iterate over CFWrapper if available, and TheseusLayer to create them at init time.

* Added a Vectorizer class and moved CFWrappers there.

* Renamed vectorizer as Vectorize, added logic to replace Objective iterator, and added it to TheseusLayer.

* Added a CostFunctionSchema -> List[CostFunction] to use for vectorization grouping.

* _CostFunctionWrapper is now meant to just store a cached value coming from vectorization. Otherwise it works just like the cost function it wraps.

* Added code to automatically compute shared vars in Vectorize.

* Changed vectorized costs construction to ensure that their weight is also a copy.

* Implemented main cost function vectorization logic.

* Updated bug that was causing detached gradients.

* Fixed invalid check in theseus end-to-end unit tests.

* Added unit test for schema and shared var computation.

* Added a test to check that computed vectorized errors are correct.

* Moved vectorization update call to base linearization class.

* Changed code to allow batch_size > 1 in shared variables.

* Fixed unit test and added call to Objective.update() in update_vectorization() if batch_size is None.

* Added new private iterator for vectorized costs.

* Replaced _register_vars_in_list with TheseusFunction.register_vars.

* Renamed vectorize_cost_fns kwarg as vectorize.

* Added license headers.

* Small refactor.

* Fixed bug that was preventing vectorized costs to work with to(). End-to-end tests now run with vectorize=True.

* Renamed the private Objective cost function iterator to _get_iterator().

* Renamed kwarg in register_vars.

* Set vectorize=True for inverse kinematics and backward tests.

* Remove lingering comments.


Co-authored-by: Taosha Fan <6612911+fantaosha@users.noreply.github.com>
  • Loading branch information
luisenp and fantaosha authored Jun 9, 2022
1 parent ca01b29 commit e5405fd
Show file tree
Hide file tree
Showing 16 changed files with 574 additions and 50 deletions.
1 change: 1 addition & 0 deletions theseus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Objective,
ScaleCostWeight,
Variable,
Vectorize,
)
from .geometry import (
SE2,
Expand Down
1 change: 1 addition & 0 deletions theseus/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from .cost_weight import CostWeight, DiagonalCostWeight, ScaleCostWeight
from .objective import Objective
from .variable import Variable
from .vectorizer import Vectorize
14 changes: 2 additions & 12 deletions theseus/core/cost_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,22 +111,12 @@ def __init__(
# this avoids doing aux_vars=[], which is a bad default since [] is mutable
aux_vars = aux_vars or []

def _register_vars_in_list(var_list_, is_optim=False):
for var_ in var_list_:
if hasattr(self, var_.name):
raise RuntimeError(f"Variable name {var_.name} is not allowed.")
setattr(self, var_.name, var_)
if is_optim:
self.register_optim_var(var_.name)
else:
self.register_aux_var(var_.name)

if len(optim_vars) < 1:
raise ValueError(
"AutodiffCostFunction must receive at least one optimization variable."
)
_register_vars_in_list(optim_vars, is_optim=True)
_register_vars_in_list(aux_vars, is_optim=False)
self.register_vars(optim_vars, is_optim_vars=True)
self.register_vars(aux_vars, is_optim_vars=False)

self._err_fn = err_fn
self._dim = dim
Expand Down
31 changes: 26 additions & 5 deletions theseus/core/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import warnings
from collections import OrderedDict
from typing import Dict, List, Optional, Sequence, Union
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Union

import torch

Expand Down Expand Up @@ -61,6 +61,15 @@ def __init__(self, dtype: Optional[torch.dtype] = None):
# objective structure might break optimizer initialization).
self.current_version = 0

# ---- Callbacks for vectorization ---- #
# This gets replaced when cost function vectorization is used
self._cost_functions_iterable: Optional[Iterable[CostFunction]] = None

# Used to vectorize cost functions after update
self._vectorization_run: Optional[Callable] = None

self._vectorization_to: Optional[Callable] = None

def _add_function_variables(
self,
function: TheseusFunction,
Expand Down Expand Up @@ -160,15 +169,14 @@ def add(self, cost_function: CostFunction):
self.cost_functions_for_weights[cost_function.weight] = []

if cost_function.weight.num_optim_vars() > 0:
warnings.warn(
raise RuntimeError(
f"The cost weight associated to {cost_function.name} receives one "
"or more optimization variables. Differentiating cost "
"weights with respect to optimization variables is not currently "
"supported, thus jacobians computed by our optimizers will be "
"incorrect. You may want to consider moving the weight computation "
"inside the cost function, so that the cost weight only receives "
"auxiliary variables.",
RuntimeWarning,
"auxiliary variables."
)

self.cost_functions_for_weights[cost_function.weight].append(cost_function)
Expand Down Expand Up @@ -471,9 +479,20 @@ def _get_batch_size(batch_sizes: Sequence[int]) -> int:
batch_sizes.extend([v.data.shape[0] for v in self.aux_vars.values()])
self._batch_size = _get_batch_size(batch_sizes)

def update_vectorization(self):
if self._vectorization_run is not None:
if self._batch_size is None:
self.update()
self._vectorization_run()

# iterates over cost functions
def __iter__(self):
return iter([f for f in self.cost_functions.values()])
return iter([cf for cf in self.cost_functions.values()])

def _get_iterator(self):
if self._cost_functions_iterable is None:
return iter([cf for cf in self.cost_functions.values()])
return iter([cf for cf in self._cost_functions_iterable])

# Applies to() with given args to all tensors in the objective
def to(self, *args, **kwargs):
Expand All @@ -482,3 +501,5 @@ def to(self, *args, **kwargs):
device, dtype, *_ = torch._C._nn._parse_to(*args, **kwargs)
self.device = device or self.device
self.dtype = dtype or self.dtype
if self._vectorization_to is not None:
self._vectorization_to(*args, **kwargs)
4 changes: 2 additions & 2 deletions theseus/core/tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,10 @@ def __init__(

def error(self):
mu = torch.stack([v.data for v in self.optim_vars]).sum()
return mu * torch.ones(self._dim)
return mu * torch.ones(1, self._dim)

def jacobians(self):
return [self.error()] * len(self._optim_vars_attr_names)
return [torch.ones(1, self._dim, self._dim)] * len(self._optim_vars_attr_names)

def dim(self) -> int:
return self._dim
Expand Down
26 changes: 9 additions & 17 deletions theseus/core/tests/test_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _create_cost_function_with_n_vars_and_m_aux(
yet_another_cost_function = _create_cost_function_with_n_vars_and_m_aux(
"yet_another", ["yet_var_1"], ["yet_aux_1"], cost_weight
)
with pytest.warns(RuntimeWarning): # optim var associated to weight
with pytest.raises(RuntimeError): # optim var associated to weight
objective.add(yet_another_cost_function) # no conflict here

cost_weight_with_conflict_in_aux_var = MockCostWeight(
Expand Down Expand Up @@ -151,24 +151,19 @@ def test_add_and_erase_step_by_step():
var3 = MockVar(1, data=None, name="var3")
aux1 = MockVar(1, data=None, name="aux1")
aux2 = MockVar(1, data=None, name="aux2")
cw1 = MockCostWeight(aux1, name="cw1", add_dummy_var_with_name="ignored_optim_var")
cw2 = MockCostWeight(aux2, name="cw2", add_optim_var=var1)
cw1 = MockCostWeight(
aux1, name="cw1"
) # , add_dummy_var_with_name="ignored_optim_var")
cw2 = MockCostWeight(aux2, name="cw2") # , add_optim_var=var1)

cf1 = MockCostFunction([var1, var2], [aux1, aux2], cw1, name="cf1")
cf2 = MockCostFunction([var1, var3], [aux1], cw2, name="cf2")
cf3 = MockCostFunction([var2, var3], [aux2], cw2, name="cf3")

objective = th.Objective()
for cost_function in [cf1, cf2, cf3]:
if cost_function is not cf3:
with pytest.warns(RuntimeWarning):
# a warning should emit the first time cw1/cw2 are added
objective.add(cost_function)
else:
objective.add(cost_function)
objective.add(cost_function)

for name in ["var1", "ignored_optim_var"]:
assert name in objective.cost_weight_optim_vars
for name in ["var1", "var2", "var2"]:
assert name in objective.optim_vars
for name in ["aux1", "aux2"]:
Expand Down Expand Up @@ -207,13 +202,10 @@ def _check_all_vars(v1_lis_, v2_lis_, v3_lis_, cw1_opt_lis_, a1_lis_, a2_lis_):
_check_funs_for_variable(var1, v1_lis_)
_check_funs_for_variable(var2, v2_lis_)
_check_funs_for_variable(var3, v3_lis_)
_check_funs_for_variable(
cw1.optim_var_at(0), cw1_opt_lis_, is_cost_weight_optim=True
)
_check_funs_for_variable(aux1, a1_lis_, optim_var=False)
_check_funs_for_variable(aux2, a2_lis_, optim_var=False)

v1_lis = [cw2, cf1, cf2]
v1_lis = [cf1, cf2]
v2_lis = [cf1, cf3]
v3_lis = [cf2, cf3]
cw1o_lis = [cw1]
Expand All @@ -223,7 +215,7 @@ def _check_all_vars(v1_lis_, v2_lis_, v3_lis_, cw1_opt_lis_, a1_lis_, a2_lis_):
_check_all_vars(v1_lis, v2_lis, v3_lis, cw1o_lis, a1_lis, a2_lis)

objective.erase("cf1")
v1_lis = [cw2, cf2]
v1_lis = [cf2]
v2_lis = [cf3]
cw1o_lis = []
a1_lis = [cf2] # cf1 and cw1 are deleted, since cw1 not used by any other cost fn
Expand All @@ -232,7 +224,7 @@ def _check_all_vars(v1_lis_, v2_lis_, v3_lis_, cw1_opt_lis_, a1_lis_, a2_lis_):
assert cw1 not in objective.cost_functions_for_weights

objective.erase("cf2")
v1_lis = [cw2] # cw2 still used by cf3
v1_lis = []
v3_lis = [cf3]
a1_lis = []
_check_all_vars(v1_lis, v2_lis, v3_lis, cw1o_lis, a1_lis, a2_lis)
Expand Down
194 changes: 194 additions & 0 deletions theseus/core/tests/test_vectorizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import torch

import theseus as th
from theseus.core.vectorizer import _CostFunctionWrapper


def test_costs_vars_and_err_before_vectorization():
for _ in range(20):
objective = th.Objective()
batch_size = torch.randint(low=1, high=10, size=(1,)).item()
v1 = th.Vector(data=torch.randn(batch_size, 1), name="v1")
v2 = th.Vector(data=torch.randn(batch_size, 1), name="v2")
odummy = th.Vector(1, name="odummy")
t1 = th.Vector(data=torch.zeros(1, 1), name="t1")
adummy = th.Variable(data=torch.zeros(1, 1), name="adummy")
cw1 = th.ScaleCostWeight(th.Variable(torch.zeros(1, 1), name="w1"))
cw2 = th.ScaleCostWeight(th.Variable(torch.zeros(1, 1), name="w2"))
cf1 = th.Difference(v1, cw1, t1)

# Also test with autodiff cost
def err_fn(optim_vars, aux_vars):
return optim_vars[0] - aux_vars[0]

cf2 = th.AutoDiffCostFunction([v2, odummy], err_fn, 1, cw2, [t1, adummy])

# Chech that vectorizer's has the correct number of wrappers
objective.add(cf1)
objective.add(cf2)
th.Vectorize(objective)

# Update weights after creating vectorizer to see if data is picked up correctly
w1 = torch.randn(1, 1) # also check that broadcasting works
w2 = torch.randn(batch_size, 1)

# disable for this test since we are not checking the result
objective._vectorization_run = None
objective.update({"w1": w1, "w2": w2})

def _check_attr(cf, var):
return hasattr(cf, var.name) and getattr(cf, var.name) is var

# Check that the vectorizer's cost functions have the right variables and error
saw_cf1 = False
saw_cf2 = False
for cf in objective._get_iterator():
assert isinstance(cf, _CostFunctionWrapper)
optim_vars = [v for v in cf.optim_vars]
aux_vars = [v for v in cf.aux_vars]
assert t1 in aux_vars
assert _check_attr(cf, t1)
w_err = cf.weighted_error()
if cf.cost_fn is cf1:
assert v1 in optim_vars
assert w_err.allclose((v1.data - t1.data) * w1)
assert _check_attr(cf, v1)
saw_cf1 = True
elif cf.cost_fn is cf2:
assert v2 in optim_vars and odummy in optim_vars
assert adummy in aux_vars
assert _check_attr(cf, v2) and _check_attr(cf, odummy)
assert w_err.allclose((v2.data - t1.data) * w2)
saw_cf2 = True
else:
assert False
assert saw_cf1 and saw_cf2


def test_correct_schemas_and_shared_vars():
v1 = th.Vector(1)
v2 = th.Vector(1)
tv = th.Vector(1)
w1 = th.ScaleCostWeight(1.0)
mv = th.Vector(1)

v3 = th.Vector(3)
v4 = th.Vector(3)

s1 = th.SE2()
s2 = th.SE2()
ts = th.SE2()

objective = th.Objective()
# these two can be grouped
cf1 = th.Difference(v1, w1, tv)
cf2 = th.Difference(v2, w1, tv)
objective.add(cf1)
objective.add(cf2)

# this one uses the same weight and v1, v2, but cannot be grouped
cf3 = th.Between(v1, v2, w1, mv)
objective.add(cf3)

# this one is the same cost function type, var type, and weight but different
# dimension, so cannot be grouped either
cf4 = th.Difference(v3, w1, v4)
objective.add(cf4)

# Now add another group with a different data-type (no-shared weight)
w2 = th.ScaleCostWeight(1.0)
w3 = th.ScaleCostWeight(2.0)
cf5 = th.Difference(s1, w2, ts)
cf6 = th.Difference(s2, w3, ts)
objective.add(cf5)
objective.add(cf6)

# Not grouped with anything cf1 and cf2 because weight type is different
w7 = th.DiagonalCostWeight([1.0])
cf7 = th.Difference(v1, w7, tv)
objective.add(cf7)

vectorization = th.Vectorize(objective)

assert len(vectorization._schema_dict) == 5
seen_cnt = [0] * 7
for schema, cost_fn_wrappers in vectorization._schema_dict.items():
cost_fns = [w.cost_fn for w in cost_fn_wrappers]
var_names = vectorization._var_names[schema]
if cf1 in cost_fns:
assert len(cost_fns) == 2
assert cf2 in cost_fns
seen_cnt[0] += 1
seen_cnt[1] += 1
assert f"{th.Vectorize._SHARED_TOKEN}{w1.scale.name}" in var_names
assert f"{th.Vectorize._SHARED_TOKEN}{tv.name}" in var_names
if cf3 in cost_fns:
assert len(cost_fns) == 1
seen_cnt[2] += 1
if cf4 in cost_fns:
assert len(cost_fns) == 1
seen_cnt[3] += 1
if cf5 in cost_fns:
assert len(cost_fns) == 2
assert cf6 in cost_fns
seen_cnt[4] += 1
seen_cnt[5] += 1
assert f"{th.Vectorize._SHARED_TOKEN}{w2.scale.name}" not in var_names
assert f"{th.Vectorize._SHARED_TOKEN}{w3.scale.name}" not in var_names
assert f"{th.Vectorize._SHARED_TOKEN}{ts.name}" in var_names
if cf7 in cost_fns:
assert len(cost_fns) == 1
seen_cnt[6] += 1
assert seen_cnt == [1] * 7


def test_vectorized_error():
rng = np.random.default_rng(0)
generator = torch.Generator()
generator.manual_seed(0)
for _ in range(20):
dim = rng.choice([1, 2])
objective = th.Objective()
batch_size = rng.choice(range(1, 11))

vectors = [
th.Vector(
data=torch.randn(batch_size, dim, generator=generator), name=f"v{i}"
)
for i in range(rng.choice([1, 10]))
]
target = th.Vector(dim, name="target")
w = th.ScaleCostWeight(torch.randn(1, generator=generator))
for v in vectors:
objective.add(th.Difference(v, w, target))

se3s = [
th.SE3(
data=th.SE3.rand(batch_size, generator=generator).data,
requires_check=False,
)
for i in range(rng.choice([1, 10]))
]
s_target = th.SE3.rand(1, generator=generator)
ws = th.DiagonalCostWeight(torch.randn(6, generator=generator))
# ws = th.ScaleCostWeight(torch.randn(1, generator=generator))
for s in se3s:
objective.add(th.Difference(s, ws, s_target))

vectorization = th.Vectorize(objective)
objective.update_vectorization()

assert objective._cost_functions_iterable is vectorization._cost_fn_wrappers
for w in vectorization._cost_fn_wrappers:
for cost_fn in objective.cost_functions.values():
if cost_fn is w.cost_fn:
w_jac, w_err = cost_fn.weighted_jacobians_error()
assert w._cached_error.allclose(w_err)
for jac, exp_jac in zip(w._cached_jacobians, w_jac):
assert jac.allclose(exp_jac, atol=1e-6)
Loading

0 comments on commit e5405fd

Please sign in to comment.