Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[export][decomp] aten::einsum decomposition is unavoidable #115883

Open
BowenBao opened this issue Dec 14, 2023 · 3 comments
Open

[export][decomp] aten::einsum decomposition is unavoidable #115883

BowenBao opened this issue Dec 14, 2023 · 3 comments
Labels
module: decompositions Topics related to decomposition (excluding PrimTorch) module: dynamo oncall: export oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@BowenBao
Copy link
Collaborator

BowenBao commented Dec 14, 2023

Is there a way to not decompose aten::einsum during export?

import torch


def f(x, y):
    softmax_scale = 1.0
    q = x
    k = y
    return torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)


x = torch.randn(1, 2, 3, 4)
y = torch.randn(1, 2, 3, 4)

print(f(x, y))

torch.export.export(f, (x, y)).graph_module.print_readable()
class GraphModule(torch.nn.Module):
    def forward(self, q: "f32[1, 2, 3, 4]", k: "f32[1, 2, 3, 4]"):
        # File: test_einsum_decomp.py:8, code: return torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
        mul: "f32[1, 2, 3, 4]" = torch.ops.aten.mul.Tensor(k, 1.0);  k = None
        unsqueeze: "f32[1, 2, 3, 4, 1]" = torch.ops.aten.unsqueeze.default(q, 4);  q = None
        permute: "f32[1, 3, 2, 1, 4]" = torch.ops.aten.permute.default(unsqueeze, [0, 2, 1, 4, 3]);  unsqueeze = None
        unsqueeze_1: "f32[1, 2, 3, 4, 1]" = torch.ops.aten.unsqueeze.default(mul, 4);  mul = None
        permute_1: "f32[1, 3, 1, 2, 4]" = torch.ops.aten.permute.default(unsqueeze_1, [0, 2, 4, 1, 3]);  unsqueeze_1 = None
        permute_2: "f32[3, 2, 4, 1, 1]" = torch.ops.aten.permute.default(permute, [1, 2, 4, 0, 3]);  permute = None
        view: "f32[3, 2, 4]" = torch.ops.aten.view.default(permute_2, [3, 2, 4]);  permute_2 = None
        permute_3: "f32[3, 4, 1, 2, 1]" = torch.ops.aten.permute.default(permute_1, [1, 4, 0, 3, 2]);  permute_1 = None
        view_1: "f32[3, 4, 2]" = torch.ops.aten.view.default(permute_3, [3, 4, 2]);  permute_3 = None
        bmm: "f32[3, 2, 2]" = torch.ops.aten.bmm.default(view, view_1);  view = view_1 = None
        view_2: "f32[3, 2, 1, 1, 2]" = torch.ops.aten.view.default(bmm, [3, 2, 1, 1, 2]);  bmm = None
        permute_4: "f32[1, 3, 2, 2, 1]" = torch.ops.aten.permute.default(view_2, [3, 0, 1, 4, 2]);  view_2 = None
        view_3: "f32[1, 3, 2, 2]" = torch.ops.aten.view.default(permute_4, [1, 3, 2, 2]);  permute_4 = None
        return (view_3,)

cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @zou3519 @SherlockNoMad @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov @kadeng @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4 @gchanan

@BowenBao BowenBao added the module: decompositions Topics related to decomposition (excluding PrimTorch) label Dec 14, 2023
@cpuhrsch
Copy link
Contributor

I'm marking this as high priority, from the far away lands of PTO, to discuss the right tags. Thank you!

@voznesenskym voznesenskym added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module oncall: pt2 module: dynamo and removed triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Dec 19, 2023
BowenBao added a commit that referenced this issue Jan 12, 2024
…skips using custom operator"


A context manager that disables the decomposition of certain ops during dynamo tracing.

The approach is to temporarily hijack the operator callable with PT2 custom operator.
The custom operator will not be decomposed and will show up as a single node to be exported to ONNX.

For the time being the decomposition of these ops is otherwise unavoidable.

#116684
#115883

This solution will no longer be required once the issue is resolved.

