Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Einsum to avoid transpose and reshape #4

azzever opened this issue Aug 25, 2019 · 2 comments

Einsum to avoid transpose and reshape #4

azzever opened this issue Aug 25, 2019 · 2 comments


Copy link

@azzever azzever commented Aug 25, 2019


Thank you for the great post about Transformers.
Actually, you can avoid transpose/reshape using torch.einsum.

Here is an example that behaves exactly as your implementation (except mask=True, and asserts 馃槉):

def forward_einsum(self, x):
    b, t, e = x.size()
    h = self.heads

    keys    = self.tokeys(x).view(b, t, h, e)
    queries = self.toqueries(x).view(b, t, h, e)
    values  = self.tovalues(x).view(b, t, h, e)

    dot = torch.einsum('bthe,bihe->bhti', queries, keys) / math.sqrt(e)
    dot = F.softmax(dot, dim=-1)

    out = torch.einsum('bhtd,bdhe->bthe', dot, values)

    # we can move reshape of weights to init; I left it here just to compare with the original implementation
    out = torch.einsum('bthe,khe->btk', out, self.unifyheads.weight.view(e,h,e)) 
    return out + self.unifyheads.bias

Despite code became very short it's probably hard to understand for people that don't know einsum notation, so apparently, this is definitely not the best code to explain the idea 馃槉.

@pbloem pbloem self-assigned this Aug 25, 2019
Copy link

@pbloem pbloem commented Aug 26, 2019

Wow, cool. I admit einsum is a bit of a blind spot for me (one of many).

I'm curious, do you know whether this is actually faster than a transpose/reshape, or does it just end up transposing under the hood? I can't find much information on how einsum is implemented, but it seems to be based mostly on reshaping and applying bmm().

You're right that adding einsum will probably make the post more difficult to read, but it's good to know that it exists. I'll see if I can reference it somewhere from the blogpost.


@pbloem pbloem closed this Aug 29, 2019
Copy link

@FabricioArendTorres FabricioArendTorres commented Feb 18, 2021

As far as I know, tf.einsum does not yet optimize in a similar way to the einsum in numpy etc.

However, there is e.g.

So as of now einsum just complicates things imo


Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
None yet
None yet
Linked pull requests

Successfully merging a pull request may close this issue.

None yet
3 participants