-
Notifications
You must be signed in to change notification settings - Fork 125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PyTorch tests - 6/n - Add type check for list/tuple elems in CONSTANT_MATCH #303
Conversation
torchdynamo/guards.py
Outdated
@@ -257,6 +263,18 @@ def LIST_LENGTH(self, guard): | |||
self.code.append(f"___check_type_id({ref}, {self.id_ref(type(value))})") | |||
self.code.append(f"len({ref}) == {len(value)}") | |||
|
|||
def add_guard_recursively(prefix_index_str, lst): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we are going to add this to the LIST_LENGTH
guard, then we should rename it to something else since it is now checking more than just length.
For length of nested lists, I would expect there to be a different guard that checks the inner length. For something like [[Foo()]]
, I'd expect:
ListVariable([
ListVariable([
UserDefinedObjectVariable(..., guards={TYPE_MATCH of inner object})
],
guards={LIST_LENGTH of inner list}
),
],
guards={LIST_LENGTH of outer list}
)
So there would be 3 different guards to check the 2 lengths and the type of the inner value. If this is failing somewhere, perhaps we need to do something like copy guards from the items of the list up to the containing list. Keeping the guards smaller and sperate will allow them to get automatically de-dupliated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm! You are right. I think currently for some reason, we are not generating ListVariable for the inner list. I am making it more complicated than necessary. Let me check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed this. Now, recursively collecting guards by walking the TupleVariable.
@jansel This is ready for review. The additional guards added above are correct. And I also confirmed by running Torchbench that speedups dont change. |
This PR eliminated the type check for list/tuple EQUALS_MATCH: as in, we will accept a list where a tuple was previously expected. Is this expected? |
OK, I can trigger a soundness error without this guard, going to submit a fix |
For guards like
if
args
actual runtime value is a list of tensor likeThe type check passes and it proceeds to comparing Tensor and a scalar value, leading to boolean Tensor and guard executing failure. This PR adds an extra type checks on each element of the list/tuple when it is a CONSTANT_MATCH guard.
This PR also add list length guard recursively for list.