[ghstack-poisoned]
BowenBao added a commit that referenced this issue Jan 12, 2024
…perator"


A context manager that disables the decomposition of certain ops during dynamo tracing.

The approach is to temporarily hijack the operator callable with PT2 custom operator.
The custom operator will not be decomposed and will show up as a single node to be exported to ONNX.

For the time being the decomposition of these ops is otherwise unavoidable.

#116684
#115883

This solution will no longer be required once the issue is resolved.

[ghstack-poisoned]
@anijain2305
Copy link
Contributor

cc @tugsbayasgalan @angelayi

@anijain2305 anijain2305 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 12, 2024
BowenBao added a commit that referenced this issue Jan 12, 2024
…mposition skips using custom operator"


A context manager that disables the decomposition of certain ops during dynamo tracing.

The approach is to temporarily hijack the operator callable with PT2 custom operator.
The custom operator will not be decomposed and will show up as a single node to be exported to ONNX.

For the time being the decomposition of these ops is otherwise unavoidable.

#116684
#115883

This solution will no longer be required once the issue is resolved.

[ghstack-poisoned]
BowenBao added a commit that referenced this issue Jan 12, 2024
…NNX][dynamo_export] Decomposition skips using custom operator"


A context manager that disables the decomposition of certain ops during dynamo tracing.

The approach is to temporarily hijack the operator callable with PT2 custom operator.
The custom operator will not be decomposed and will show up as a single node to be exported to ONNX.

For the time being the decomposition of these ops is otherwise unavoidable.

#116684
#115883

This solution will no longer be required once the issue is resolved.

[ghstack-poisoned]
BowenBao added a commit that referenced this issue Jan 13, 2024
…ueError: vector reserve on "[ONNX][dynamo_export] Decomposition skips using custom operator"


A context manager that disables the decomposition of certain ops during dynamo tracing.

The approach is to temporarily hijack the operator callable with PT2 custom operator.
The custom operator will not be decomposed and will show up as a single node to be exported to ONNX.

For the time being the decomposition of these ops is otherwise unavoidable.

#116684
#115883

This solution will no longer be required once the issue is resolved.

[ghstack-poisoned]
BowenBao added a commit that referenced this issue Jan 17, 2024
…on skips using custom operator"


A context manager that disables the decomposition of certain ops during dynamo tracing.

The approach is to temporarily hijack the operator callable with PT2 custom operator.
The custom operator will not be decomposed and will show up as a single node to be exported to ONNX.

For the time being the decomposition of these ops is otherwise unavoidable.

#116684
#115883

This solution will no longer be required once the issue is resolved.

[ghstack-poisoned]
BowenBao added a commit that referenced this issue Jan 17, 2024
…ustom operator"


A context manager that disables the decomposition of certain ops during dynamo tracing.

The approach is to temporarily hijack the operator callable with PT2 custom operator.
The custom operator will not be decomposed and will show up as a single node to be exported to ONNX.

For the time being the decomposition of these ops is otherwise unavoidable.

#116684
#115883

This solution will no longer be required once the issue is resolved.

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this issue Jan 18, 2024
)

A context manager that disables the decomposition of certain ops during dynamo tracing.

The approach is to temporarily hijack the operator callable with PT2 custom operator.
The custom operator will not be decomposed and will show up as a single node to be exported to ONNX.

For the time being the decomposition of these ops is otherwise unavoidable.

#116684
#115883

This solution will no longer be required once the issue is resolved.
Pull Request resolved: #117314
Approved by: https://github.com/justinchuby, https://github.com/malfet
@zhxchen17
Copy link
Contributor

@BowenBao Hi sorry for the late response as I'm cleaning up the issue queue... We'll soon migrate to a higher decomposition level (known as pre-dispatch), which might preseve aten::einsum in the graph, and this mode will become the default mode when we do a formal release, in our tentative plan. Currently we have a private flag to enable this mode which is hidden in torch.export._trace._export. If you want to try it out, feel free to go ahead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: decompositions Topics related to decomposition (excluding PrimTorch) module: dynamo oncall: export oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

7 participants