Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug] Dynamo graph break when using pyton module heapq (manipulates with lists), although succeeds when placing heapq.py near the test script #106885

Open
vadimkantorov opened this issue Aug 9, 2023 · 8 comments
Labels
low priority We're unlikely to get around to doing this in the near future triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Aug 9, 2023

馃悰 Describe the bug

OP: https://discuss.pytorch.org/t/is-heapq-module-supported-for-compilation-by-dynamo/185863:

heapq is a standard python's module useful for priority queue loops for Dijkstra-like algos:

@msaroufim: @voznesenskym on github might say this is a good dynamo starter task XD

import heapq
import torch

@torch.compile(fullgraph=True)
def program():
  h = []
  heapq.heappush(h, 3)
  heapq.heappush(h, 1)
  heapq.heappush(h, 4)
  val = heapq.heappop(h)
  return val * torch.randn(10)

program()

Versions

N/A

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov @kadeng @anijain2305 @ipiszy

@janeyx99 janeyx99 added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: dynamo labels Aug 9, 2023
@msaroufim
Copy link
Member

msaroufim commented Aug 10, 2023

Out of curiosity, have you seen heaps used in model code - curious what kinds of applications you're thinking of

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented Aug 10, 2023

I have my own reimpl of selective search, which is an iterative graph merging procedure and has a kind of tight Python loop. So I was wondering if Dynamo can compile the whole thing and approach C++ in loop codegen :)

Or you might want to have Dikjstra / Prim algorithms as well

@anijain2305 anijain2305 added the low priority We're unlikely to get around to doing this in the near future label Sep 13, 2023
@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented Dec 10, 2023

It seems that dynamo does not want to trace into Python stdlib's heapq.heappush, even though inside there exists a pure Python code path dealing only with Python lists: torch._dynamo.exc.Unsupported: call_method UserDefinedObjectVariable(heappush) __call__ [ListVariable(), ConstantVariable(int)] {}

