Skip to content

Conversation

anmyachev
Copy link
Collaborator

@anmyachev anmyachev commented Sep 18, 2024

#136087 update pybind11 to 2.13.6 and that new release has the feature which is expressed by a new function _pybind11_conduit_v1_. The presence of this function breaks the serialization mechanisms used by Titon and in PyTorch itself.

Possible errors that have been noticed due to this change:

the first error
_________ KernelTests.test_layout_constraint_needs_fixed_stride_order __________
Traceback (most recent call last):
  File "/runner/_work/intel-xpu-backend-for-triton/intel-xpu-backend-for-triton/pytorch/test/inductor/test_triton_kernels.py", line 1072, in test_layout_constraint_needs_fixed_stride_order
    eager_out = f(x)
  File "/runner/_work/intel-xpu-backend-for-triton/intel-xpu-backend-for-triton/pytorch/test/inductor/test_triton_kernels.py", line 1068, in f
    arange_out(x, y)
  File "/runner/_work/intel-xpu-backend-for-triton/intel-xpu-backend-for-triton/pytorch/test/inductor/test_triton_kernels.py", line 1059, in arange_out
    kernel[grid](x, out, n_elements, BLOCK_SIZE=4)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/triton/runtime/jit.py", line 330, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/triton/runtime/jit.py", line 657, in run
    kernel = self.compile(
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/triton/compiler/compiler.py", line 315, in compile
    metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename,
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/json/__init__.py", line 234, in dumps
    return cls(
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/json/encoder.py", line 199, in encode
    chunks = self.iterencode(o, _one_shot=True)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/json/encoder.py", line 257, in iterencode
    return _iterencode(o, 0)
TypeError: vars() argument must have __dict__ attribute
the second error
________________ TestTritonWrapper.test_wrapper_using_gpu_seed _________________
Traceback (most recent call last):
  File "/cache/pytorch-c5e9d03a2da4b93481737594cbe2f5931fa569aa833f206a638189cad2c36d3c-11/test/inductor/test_triton_wrapper.py", line 40, in test_wrapper_using_gpu_seed
    out = f(x, y)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
    return fn(*args, **kwargs)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 1292, in __call__
    return self._torchdynamo_orig_callable(
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 1087, in __call__
    result = self._inner_convert(
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 530, in __call__
    return _compile(
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 933, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 675, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
    return function(*args, **kwargs)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 708, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
    transformations(instructions, code_options)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 220, in _fn
    return fn(*args, **kwargs)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 643, in transform
    tracer.run()
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2776, in run
    super().run()
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 979, in run
    while self.step():
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 891, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2967, in RETURN_VALUE
    self._return(inst)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2952, in _return
    self.output.compile_subgraph(
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 1117, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 1369, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 1416, in call_user_compiler
    return self._call_user_compiler(gm)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 1465, in _call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 1446, in _call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/__init__.py", line 2235, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 1528, in compile_fx
    return aot_autograd(
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/backends/common.py", line 72, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 1071, in aot_module_simplified
    compiled_fn = dispatch_and_compile()
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 1056, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 522, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 759, in _create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 179, in aot_dispatch_base
    compiled_fw = compiler(fw_module, updated_flat_args)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 1357, in fw_compiler_base
    return _fw_compiler_base(model, example_inputs, is_inference)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 1428, in _fw_compiler_base
    return inner_compile(
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 479, in compile_fx_inner
    return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/repro/after_aot.py", line 85, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 665, in _compile_fx_inner
    compiled_graph = FxGraphCache.load(
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 1341, in load
    compiled_graph = compile_fx_fn(
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 574, in codegen_and_compile
    compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 882, in fx_codegen_and_compile
    compiled_fn = graph.compile_to_fn()
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_inductor/graph.py", line 1952, in compile_to_fn
    return self.compile_to_module().call
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_inductor/graph.py", line 1878, in compile_to_module
    return self._compile_to_module()
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_inductor/graph.py", line 1906, in _compile_to_module
    mod = PyCodeCache.load_by_key_path(
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 2866, in load_by_key_path
    mod = _reload_python_module(key, path)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_inductor/runtime/compile_tasks.py", line 45, in _reload_python_module
    exec(code, mod.__dict__, mod.__dict__)
  File "/tmp/tmps59zkbew/kg/ckgkb4gt5fs5pll4o7fqawppsmdezu5h52cq6nmrvi3yy6j7ddq4.py", line 45, in <module>
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_inductor/async_compile.py", line 198, in triton
    kernel = TritonCodeCache.load(kernel_name, source_code)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 2916, in load
    return _module_to_triton_kernel(PyCodeCache.load(source_code), kernel_name)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 2853, in load
    return cls.load_by_key_path(key, path, linemap, attrs)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 2866, in load_by_key_path
    mod = _reload_python_module(key, path)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_inductor/runtime/compile_tasks.py", line 39, in _reload_python_module
    raise RuntimeError(
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: Failed to import /tmp/tmps59zkbew/g3/cg3zgxsidsjhdlz2lzvajvubdq6kg2x2hzd2kznfj43qwvlv33du.py
SyntaxError: invalid syntax (cg3zgxsidsjhdlz2lzvajvubdq6kg2x2hzd2kznfj43qwvlv33du.py, line 14)

cc @alexbaden @etaf

Copy link

pytorch-bot bot commented Sep 18, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136280

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures

As of commit eff8243 with merge base 538ee7b (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link

linux-foundation-easycla bot commented Sep 18, 2024

CLA Signed

The committers listed above are authorized under a signed CLA.

@anmyachev anmyachev changed the title Fix Triton tests after update pybind11 to 2.13.6 [Inductor] Fix Triton tests after update pybind11 to 2.13.6 Sep 18, 2024
@anmyachev anmyachev marked this pull request as ready for review September 18, 2024 19:11
@anmyachev
Copy link
Collaborator Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Sep 18, 2024
@anmyachev anmyachev changed the title [Inductor] Fix Triton tests after update pybind11 to 2.13.6 [Inductor] Fix Triton tests after updating pybind11 to 2.13.6 Sep 18, 2024
@etaf etaf added the ciflow/xpu Run XPU CI tasks label Sep 19, 2024
@EikanWang EikanWang requested a review from guangyey September 19, 2024 00:54
@bdhirsh bdhirsh added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 19, 2024
@rwgk
Copy link

rwgk commented Sep 19, 2024

I'm the author of pybind/pybind11#5296

Sorry for the trouble, but it's a game-changing new feature.

Do you have ideas for what we could do to help fix the problem on your end?

@@ -232,9 +232,14 @@ def get_device_capability(device: Optional[_device_t] = None) -> Dict[str, Any]:
Dict[str, Any]: the xpu capability dictionary of the device
"""
props = get_device_properties(device)
return {
props_dict = {
prop: getattr(props, prop) for prop in dir(props) if not prop.startswith("__")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
prop: getattr(props, prop) for prop in dir(props) if not prop.startswith("__")
prop: getattr(props, prop) for prop in dir(props) if not prop.startswith(("__", "_pybind11_"))

Or "_pybind11_conduit_". This is better in case there's a v2 eventually.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thought it seems this should filter all bound methods, it's only supposed to be looking for data members, I'd assume?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thought it seems this should filter all bound methods, it's only supposed to be looking for data members, I'd assume?

It seems so, but the solution you proposed seems robust enough to not complicate the code by adding logic to detect bound methods for now. Is it ok?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think so.

@anmyachev
Copy link
Collaborator Author

Note to reviewers: CI failures are related to intel/intel-xpu-backend-for-triton#2188. Apparently PyTorch has not yet switched to the new version of Triton, which has a fix for this problem: intel/intel-xpu-backend-for-triton#2297.

pytorchmergebot pushed a commit to hoshibara/pytorch that referenced this pull request Sep 23, 2024
@etaf
Copy link
Collaborator

etaf commented Sep 24, 2024

Please fix lint issue.

return {
prop: getattr(props, prop) for prop in dir(props) if not prop.startswith("__")
prop: getattr(props, prop) for prop in dir(props) if not prop.startswith(("__", "_pybind11_"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix lint error.

@guangyey
Copy link
Collaborator

@anmyachev Thanks for your fix. Could you elaborate on why the triton case will fail? In my understanding, the dict returned by torch.xpu.get_device_capability just added a new keyword _pybind11_conduit_v1_. It shouldn't cause any failure because we could ignore it.

@anmyachev
Copy link
Collaborator Author

@anmyachev Thanks for your fix. Could you elaborate on why the triton case will fail? In my understanding, the dict returned by torch.xpu.get_device_capability just added a new keyword _pybind11_conduit_v1_. It shouldn't cause any failure because we could ignore it.

The new keyword _pybind11_conduit_v1_ contains a callable object <bound method PyCapsule._pybind11_conduit_v1_ of _XpuDeviceProperties..> that breaks serialization (json.dumps([dict], default=vars)) in Triton, which is not designed for this.

Small example:

>>> import json
>>> test_dict = {"a": "some data"}        
>>> json.dumps(test_dict, default=vars) 
'{"a": "some data"}'
>>> test_dict = {"a": "some data", "callable": dict.copy} 
>>> json.dumps(test_dict, default=vars)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "...\Lib\json\__init__.py", line 238, in dumps
    **kw).encode(obj)
          ^^^^^^^^^^^
  File "...\Lib\json\encoder.py", line 200, in encode
    chunks = self.iterencode(o, _one_shot=True)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "...\json\encoder.py", line 258, in iterencode
    return _iterencode(o, 0)
           ^^^^^^^^^^^^^^^^^
TypeError: vars() argument must have __dict__ attribute

In principle, one can ignore or filter this field in different places, but by doing it closer to the source, we reduce the likelihood that it will have to be done in several places.

@etaf
Copy link
Collaborator

etaf commented Sep 24, 2024

@pytorchbot rebase -b main

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/main. Check the current status here

@etaf etaf added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 24, 2024
@etaf etaf requested review from albanD and jansel September 24, 2024 13:05
@albanD
Copy link
Collaborator

albanD commented Sep 24, 2024

@rwgk very interesting feature for sure!
Is the implication in the context of serialization that no python class bound via py::class_ can be serialized? Or this was not possible before anyways?

@henryiii
Copy link
Contributor

It’s a method, not a data member?

@albanD
Copy link
Collaborator

albanD commented Sep 24, 2024

Right, so extracting all the content this way into a dict is prone to error as any method being defined on the base object would fail the same way.
Should we generally update this code to never store in the doc any callable with if not callable(prop) to make it more robust to any future change?

@henryiii
Copy link
Contributor

See #136280 (comment).

@rwgk
Copy link

rwgk commented Sep 24, 2024

@rwgk very interesting feature for sure! Is the implication in the context of serialization that no python class bound via py::class_ can be serialized? Or this was not possible before anyways?

@albanD Short answer: Not sure!

Trying to see what sticks:

  • py::class_ objects are generally not serializable (and I expected that my PR #5296 doesn't change that in any way).

  • Tangential probably: However, pybind11 has support for pickling, which can be enabled as an option. That makes the objects serializable with pickle.

  • I didn't look enough to understand why exactly the Triton tests are stumbling over the added method. — I'm wondering, what's special about _pybind11_conduit_v1_, compared to any other random methods that people may be .def-ing?

  • I feel (and felt) it would have been better to make the method __pybind11_conduit_v1__, but I wasn't (still am not) sure, and I didn't speak up when I maybe should have (here). — It wouldn't be too late IMO to convert to the dunder convention. On the back of my mind: If it solves the issue here (?), possibly/probably it'll also avoid hiccups elsewhere?

@henryiii
Copy link
Contributor

henryiii commented Sep 24, 2024

Dunder methods just happened to work here since this was filtering __*. Dunder names are meant for Python's own use, we should not be adding them in general. It's not Python's convention for Protocols, it's Python's internal reserved names and it just happens a lot of them are Protocols. The problem is the filtering is wrong; you don't need to store bound methods in general, not just dunder attributes.

Filtering out callable is probably okay, though if you store an object that happens to be callable, it will also be wrong. It would be better to filter on methods specifically.

compared to any other random methods that people

This is the first pybind11 injected method (at least that's not a Python dunder method like __new__). Maybe people aren't adding pybind11 methods here?

I don't know anything about how this is working, but I'm assuming it collects all members, does something, then unpacks it into a new object and reapplies all the members. Bound methods should still be present if the object type is the same.

@etaf
Copy link
Collaborator

etaf commented Sep 25, 2024

Hi, @albanD @jansel , ciflow/xpu is blocked by this fix. Can you please take the time to review this pr?

prop: getattr(props, prop) for prop in dir(props) if not prop.startswith("__")
prop: getattr(props, prop)
for prop in dir(props)
if not prop.startswith(("__", "_pybind11_"))
Copy link
Collaborator

@guangyey guangyey Sep 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if not prop.startswith(("__", "_pybind11_"))
if (not prop.startswith("__")) and (not callable(getattr(props, prop)))

I agree with @henryiii and @albanD about filtering out all callables, which makes code more generic.

@guangyey
Copy link
Collaborator

guangyey commented Sep 25, 2024

@anmyachev Thanks for your fix. Could you elaborate on why the triton case will fail? In my understanding, the dict returned by torch.xpu.get_device_capability just added a new keyword _pybind11_conduit_v1_. It shouldn't cause any failure because we could ignore it.

Thanks for your elaboration. It makes me easier to understand where the root cause lies.

@rwgk
Copy link

rwgk commented Sep 25, 2024

@henryiii wrote:

Dunder names are reserved for Python's own use

I couldn't find any evidence for that.

So I asked around. A colleague also looked around and couldn't find any evidence either.

Another person I trust wrote: They're not reserved, but you do run the risk of clashing with the runtime. There's other third-party packages that use their own dunders.

My take: Since we have pybind11 in the method name, there's practically zero risk of clashing with the Python runtime.

Can someone here find evidence that the dunder names are actually reserved?

If not, I'd be in favor of adopting __pybind11_conduit_v1__ as the method name, and yanking the pybind11 2.11.2, 2.12.1, and 2.13.6 releases (they were all released on the same day, Friday September 13, 2024). I believe that'd resolve this issue here in the most comprehensive way (released versions of pytorch and its dependencies would continue to be compatible with newer pybind11 versions).

@EikanWang
Copy link
Collaborator

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 3 checks: xpu / linux-jammy-xpu-py3.9 / test (default, 1, 4, linux.idc.xpu), xpu / linux-jammy-xpu-py3.9 / test (default, 3, 4, linux.idc.xpu), xpu / linux-jammy-xpu-py3.9 / test (default, 4, 4, linux.idc.xpu)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@EikanWang
Copy link
Collaborator

Since all the failures are Intel GPU only now, I will merge this PR first. We have a break change in the PyTorch and require the Intel GPU triton commit pin update. But the CI for the triton commit pin update will fail after updating pybind11 to 2.13.6. Therefore, we need to land this PR first. If we have a more elegant fix, we can submit another PR.

@anmyachev anmyachev deleted the patch-1 branch September 25, 2024 08:55
@henryiii
Copy link
Contributor

https://docs.python.org/3/reference/lexical_analysis.html#identifiers

Any use of __*__ names, in any context, that does not follow explicitly documented use, is subject to breakage without warning.

Emphasis is from the original docs.

@rwgk
Copy link

rwgk commented Sep 25, 2024

https://docs.python.org/3/reference/lexical_analysis.html#identifiers

Any use of * names, in any context, that does not follow explicitly documented use, is subject to breakage without warning.

Emphasis is from the original docs.

To close the loop here: I replied under pybind/pybind11#5296 (comment)

BoyuanFeng pushed a commit to BoyuanFeng/pytorch that referenced this pull request Sep 25, 2024
…h#136280)

pytorch#136087 update pybind11 to 2.13.6 and that new release has the feature which is expressed by [a new function](https://pybind11.readthedocs.io/en/latest/changelog.html#version-2-13-6-september-13-2024) `_pybind11_conduit_v1_`. The presence of this function breaks the serialization mechanisms used by Titon and in PyTorch itself.

Possible errors that have been noticed due to this change:

<details>
<summary> the first error </summary>

```bash
_________ KernelTests.test_layout_constraint_needs_fixed_stride_order __________
Traceback (most recent call last):
  File "/runner/_work/intel-xpu-backend-for-triton/intel-xpu-backend-for-triton/pytorch/test/inductor/test_triton_kernels.py", line 1072, in test_layout_constraint_needs_fixed_stride_order
    eager_out = f(x)
  File "/runner/_work/intel-xpu-backend-for-triton/intel-xpu-backend-for-triton/pytorch/test/inductor/test_triton_kernels.py", line 1068, in f
    arange_out(x, y)
  File "/runner/_work/intel-xpu-backend-for-triton/intel-xpu-backend-for-triton/pytorch/test/inductor/test_triton_kernels.py", line 1059, in arange_out
    kernel[grid](x, out, n_elements, BLOCK_SIZE=4)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/triton/runtime/jit.py", line 330, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/triton/runtime/jit.py", line 657, in run
    kernel = self.compile(
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/triton/compiler/compiler.py", line 315, in compile
    metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename,
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/json/__init__.py", line 234, in dumps
    return cls(
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/json/encoder.py", line 199, in encode
    chunks = self.iterencode(o, _one_shot=True)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/json/encoder.py", line 257, in iterencode
    return _iterencode(o, 0)
TypeError: vars() argument must have __dict__ attribute
```
</details>

<details>
<summary> the second error </summary>

```bash
________________ TestTritonWrapper.test_wrapper_using_gpu_seed _________________
Traceback (most recent call last):
  File "/cache/pytorch-c5e9d03a2da4b93481737594cbe2f5931fa569aa833f206a638189cad2c36d3c-11/test/inductor/test_triton_wrapper.py", line 40, in test_wrapper_using_gpu_seed
    out = f(x, y)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
    return fn(*args, **kwargs)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 1292, in __call__
    return self._torchdynamo_orig_callable(
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 1087, in __call__
    result = self._inner_convert(
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 530, in __call__
    return _compile(
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 933, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 675, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
    return function(*args, **kwargs)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 708, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
    transformations(instructions, code_options)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 220, in _fn
    return fn(*args, **kwargs)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 643, in transform
    tracer.run()
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2776, in run
    super().run()
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 979, in run
    while self.step():
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 891, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2967, in RETURN_VALUE
    self._return(inst)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2952, in _return
    self.output.compile_subgraph(
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 1117, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 1369, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 1416, in call_user_compiler
    return self._call_user_compiler(gm)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 1465, in _call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 1446, in _call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/__init__.py", line 2235, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 1528, in compile_fx
    return aot_autograd(
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/backends/common.py", line 72, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 1071, in aot_module_simplified
    compiled_fn = dispatch_and_compile()
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 1056, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 522, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 759, in _create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 179, in aot_dispatch_base
    compiled_fw = compiler(fw_module, updated_flat_args)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 1357, in fw_compiler_base
    return _fw_compiler_base(model, example_inputs, is_inference)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 1428, in _fw_compiler_base
    return inner_compile(
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 479, in compile_fx_inner
    return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_dynamo/repro/after_aot.py", line 85, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 665, in _compile_fx_inner
    compiled_graph = FxGraphCache.load(
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 1341, in load
    compiled_graph = compile_fx_fn(
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 574, in codegen_and_compile
    compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 882, in fx_codegen_and_compile
    compiled_fn = graph.compile_to_fn()
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_inductor/graph.py", line 1952, in compile_to_fn
    return self.compile_to_module().call
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_inductor/graph.py", line 1878, in compile_to_module
    return self._compile_to_module()
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_inductor/graph.py", line 1906, in _compile_to_module
    mod = PyCodeCache.load_by_key_path(
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 2866, in load_by_key_path
    mod = _reload_python_module(key, path)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_inductor/runtime/compile_tasks.py", line 45, in _reload_python_module
    exec(code, mod.__dict__, mod.__dict__)
  File "/tmp/tmps59zkbew/kg/ckgkb4gt5fs5pll4o7fqawppsmdezu5h52cq6nmrvi3yy6j7ddq4.py", line 45, in <module>
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_inductor/async_compile.py", line 198, in triton
    kernel = TritonCodeCache.load(kernel_name, source_code)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 2916, in load
    return _module_to_triton_kernel(PyCodeCache.load(source_code), kernel_name)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 2853, in load
    return cls.load_by_key_path(key, path, linemap, attrs)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 2866, in load_by_key_path
    mod = _reload_python_module(key, path)
  File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/_inductor/runtime/compile_tasks.py", line 39, in _reload_python_module
    raise RuntimeError(
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: Failed to import /tmp/tmps59zkbew/g3/cg3zgxsidsjhdlz2lzvajvubdq6kg2x2hzd2kznfj43qwvlv33du.py
SyntaxError: invalid syntax (cg3zgxsidsjhdlz2lzvajvubdq6kg2x2hzd2kznfj43qwvlv33du.py, line 14)
```
</details>

Pull Request resolved: pytorch#136280
Approved by: https://github.com/etaf, https://github.com/jansel, https://github.com/EikanWang

Co-authored-by: Henry Schreiner <HenrySchreinerIII@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request ciflow/xpu Run XPU CI tasks Merged open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.