-
Notifications
You must be signed in to change notification settings - Fork 24.8k
Closed
Labels
module: decompositionsTopics related to decomposition (excluding PrimTorch)Topics related to decomposition (excluding PrimTorch)module: dynamooncall: exportoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
Is there a way to not decompose aten::einsum
during export?
import torch
def f(x, y):
softmax_scale = 1.0
q = x
k = y
return torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
x = torch.randn(1, 2, 3, 4)
y = torch.randn(1, 2, 3, 4)
print(f(x, y))
torch.export.export(f, (x, y)).graph_module.print_readable()
class GraphModule(torch.nn.Module):
def forward(self, q: "f32[1, 2, 3, 4]", k: "f32[1, 2, 3, 4]"):
# File: test_einsum_decomp.py:8, code: return torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
mul: "f32[1, 2, 3, 4]" = torch.ops.aten.mul.Tensor(k, 1.0); k = None
unsqueeze: "f32[1, 2, 3, 4, 1]" = torch.ops.aten.unsqueeze.default(q, 4); q = None
permute: "f32[1, 3, 2, 1, 4]" = torch.ops.aten.permute.default(unsqueeze, [0, 2, 1, 4, 3]); unsqueeze = None
unsqueeze_1: "f32[1, 2, 3, 4, 1]" = torch.ops.aten.unsqueeze.default(mul, 4); mul = None
permute_1: "f32[1, 3, 1, 2, 4]" = torch.ops.aten.permute.default(unsqueeze_1, [0, 2, 4, 1, 3]); unsqueeze_1 = None
permute_2: "f32[3, 2, 4, 1, 1]" = torch.ops.aten.permute.default(permute, [1, 2, 4, 0, 3]); permute = None
view: "f32[3, 2, 4]" = torch.ops.aten.view.default(permute_2, [3, 2, 4]); permute_2 = None
permute_3: "f32[3, 4, 1, 2, 1]" = torch.ops.aten.permute.default(permute_1, [1, 4, 0, 3, 2]); permute_1 = None
view_1: "f32[3, 4, 2]" = torch.ops.aten.view.default(permute_3, [3, 4, 2]); permute_3 = None
bmm: "f32[3, 2, 2]" = torch.ops.aten.bmm.default(view, view_1); view = view_1 = None
view_2: "f32[3, 2, 1, 1, 2]" = torch.ops.aten.view.default(bmm, [3, 2, 1, 1, 2]); bmm = None
permute_4: "f32[1, 3, 2, 2, 1]" = torch.ops.aten.permute.default(view_2, [3, 0, 1, 4, 2]); view_2 = None
view_3: "f32[1, 3, 2, 2]" = torch.ops.aten.view.default(permute_4, [1, 3, 2, 2]); permute_4 = None
return (view_3,)
cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @zou3519 @SherlockNoMad @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov @kadeng @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4 @gchanan
vadimkantorov
Metadata
Metadata
Assignees
Labels
module: decompositionsTopics related to decomposition (excluding PrimTorch)Topics related to decomposition (excluding PrimTorch)module: dynamooncall: exportoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module