TORCH_LOGS="+dynamo" TORCHDYNAMO_VERBOSE=1 python3 foo.py
[2023-12-10 15:23:32,763] torch._dynamo.eval_frame: [DEBUG] skipping helper /usr/lib/python3.10/contextlib.py
[2023-12-10 15:23:32,764] torch._dynamo.eval_frame: [DEBUG] skipping __init__ /usr/lib/python3.10/contextlib.py
[2023-12-10 15:23:32,764] torch._dynamo.eval_frame: [DEBUG] skipping __enter__ /usr/lib/python3.10/contextlib.py
[2023-12-10 15:23:32,765] torch._dynamo.eval_frame: [DEBUG] skipping backend_cache_wrapper /home/vadimkantorov/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py
[2023-12-10 15:23:32,766] torch._dynamo.eval_frame: [DEBUG] skipping _maybe_init_guarded_backend_cache /home/vadimkantorov/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py
[2023-12-10 15:23:32,766] torch._dynamo.eval_frame: [DEBUG] skipping innermost_fn /home/vadimkantorov/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py
[2023-12-10 15:23:32,767] torch._dynamo.eval_frame: [DEBUG] skipping _set_current_backend /home/vadimkantorov/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py
[2023-12-10 15:23:32,768] torch._dynamo.eval_frame: [DEBUG] skipping __init__ /usr/lib/python3.10/contextlib.py
[2023-12-10 15:23:32,768] torch._dynamo.eval_frame: [DEBUG] skipping __enter__ /usr/lib/python3.10/contextlib.py
[2023-12-10 15:23:32,769] torch._dynamo.eval_frame: [DEBUG] skipping enable_dynamic /home/vadimkantorov/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py
[2023-12-10 15:23:32,771] [0/0] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing program /mnt/c/Users/vadim/notionexport/notionfun/foo.py:4
[2023-12-10 15:23:32,774] [0/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /mnt/c/Users/vadim/notionexport/notionfun/foo.py:4 in program
[2023-12-10 15:23:32,774] [0/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]     @torch.compile(fullgraph=True)
[2023-12-10 15:23:32,775] [0/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /mnt/c/Users/vadim/notionexport/notionfun/foo.py:6 in program
[2023-12-10 15:23:32,775] [0/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]       h = []
[2023-12-10 15:23:32,775] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE BUILD_LIST 0 []
[2023-12-10 15:23:32,776] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE STORE_FAST h [ListVariable()]
[2023-12-10 15:23:32,776] [0/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /mnt/c/Users/vadim/notionexport/notionfun/foo.py:7 in program
[2023-12-10 15:23:32,776] [0/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]       heapq.heappush(h, 3)
[2023-12-10 15:23:32,777] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_GLOBAL heapq []
[2023-12-10 15:23:32,778] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_ATTR heappush [PythonModuleVariable()]
[2023-12-10 15:23:32,797] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST h [UserDefinedObjectVariable(heappush)]
[2023-12-10 15:23:32,797] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_CONST 3 [UserDefinedObjectVariable(heappush), ListVariable()]
[2023-12-10 15:23:32,798] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE CALL_FUNCTION 2 [UserDefinedObjectVariable(heappush), ListVariable(), ConstantVariable(int)]
[2023-12-10 15:23:32,798] [0/0] torch._dynamo.symbolic_convert: [DEBUG] empty checkpoint
Traceback (most recent call last):
  File "/mnt/c/Users/vadim/notionexport/notionfun/foo.py", line 13, in <module>
    program()
  File "/home/vadimkantorov/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 406, in _fn
    return fn(*args, **kwargs)
  File "/home/vadimkantorov/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 554, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/home/vadimkantorov/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 140, in _fn
    return fn(*args, **kwargs)
  File "/home/vadimkantorov/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 380, in _convert_frame_assert
    return _compile(
  File "/home/vadimkantorov/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 559, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/vadimkantorov/.local/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 190, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/vadimkantorov/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 481, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/vadimkantorov/.local/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/home/vadimkantorov/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 451, in transform
    tracer.run()
  File "/home/vadimkantorov/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2094, in run
    super().run()
  File "/home/vadimkantorov/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 739, in run
    and self.step()
  File "/home/vadimkantorov/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 702, in step
    getattr(self, inst.opname)(inst)
  File "/home/vadimkantorov/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 403, in wrapper
    return inner_fn(self, inst)
  File "/home/vadimkantorov/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1135, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/vadimkantorov/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/vadimkantorov/.local/lib/python3.10/site-packages/torch/_dynamo/variables/user_defined.py", line 388, in call_function
    return self.call_method(tx, "__call__", args, kwargs)
  File "/home/vadimkantorov/.local/lib/python3.10/site-packages/torch/_dynamo/variables/user_defined.py", line 308, in call_method
    return super().call_method(tx, name, args, kwargs)
  File "/home/vadimkantorov/.local/lib/python3.10/site-packages/torch/_dynamo/variables/base.py", line 329, in call_method
    raise unimplemented(f"call_method {self} {name} {args} {kwargs}")
  File "/home/vadimkantorov/.local/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 176, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: call_method UserDefinedObjectVariable(heappush) __call__ [ListVariable(), ConstantVariable(int)] {}

from user code:
   File "/mnt/c/Users/vadim/notionexport/notionfun/foo.py", line 7, in program
    heapq.heappush(h, 3)


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

[2023-12-10 15:23:32,843] torch._dynamo.utils: [INFO] TorchDynamo compilation metrics:
[2023-12-10 15:23:32,843] torch._dynamo.utils: [INFO] Function, Runtimes (s)
[2023-12-10 15:23:32,843] torch._dynamo.utils: [INFO] _compile.<locals>.compile_inner, 0.0000

If I download wget https://raw.githubusercontent.com/python/cpython/3.12/Lib/heapq.py and run the same code, it evaluates the value as constant:

Important

I commented at the bottom of the file

# If available, use C implementation
#try:
#    from _heapq import *
#except ImportError:
#    pass
#try:
#    from _heapq import _heapreplace_max
#except ImportError:
#    pass
#try:
#    from _heapq import _heapify_max
#except ImportError:
#    pass
#try:
#    from _heapq import _heappop_max
#except ImportError:
#    pass
#
#
#if __name__ == "__main__":
#
#    import doctest # pragma: no cover
#    print(doctest.testmod()) # pragma: no cover

How could we make the graph capture dynamic? E.g. so that heapq behavior can be traced as data-dependent (e.g. some pytorch_scalar.item() or even directly pytorch_scalar put in to heapq.heappush)

[2023-12-10 15:27:01,400] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] TRACED GRAPH
[2023-12-10 15:27:01,400] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]  ===== __compiled_fn_0 =====
[2023-12-10 15:27:01,400] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]  <eval_with_key>.0 class GraphModule(torch.nn.Module):
[2023-12-10 15:27:01,400] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]     def forward(self):
[2023-12-10 15:27:01,400] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         # File: /mnt/c/Users/vadim/notionexport/notionfun/foo.py:11, code: return val * torch.randn(10)
[2023-12-10 15:27:01,400] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         randn = torch.randn(10)
[2023-12-10 15:27:01,400] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         mul = 1 * randn;  randn = None
[2023-12-10 15:27:01,400] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         return (mul,)
[2023-12-10 15:27:01,400] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2023-12-10 15:27:01,400] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]

@vadimkantorov
Copy link
Contributor Author

E.g. this at least not breaks in regular Python:

import heapq
import torch
# @torch.compile(fullgraph=True)
def program(x):
  h = []
  heapq.heappush(h, x)
  val = heapq.heappop(h)
  return val * torch.randn(10)
program(torch.tensor(1))
import heapq
import torch
@torch.compile(fullgraph=True)
def program(x):
  h = []
  heapq.heappush(h, x)
  val = heapq.heappop(h)
  return val * torch.randn(10)
program(torch.tensor(1))

produces
log.txt

@vadimkantorov vadimkantorov changed the title Dynamo graph break when using pyton module heapq (manipulates with lists) [bug] Dynamo graph break when using pyton module heapq (manipulates with lists), although succeeds when placing heapq.py near the test script Dec 10, 2023
@furlat
Copy link

furlat commented Jan 2, 2024

any updates? does it work in eager mode?

@vadimkantorov
Copy link
Contributor Author

I'll check, but for scalars it should work, yeah (as comparisons should work). Even with torch.compile it seems that it might work if heapq.py is copied to be near the script. It breaks straight up if heapq is used from the system library - and this is strange...

@anijain2305 anijain2305 added low priority We're unlikely to get around to doing this in the near future and removed low priority We're unlikely to get around to doing this in the near future module: dynamo labels Jan 31, 2024
@vadimkantorov
Copy link
Contributor Author

@anijain2305 I would say that this is still a dynamo issue?

@0x00b1
Copy link
Contributor

0x00b1 commented May 31, 2024

I ran into this today while working on a PyTorch implementation of farthest-first traversal.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
low priority We're unlikely to get around to doing this in the near future triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants