Skip to content

Commit

Permalink
General in-place binary op support in dynamo (#94203)
Browse files Browse the repository at this point in the history
Continues the approach taken in #93271, expanding support to in-place binary ops (e.g. `__iadd__`).

Pull Request resolved: #94203
Approved by: https://github.com/ezyang
  • Loading branch information
jbschlosser authored and pytorchmergebot committed Feb 7, 2023
1 parent f954498 commit bf4fe5d
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 57 deletions.
44 changes: 44 additions & 0 deletions test/dynamo/test_misc.py
Expand Up @@ -171,6 +171,22 @@ def fn(x):
self, fn, 1, expected_ops=1, expected_ops_dynamic=11
)

def test_int_shape_inplace_binops(self):
def fn(x):
p = x.shape[0]
p += 2
p -= 2
p **= 2
p /= 2
p *= 2
p //= 2
p %= 2
return x + p

torch._dynamo.testing.standard_test(
self, fn, 1, expected_ops=1, expected_ops_dynamic=10
)

def test_param_shape_binops(self):
class MyModule(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -780,6 +796,34 @@ def fn(a):
self, fn, 1, expected_ops=1, expected_ops_dynamic=3
)

def test_tuple_iadd_with_shape(self):
def fn(a):
output = (a + a.shape[0], a - a.shape[0])
# tuple += tuple
output += (a - a.shape[0], a + a.shape[0])
# tuple += constant tuple
output += (2, 3)
return output

# expect 4 add / subs for static, 4 * 3 (size, index, math op) for dynamic
torch._dynamo.testing.standard_test(
self, fn, 1, expected_ops=4, expected_ops_dynamic=12
)

def test_list_iadd_with_shape(self):
def fn(a):
output = [a + a.shape[0], a - a.shape[0]]
# list += list
output += [a - a.shape[0], a + a.shape[0]]
# list += tuple
output += (a + a.shape[0], a - a.shape[0])
return output

# expect 6 add / subs for static, 6 * 3 (size, index, math op) for dynamic
torch._dynamo.testing.standard_test(
self, fn, 1, expected_ops=6, expected_ops_dynamic=18
)

def test_user_getattr1(self):
class MyConfig(dict):
def __getattr__(self, name):
Expand Down
5 changes: 0 additions & 5 deletions torch/__init__.py
Expand Up @@ -243,11 +243,6 @@ def __bool__(self):
def __int__(self):
return self.node.int_()

# This is a hack, shouldn't be necessary. Helps
# pyhpc_turbulent_kinetic_energy and vision_maskrcnn
def __iadd__(self, other):
return self + other

# Magic methods installed by torch.fx.experimental.symbolic_shapes

def __eq__(self, other: object) -> builtins.bool:
Expand Down
141 changes: 103 additions & 38 deletions torch/_dynamo/variables/builtin.py
Expand Up @@ -164,6 +164,27 @@ def _reversible_binops():
}
return fns

@staticmethod
@functools.lru_cache(None)
def _inplace_binops():
fns = {
operator.ipow: "__ipow__",
operator.imul: "__imul__",
operator.imatmul: "__imatmul__",
operator.ifloordiv: "__ifloordiv__",
operator.itruediv: "__itruediv__",
operator.imod: "__imod__",
operator.iadd: "__iadd__",
operator.iconcat: "__iconcat__",
operator.isub: "__isub__",
operator.ilshift: "__ilshift__",
operator.irshift: "__irshift__",
operator.iand: "__iand__",
operator.ixor: "__ixor__",
operator.ior: "__ior__",
}
return fns

@staticmethod
@functools.lru_cache(None)
def _binop_handlers():
Expand All @@ -174,34 +195,49 @@ def _binop_handlers():

# Override table contains: op_fn -> [list of handlers]
op_handlers = {}
for (
op,
(forward_name, reverse_name),
) in BuiltinVariable._reversible_binops().items():
for (op, magic_method_names) in itertools.chain(
BuiltinVariable._inplace_binops().items(),
BuiltinVariable._reversible_binops().items(),
):
handlers = []

# User-defined args (highest precedence)
def user_defined_handler(
tx, a, b, options, forward_name=forward_name, reverse_name=reverse_name
):
# Manually handle reversing logic if needed (e.g. call __radd__)

# TODO: If we expand this to handle tensor args, we need to manually
# handle cases like this:
#
# class A(int):
# def __radd__(self, other):
# print("woof")
# torch.randn(3) + A(3)
#
# In this example, A.__radd__() is not called -> nothing is printed, because
# Tensor.__add__ only does a subtype test against int and will ignore the subclass.
# To be fully correct, we should not call A.__radd__() here, and there may be
# other cases to reason about and add exceptions for.
if isinstance(a, UserDefinedVariable):
if isinstance(magic_method_names, tuple):
# Reversible binary ops have forward / backward magic methods
forward_name, reverse_name = magic_method_names

def user_defined_handler(
tx,
a,
b,
options,
forward_name=forward_name,
reverse_name=reverse_name,
):
# Manually handle reversing logic if needed (e.g. call __radd__)

