Skip to content

Commit

Permalink
return_and_correct_aliasing: massage some schemas to work with torchg…
Browse files Browse the repository at this point in the history
…en (#108897)

This issue is that `str(torch.ops.aten.conv2d.default._schema)` does not return the same schema that is in native_functions.yaml ([link](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml#L1654)).

Torchscript appears to change the default arg string `int[2] strides=1` to `int[2] strides=[1, 1]`. If you try to parse that with torchgen, torchgen is unhappy (it tries to split arguments on comma, but now we have a comma inside of the default argument).

Fixing the issue directly in torchgen was a bit more painful, so I opted just to undo the transformation that torchscript made: convert `=[1, 1]` back into `=1`.

Pull Request resolved: #108897
Approved by: https://github.com/ezyang
ghstack dependencies: #106404, #107917
  • Loading branch information
bdhirsh authored and pytorchmergebot committed Sep 15, 2023
1 parent 0ad5959 commit 71b4b32
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 21 deletions.
10 changes: 10 additions & 0 deletions test/test_python_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2249,6 +2249,16 @@ def test_wrapper_subclass_aliasing_custom(self, device, dtype, op):
kwargs = sample.kwargs
self._test_wrapper_subclass_aliasing(op, args, kwargs)

def test_wrapper_subclass_aliasing_conv2d(self, device):
args = (torch.randn(4, 4, 4, 4), torch.randn(4, 4, 4, 4))
kwargs = {}
# conv2d has a default arg 'int[2] strides=0',
# which torchscript expands into 'int[2] strides=[0, 0]'
# Make sure that _return_and_correct_aliasing can handle this case
# (I'm using inference_mode to make sure conv2d doesn't decompose and goes to torch_dispatch)
with torch.inference_mode():
self._test_wrapper_subclass_aliasing(torch.ops.aten.conv2d.default, args, kwargs)

instantiate_device_type_tests(TestWrapperSubclassAliasing, globals())

if __name__ == '__main__':
Expand Down
44 changes: 23 additions & 21 deletions torch/utils/_python_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,9 @@ def alias_non_inplace_storage(arg, ret):
# plain tensors, we could remove the assert and just not perform the aliasing,
# but it seems safer to learn more about this case first.
if is_traceable_wrapper_subclass(arg) or is_traceable_wrapper_subclass(ret):
assert type(arg) == type(ret), f"""Called {str(func)} with input of type {type(arg)}
ret_list = ret if isinstance(ret, list) else [ret]
for r in ret_list:
assert type(arg) == type(r), f"""Called {str(func)} with input of type {type(arg)}
and output of type {type(ret)}. But expected types to match."""
# Need to run under no_dispatch, because we explicitly do **not**
# want our subclass to intercept the set_() call.
Expand All @@ -211,7 +213,12 @@ def alias_non_inplace_storage(arg, ret):
# Example: out = inp.expand(inp.shape[0], inp.shape[0])
# This requires swapping the storage of out to be the same as inp,
# but we do *not* want it to change the sizes/strides that were compute for out.
torch.ops.aten.set_.source_Storage_storage_offset(ret, arg.untyped_storage(), ret.storage_offset(), ret.shape)
if isinstance(ret, list):
for r in ret:
torch.ops.aten.set_.source_Storage_storage_offset(r, arg.untyped_storage(), r.storage_offset(), r.shape)
else:
assert isinstance(ret, torch.Tensor), f"type: {type(ret)}"
torch.ops.aten.set_.source_Storage_storage_offset(ret, arg.untyped_storage(), ret.storage_offset(), ret.shape)
finally:
torch._C._set_meta_in_tls_dispatch_include(meta_in_tls)

Expand All @@ -226,23 +233,6 @@ def is_read_only_alias_match(arg, ret):
if is_read_only_alias_match(schema_info.args[arg_idx], schema_info.outs[return_idx]):
alias_non_inplace_storage(args[arg_idx], outs[return_idx])

# Sigh... the torchscript parser has a bug where alias annotations for Tensor[](a) don't show up properly
# See https://github.com/pytorch/pytorch/issues/106173
if func.overloadpacket in [
torch.ops.aten.chunk,
torch.ops.aten.tensor_split,
torch.ops.aten.split,
torch.ops.aten.split_with_sizes,
torch.ops.aten.hsplit,
torch.ops.aten.vsplit,
torch.ops.aten.dsplit,
torch.ops.aten.unbind,
]:
assert isinstance(outs, list) and all(isinstance(x, torch.Tensor) for x in outs)
for o in outs:
# For lists of outputs, need to alias every individual tensor to the input
alias_non_inplace_storage(args[0], o)

# This abstracts over the fact that in return_and_correct_aliasing,
# we sometimes use torchgen schema parsing (for aten ops, since torchscript's schema parsing is sometimes buggy),
# and sometimes use torchscript schema parsing (for custom ops, for which torchgen parsing is untested).
Expand All @@ -267,7 +257,19 @@ def get_alias_info(func) -> SchemaInfo:
# For ATen ops: use torchgen (since torchscript parser doesn't handle alias annotations
# properly for some ops that output tensorlists)
if func.namespace == "aten":
torchgen_schema = torchgen.model.FunctionSchema.parse(str(func._schema))
torchgen_schema_str = str(func._schema)
assert torchgen_schema_str.startswith("aten::")
# remove the aten:: namespace, which is added by the torchscript parser,
# and torchgen doesn't know how to handle
torchgen_schema_str = torchgen_schema_str[6:]
import re
# the torchscript parser ends up converting int[2]=1 into int[2]=[1, 1],
# which torchgen chokes on.
torchgen_schema_str = re.sub(r'=\[[0, ]+\]', '=0', torchgen_schema_str)
torchgen_schema_str = re.sub(r'=\[[1, ]+\]', '=1', torchgen_schema_str)
# for aten::rot90
torchgen_schema_str = torchgen_schema_str.replace("=[0, 1]", "=[0,1]")
torchgen_schema = torchgen.model.FunctionSchema.parse(torchgen_schema_str)
arg_schemas = [AliasInfo(
alias_set=set() if a.annotation is None else set(a.annotation.alias_set),
is_write=a.annotation is not None and a.annotation.is_write
Expand Down Expand Up @@ -331,7 +333,7 @@ def get_arg_idx_from_alias(output_alias):

# Fix up the storages of any outs so that they point to the same storage as the input,
# if func is a view op.
_correct_storage_aliasing(func, schema_info, args, [out] if not isinstance(out, (list, tuple)) else out)
_correct_storage_aliasing(func, schema_info, args, (out,) if not isinstance(out, tuple) else out)

# For inplace_view ops in particular, we'll try hard to make sure that the wrapper subclass's
# metadata is set correctly.
Expand Down

0 comments on commit 71b4b32

Please sign in to comment.