Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AOTI] Support ReinterpretView in abi mode #114169

Closed
wants to merge 2 commits into from

Conversation

oulgen
Copy link
Contributor

@oulgen oulgen commented Nov 20, 2023

Stack from ghstack (oldest at bottom):

#113967 added support for
ReinterpretView but it turnes out we codegen it differently in abi
compat mode. This PR adds support for abi compat mode as well.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler

#113967 added support for
ReinterpretView but it turnes out we codegen it differently in abi
compat mode. This PR adds support for abi compat mode as well.

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Nov 20, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/114169

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 9bf3290 with merge base 8f8722e (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

oulgen added a commit that referenced this pull request Nov 20, 2023
#113967 added support for
ReinterpretView but it turnes out we codegen it differently in abi
compat mode. This PR adds support for abi compat mode as well.

ghstack-source-id: f435abbdd5497343124c4c779f35743d87ca1dfd
Pull Request resolved: #114169
@oulgen oulgen added ciflow/trunk Trigger trunk jobs on your pull request topic: not user facing topic category labels Nov 20, 2023
@@ -60,7 +60,9 @@ def is_aligned(
https://github.com/openai/triton/blob/5282ed890d453e10b9ee30076ef89115dd197761/python/triton/runtime/jit.py#L208-L222
"""
if isinstance(x, TensorArg):
if x.buffer.startswith("reinterpret_tensor"):
if x.buffer.startswith("reinterpret_tensor") or x.buffer.startswith(
"RAIIAtenTensorHandle"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like too general? Many things can start with RAIIAtenTensorHandle: do we want to return False on all of them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well all we know at this point is that the name is RAIIAtenTensorHandle(tmp_tensor_handle_0). Previously other RAIIAtenTensorHandle would have failed because there would be no matching tensor in self.name_to_node, so returning False on all of them is the right direction.

The longer term fix would be to remove these prefixes from tensor names, even better not use tensor names to do these comparisons.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I share a similar concern that checking RAIIAtenTensorHandle in a general purpose utility function config_of seems to be too aggressive. I am wondering if we could lift the check to its caller in this particular case, i.e.:

triton_meta = {
"signature": signature_to_meta(signature, size_dtype=index_dtype),
"device": V.graph.scheduler.current_device.index,
"device_type": V.graph.scheduler.current_device.type,
"constants": constants,
"configs": [config_of(signature)],
}

where we could check if the kernel is of ReinterpretView, then we just set False to the "configs" field of triton_meta. Otherwise, we call the general config_of function. In this way, we would not pollute the general utility function.

@jansel
Copy link
Contributor

jansel commented Nov 21, 2023

Is RAIIAtenTensorHandle always the same reinterpret_tensor? Can you give an example of what the buffer is?

@oulgen
Copy link
Contributor Author

oulgen commented Nov 21, 2023

def codegen_reinterpret_view(
self, data, size_list, stride_list, offset, writer
) -> str:
dim = str(len(size_list))
size = self.codegen_shape_tuple(size_list)
stride = self.codegen_shape_tuple(stride_list)
offset = self.codegen_sizevar(offset)
if config.aot_inductor.abi_compatible:
tmp_name = f"tmp_tensor_handle_{next(self.tmp_tensor_id)}"
# Because the memory planning is done in two passes (see the implementation
# of self.generate), the writeline behavior is different in the two passes.
if writer is None:
writer = self
args = [
f"{data.get_name()}",
dim,
self.codegen_int_array_var(size, writer),
self.codegen_int_array_var(stride, writer),
offset,
f"&{tmp_name}",
]
def gen_reinterpret_call(writer, args):
writer.writeline(f"AtenTensorHandle {tmp_name};")
writer.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__reinterpret_tensor({', '.join(args)}));"
)
if (
self.can_cache_buffer_in_thread_local(data)
and self.is_statically_known_list_of_ints(size_list)
and self.is_statically_known_list_of_ints(stride_list)
):
self.cached_thread_locals.add(tmp_name)
writer.writeline(
f"thread_local RAIIAtenTensorHandle {tmp_name}_handle = ([&] {{"
)
if hasattr(writer, "indent"):
indent = writer.indent()
else:
indent = contextlib.nullcontext()
with indent:
gen_reinterpret_call(writer, args)
writer.writeline(f"return {tmp_name};")
writer.writeline("})();")
writer.writeline(
f"AtenTensorHandle {tmp_name}({tmp_name}_handle.get());"
)
return tmp_name
gen_reinterpret_call(writer, args)
# NB, the return handle here represents a temporary tensor, which will be automatically
# released.
# Here's a sample usage in the cpp wrapper code:
# ```
# aoti_torch_addmm_out(
# buf1,
# arg1_1,
# RAIIAtenTensorHandle(tmp_tensor_handle_0),
# buf0,
# 1L,
# 1L));
# ```
# RAIIAtenTensorHandle(tmp_tensor_handle_0) will be released after the call to addmm_out.
# This could be problematic when it's used in a different pattern, for example:
# ````
# AtenTensorHandle tensor_args[] = {RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6};
# aoti_torch_proxy_executor_call_function(..., tensor_args);
# ````
# RAIIAtenTensorHandle(tmp_tensor_handle_2) will be invalid when it's used in the latter
# kernel call.
#
# This is solved by updating the proxy_executor invocation to
# ```
# aoti_torch_proxy_executor_call_function(...,
# std::vector<AtenTensorHandle>{
# RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6
# }.data()
# );
# ```
return f"RAIIAtenTensorHandle({tmp_name})"
else:
args = [data.get_name(), size, stride, offset]
return f"reinterpret_tensor({', '.join(args)})"

in this code for codegen_reinterpret_view, we either generate RAIIAtenTensorHandle or reinterpret_tensor.

On my test case, the emitted code is P887066630

#113967 added support for
ReinterpretView but it turnes out we codegen it differently in abi
compat mode. This PR adds support for abi compat mode as well.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
oulgen added a commit that referenced this pull request Nov 21, 2023
#113967 added support for
ReinterpretView but it turnes out we codegen it differently in abi
compat mode. This PR adds support for abi compat mode as well.

ghstack-source-id: 18dc01c478550f2256f59ea4016989cd91771d96
Pull Request resolved: #114169
Copy link
Contributor

@aakhundov aakhundov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks more clean and specific now. Thanks!

@oulgen
Copy link
Contributor Author

oulgen commented Nov 21, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Nov 24, 2023
This IR node mutates in place, it needs to use the argument not the
target.

Fixes #113440

Pull Request resolved: #114436
Approved by: https://github.com/jansel
ghstack dependencies: #114169
xunsongh pushed a commit to xunsongh/pytorch that referenced this pull request Nov 24, 2023
This IR node mutates in place, it needs to use the argument not the
target.

Fixes pytorch#113440

Pull Request resolved: pytorch#114436
Approved by: https://github.com/jansel
ghstack dependencies: pytorch#114169
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants