Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5246,6 +5246,21 @@ def forward(self, input):
prof.report()
)

def test_guards_strip_function_call(self):
from torch._dynamo.guards import strip_function_call

test_case = [
("___odict_getitem(a, 1)", "a"),
("a.layers[slice(2)][0]._xyz", "a"),
("getattr(a.layers[slice(2)][0]._abc, '0')", "a"),
("getattr(getattr(a.x[3], '0'), '3')", "a"),
("a.layers[slice(None, -1, None)][0]._xyz", "a"),
("a.layers[func('offset', -1, None)][0]._xyz", "a"),
]
# strip_function_call should extract the object from the string.
for name, expect_obj in test_case:
self.assertEqual(strip_function_call(name), expect_obj)


class CustomFunc1(torch.autograd.Function):
@staticmethod
Expand Down
19 changes: 16 additions & 3 deletions torch/_dynamo/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,23 @@
def strip_function_call(name):
"""
"___odict_getitem(a, 1)" => "a"
"a.layers[slice(2)][0]._xyz" ==> "a"
"getattr(a.layers[slice(2)][0]._abc, '0')" ==> "a"
"getattr(getattr(a.x[3], '0'), '3')" ==> "a"
"a.layers[slice(None, -1, None)][0]._xyz" ==> "a"
"""
m = re.search(r"([a-z0-9_]+)\(([^(),]+)[^()]*\)", name)
if m and m.group(1) != "slice":
return strip_function_call(m.group(2))
# recursively find valid object name in fuction
valid_name = re.compile("[A-Za-z_].*")
curr = ""
for char in name:
if char in " (":
curr = ""
elif char in "),[]":
if curr and curr != "None" and valid_name.match(curr):
return strip_function_call(curr)
else:
curr += char

return strip_getattr_getitem(name)


Expand Down