# TODO: If we expand this to handle tensor args, we need to manually
# handle cases like this:
#
# class A(int):
# def __radd__(self, other):
# print("woof")
# torch.randn(3) + A(3)
#
# In this example, A.__radd__() is not called -> nothing is printed, because
# Tensor.__add__ only does a subtype test against int, ignoring the subclass.
# To be fully correct, we should not call A.__radd__() here, and there may be
# other cases to reason about and add exceptions for.
if isinstance(a, UserDefinedVariable):
return a.call_method(tx, forward_name, [b], {})
else:
return b.call_method(tx, reverse_name, [a], {})

else:
forward_name = magic_method_names

def user_defined_handler(tx, a, b, options, forward_name=forward_name):
return a.call_method(tx, forward_name, [b], {})
else:
return b.call_method(tx, reverse_name, [a], {})

handlers.append(
((UserDefinedVariable, VariableTracker), user_defined_handler)
Expand Down Expand Up @@ -230,33 +266,64 @@ def dynamic_handler(tx, a, b, options, fn=op):
# Special cases - lower precedence but still prefer these over constant folding

# List-like addition (e.g. [1, 2] + [3, 4])
def tuple_add_handler(tx, a, b, options):
return TupleVariable(a.items + list(b.unpack_var_sequence(tx)), **options)

list_like_addition_handlers = [
# NB: Prefer the tuple-specific logic over base logic because of
# some SizeVariable weirdness. Specifically, the tuple-specific logic
# drops the subclass type (e.g. SizeVariable) and returns TupleVariables.
(
(TupleVariable, TupleVariable),
tuple_add_handler,
),
(
(TupleVariable, ConstantVariable),
lambda tx, a, b, options: TupleVariable(
a.items + list(b.unpack_var_sequence(tx)), **options
),
tuple_add_handler,
),
(
(ConstantVariable, TupleVariable),
lambda tx, a, b, options: TupleVariable(
list(a.unpack_var_sequence(tx)) + b.items, **options
),
),
(
(TupleVariable, TupleVariable),
lambda tx, a, b, options: TupleVariable(a.items + b.items, **options),
),
(
(BaseListVariable, BaseListVariable),
lambda tx, a, b, options: type(a)(a.items + b.items, **options),
),
]
op_handlers[operator.add].extend(list_like_addition_handlers)

def list_iadd_handler(tx, a, b, options):
if not a.mutable_local or not b.has_unpack_var_sequence(tx):
# Handler doesn't apply
return None

return tx.replace_all(
a,
ListVariable(
list(a.items) + list(b.unpack_var_sequence(tx)),
regen_guards=False,
**options,
),
)

list_like_iadd_handlers = [
(
(ListVariable, VariableTracker),
list_iadd_handler,
),
(
(TupleVariable, TupleVariable),
tuple_add_handler,
),
(
(TupleVariable, ConstantVariable),
tuple_add_handler,
),
]
op_handlers[operator.iadd].extend(list_like_iadd_handlers)

# List-like expansion (e.g. [1, 2, 3] * 3)
def expand_list_like(tx, lst, const, options):
return lst.__class__(
Expand Down Expand Up @@ -466,18 +533,19 @@ def call_function(
)
return out

# Handle functions that are reversible (e.g. __add__ / __radd__)
# Handle binary ops (e.g. __add__ / __radd__, __iadd__, etc.)
# NB: Tensor args are handled above and not here
reversible_binops = self._reversible_binops()
if self.fn in reversible_binops:
if self.fn in self._reversible_binops() or self.fn in self._inplace_binops():
assert len(kwargs) == 0 and len(args) == 2

# Try to find a handler for the arg types; otherwise, fall through to constant handler
binop_handler = BuiltinVariable._find_binop_handler(
self.fn, args[0], args[1]
)
if binop_handler:
return binop_handler(tx, args[0], args[1], options)
res = binop_handler(tx, args[0], args[1], options)
if res is not None:
return res

handler = getattr(self, f"call_{self.fn.__name__}", None)
if handler:
Expand Down Expand Up @@ -696,9 +764,6 @@ def call_enumerate(self, tx, *args):
def call_len(self, tx, *args, **kwargs):
return args[0].call_method(tx, "__len__", args[1:], kwargs)

def call_iadd(self, tx, *args, **kwargs):
return args[0].call_method(tx, "__iadd__", args[1:], kwargs)

def call_getitem(self, tx, *args, **kwargs):
if self.unspec_python_args(*args, **kwargs):
args, kwargs = specialize_args_kwargs(tx, args, kwargs)
Expand Down
15 changes: 1 addition & 14 deletions torch/_dynamo/variables/lists.py
Expand Up @@ -223,7 +223,7 @@ def call_method(
)
return ConstantVariable(None)
elif (
name in ("extend", "__iadd__")
name == "extend"
and self.mutable_local
and args
and args[0].has_unpack_var_sequence(tx)
Expand Down Expand Up @@ -296,19 +296,6 @@ def call_method(
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
options = VariableTracker.propagate(self, args, kwargs.values())
if name == "__iadd__" and len(args) == 1 and isinstance(args[0], TupleVariable):
assert not kwargs
return TupleVariable(self.items + args[0].items, **options)
elif (
name == "__iadd__"
and len(args) == 1
and isinstance(args[0], variables.ConstantVariable)
):
assert not kwargs
return TupleVariable(
self.items + list(args[0].unpack_var_sequence(self)), **options
)
return super().call_method(tx, name, args, kwargs)


Expand Down

0 comments on commit bf4fe5d

Please sign in to comment.