Skip to content

[export][decomp] aten::einsum decomposition is unavoidable #115883

@BowenBao

Description

@BowenBao

Is there a way to not decompose aten::einsum during export?

import torch


def f(x, y):
    softmax_scale = 1.0
    q = x
    k = y
    return torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)


x = torch.randn(1, 2, 3, 4)
y = torch.randn(1, 2, 3, 4)

print(f(x, y))

torch.export.export(f, (x, y)).graph_module.print_readable()
class GraphModule(torch.nn.Module):
    def forward(self, q: "f32[1, 2, 3, 4]", k: "f32[1, 2, 3, 4]"):
        # File: test_einsum_decomp.py:8, code: return torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
        mul: "f32[1, 2, 3, 4]" = torch.ops.aten.mul.Tensor(k, 1.0);  k = None
        unsqueeze: "f32[1, 2, 3, 4, 1]" = torch.ops.aten.unsqueeze.default(q, 4);  q = None
        permute: "f32[1, 3, 2, 1, 4]" = torch.ops.aten.permute.default(unsqueeze, [0, 2, 1, 4, 3]);  unsqueeze = None
        unsqueeze_1: "f32[1, 2, 3, 4, 1]" = torch.ops.aten.unsqueeze.default(mul, 4);  mul = None
        permute_1: "f32[1, 3, 1, 2, 4]" = torch.ops.aten.permute.default(unsqueeze_1, [0, 2, 4, 1, 3]);  unsqueeze_1 = None
        permute_2: "f32[3, 2, 4, 1, 1]" = torch.ops.aten.permute.default(permute, [1, 2, 4, 0, 3]);  permute = None
        view: "f32[3, 2, 4]" = torch.ops.aten.view.default(permute_2, [3, 2, 4]);  permute_2 = None
        permute_3: "f32[3, 4, 1, 2, 1]" = torch.ops.aten.permute.default(permute_1, [1, 4, 0, 3, 2]);  permute_1 = None
        view_1: "f32[3, 4, 2]" = torch.ops.aten.view.default(permute_3, [3, 4, 2]);  permute_3 = None
        bmm: "f32[3, 2, 2]" = torch.ops.aten.bmm.default(view, view_1);  view = view_1 = None
        view_2: "f32[3, 2, 1, 1, 2]" = torch.ops.aten.view.default(bmm, [3, 2, 1, 1, 2]);  bmm = None
        permute_4: "f32[1, 3, 2, 2, 1]" = torch.ops.aten.permute.default(view_2, [3, 0, 1, 4, 2]);  view_2 = None
        view_3: "f32[1, 3, 2, 2]" = torch.ops.aten.view.default(permute_4, [1, 3, 2, 2]);  permute_4 = None
        return (view_3,)

cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @zou3519 @SherlockNoMad @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov @kadeng @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4 @gchanan

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: decompositionsTopics related to decomposition (excluding PrimTorch)module: dynamooncall: exportoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions