-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[JIT] Add basic aliasing checks for tensor inputs #79474
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful links
✅ No Failures (0 Pending)As of commit f17a432 (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good! Just had one major comment, and a few minor suggestions
after = tree_flatten(arguments.get(name))[0] | ||
if (any([has_mutated(i, j) for i, j in zip(before, after)]) and not argument.is_mutable): | ||
raise RuntimeError(f"Argument {name} is not defined as mutable but was mutated") | ||
for i in before: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very minor nit: i feel like i
is usually the name of an index, maybe rename it b
? totally not important though, you can ignore if you don't want to change it.
# This TorchDispatchTensor Subclass is used to test incorrectly aliasing ops | ||
# This is done by aliasing the output to the first argument for all ops in INCORRECT_OPS | ||
|
||
class IncorrectAliasTensor(torch.Tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
awesome! This could be pretty useful in the future too if we get around to fixing batch_norm (in which case we would need to find a new way to test mutation)
- Added aliasing checks in SchemaCheckMode that raise a RuntimeError if an argument aliases an output that is not specified by the matching schema - Added python bindings for isAliasOf: a method that checks if two tensors are aliases of each other, before_set: which returns the before sets and after_set: which returns the after sets. - Created IncorrectAliasTensor as a subclass of torch.Tensor that purposefully aliases incorrectly - Tested that alias checks in SchemaCheckMode work for both correct and incorrect situations **Note that the cases of inputs aliasing with each other, outputs aliasing with each other, and cases where the before and after sets are not equivalent (this should only occur when the after set is the wildcard set) are not implemented. These cases will be addressed in a later pr.** [ghstack-poisoned]
torch/csrc/jit/python/init.cpp
Outdated
py::call_guard<py::gil_scoped_release>()); | ||
|
||
m.def("_is_alias_of", [](const at::Tensor& self, const at::Tensor& other) { | ||
return IValue(self).isAliasOf(IValue(other)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: any reason why we're converting the tensors to IValues first? There's also a Tensor::is_alias_of
method.
test/jit/test_schema_check.py
Outdated
"instead.") | ||
|
||
# This TorchDispatchTensor Subclass is used to test incorrectly aliasing ops | ||
# This is done by aliasing the output to the first argument for all ops in INCORRECT_OPS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor nit: the comment says that this subclass is used to test incorrectly aliasing ops, but technically it's SchemaCheckMode
that's responsible for that. It looks like the main purpose of this class is to test that SchemaCheckMode
behaves as expected, by "simulating" an op with an incorrect schema and asserting that the mode picks up the error.
Do you mind updating the description a bit?
Can you make it a bit clearer in this comment that the purpose of this class is basically as a test case for schema_check_mode
? It doesn't look like it's act
test/jit/test_schema_check.py
Outdated
|
||
# Tests that an exception is raised for a mismatching alias | ||
def test_alias_check_fail(self): | ||
with self.assertRaises(RuntimeError): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: for best practice, can you use self.assertRaisesRegex(RuntimeError, "my expected error message")
?
- Added aliasing checks in SchemaCheckMode that raise a RuntimeError if an argument aliases an output that is not specified by the matching schema - Added python bindings for isAliasOf: a method that checks if two tensors are aliases of each other, before_set: which returns the before sets and after_set: which returns the after sets. - Created IncorrectAliasTensor as a subclass of torch.Tensor that purposefully aliases incorrectly - Tested that alias checks in SchemaCheckMode work for both correct and incorrect situations **Note that the cases of inputs aliasing with each other, outputs aliasing with each other, and cases where the before and after sets are not equivalent (this should only occur when the after set is the wildcard set) are not implemented. These cases will be addressed in a later pr.** [ghstack-poisoned]
- Added aliasing checks in SchemaCheckMode that raise a RuntimeError if an argument aliases an output that is not specified by the matching schema - Added python bindings for isAliasOf: a method that checks if two tensors are aliases of each other, before_set: which returns the before sets and after_set: which returns the after sets. - Created IncorrectAliasTensor as a subclass of torch.Tensor that purposefully aliases incorrectly - Tested that alias checks in SchemaCheckMode work for both correct and incorrect situations **Note that the cases of inputs aliasing with each other, outputs aliasing with each other, and cases where the before and after sets are not equivalent (this should only occur when the after set is the wildcard set) are not implemented. These cases will be addressed in a later pr.** [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! just had one comment, please add that and then you can merge once CI is green
return bool(len(lhs_argument.before_set & rhs_argument.before_set)) | ||
|
||
def unwrap(e): | ||
if isinstance(e, torch.Tensor) and not type(e) == torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: could you add a comment about why you need this? (e.g. to handle IncorrectAliasTensor)
@pytorchbot merge |
@pytorchbot successfully started a merge job. Check the current status here |
Hey @goldenxuett. |
Summary: Pull Request resolved: #79474 Approved by: https://github.com/davidberard98 Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/1432a3d6acc7ac58269e5ea7ef6b2373e9a3dbf6 Reviewed By: malfet Differential Revision: D37278870 Pulled By: goldenxuett fbshipit-source-id: 2316e8ac4e368fedc8c0d6dd9a4ffb28bd33571c
Stack from ghstack (oldest at bottom):
Note that the cases of inputs aliasing with each other, outputs aliasing with each other, and cases where the before and after sets are not equivalent (this should only occur when the after set is the wildcard set) are not implemented. These cases will be addressed in a later pr.