Skip to content

Commit

Permalink
[dynamo] graph break on issubclass call with non-const args (#125943)
Browse files Browse the repository at this point in the history
Fixes #125942

Pull Request resolved: #125943
Approved by: https://github.com/jansel
ghstack dependencies: #125882
  • Loading branch information
williamwen42 authored and pytorchmergebot committed May 15, 2024
1 parent 100e3c1 commit 56a89fc
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
9 changes: 9 additions & 0 deletions test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -4962,6 +4962,15 @@ def fn(x):
inp = torch.randn(3, 3)
self.assertEqual(fn(inp), opt_fn(inp))

def test_nonconst_issubclass(self):
def fn(x):
if issubclass(x.__class__, np.ndarray):
return 1
return 0

opt_fn = torch.compile(fn, backend="eager")
opt_fn(np.ones([3, 3]))


instantiate_parametrized_tests(ReproTests)

Expand Down
11 changes: 8 additions & 3 deletions torch/_dynamo/variables/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,10 +1402,15 @@ def check_type(ty):

def call_issubclass(self, tx, left_ty, right_ty):
"""Checks if first arg is subclass of right arg"""
left_ty = left_ty.as_python_constant()
right_ty = right_ty.as_python_constant()
try:
left_ty_py = left_ty.as_python_constant()
right_ty_py = right_ty.as_python_constant()
except NotImplementedError:
unimplemented(
f"call_issubclass args not constant left_ty: {left_ty}, right_ty: {right_ty}"
)

return variables.ConstantVariable(issubclass(left_ty, right_ty))
return variables.ConstantVariable(issubclass(left_ty_py, right_ty_py))

def call_super(self, tx, a, b):
return variables.SuperVariable(a, b)
Expand Down

0 comments on commit 56a89fc

Please sign in to comment.