Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dynamo] torch._dynamo.exc.Unsupported: comparison SymNodeVariable() <built-in function is_> ListVariable() #109504

Open
ezyang opened this issue Sep 18, 2023 · 9 comments
Labels
good first issue module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ezyang
Copy link
Contributor

ezyang commented Sep 18, 2023

馃悰 Describe the bug

Self explanatory.

Occurred when implementing self-referentiality check in torch.tensor ref

            if cur_item is obj:
                raise TypeError("new(): self-referential lists are incompatible")

Versions

main

cc @msaroufim @wconstab @bdhirsh @anijain2305 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov @kadeng

@aqbewtra
Copy link

This is my first issue. Do you mind providing more detail on how I can recreate the bug?

@ezyang
Copy link
Contributor Author

ezyang commented Sep 18, 2023

import torch

@torch.compile(backend='eager', fullgraph=True, dynamic=True)
def f(x, xs):
    if x.size(0) is xs:
        return x + 1
    else:
        return x * 2

f(torch.randn(2), [1, 2])

@jansel jansel added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 22, 2023
@vdesai2014
Copy link

vdesai2014 commented Sep 25, 2023

@ezyang

Currently when the 'is' operator is used with a SymNodeVariable, an unimplemented exception is throw because the op isn't currently part of supported_tensor_comparison_ops (line 1442, dynamo/variables/built_in.py). There is an existing if statement to handle 'is' operator called on different types (lines 1458-1461) but the fact that its lower down in the function causes the unimplemented exception to get thrown first.

I think moving the catch for 'is' op on different types to the top of the function should fix this issue without any second-order effects. Can you confirm this makes sense? My first Pytorch issue so want to make sure.

@ezyang
Copy link
Contributor Author

ezyang commented Sep 25, 2023

Give it a try

@jon-chuang
Copy link
Collaborator

jon-chuang commented Oct 19, 2023

How about this test case?

import torch

@torch.compile(backend='eager', fullgraph=True, dynamic=True)
def f(x, xs):
    if x.size(0) is xs:
        return x + 1
    else:
        return x * 2

f(torch.randn(2), 2)

We also want to handle the case where it is indeed self-referential etc. For that we need something closer to #111550

@laithsakka
Copy link
Contributor

@vdesai2014 @jon-chuang is this still an open issue? is anyone working on this?

@ezyang
Copy link
Contributor Author

ezyang commented Jan 31, 2024

Feel free to grab this

@ezyang
Copy link
Contributor Author

ezyang commented Mar 23, 2024

still repros as of 0a1b3be

@joshue031
Copy link

If this is still free to work on, I would propose adding the following logic to _comparison_with_symnode in torch/_dynamo/variables/builtin.py:

if op in [operator.is_, operator.is_not]:
    is_result = (
        isinstance(left, SymNodeVariable)
        and isinstance(right, SymNodeVariable)
        and left.sym_num is right.sym_num
    )
    if op is operator.is_:
        return ConstantVariable.create(is_result)
    else:
        return ConstantVariable.create(not is_result)

This is similar to what is done in _comparison_with_tensor and seems to work for the two example test cases above.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants