Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Einsum to avoid transpose and reshape #4

azzever opened this issue Aug 25, 2019 · 1 comment


Copy link

commented Aug 25, 2019


Thank you for the great post about Transformers.
Actually, you can avoid transpose/reshape using torch.einsum.

Here is an example that behaves exactly as your implementation (except mask=True, and asserts 馃槉):

def forward_einsum(self, x):
    b, t, e = x.size()
    h = self.heads

    keys    = self.tokeys(x).view(b, t, h, e)
    queries = self.toqueries(x).view(b, t, h, e)
    values  = self.tovalues(x).view(b, t, h, e)

    dot = torch.einsum('bthe,bihe->bhti', queries, keys) / math.sqrt(e)
    dot = F.softmax(dot, dim=-1)

    out = torch.einsum('bhtd,bdhe->bthe', dot, values)

    # we can move reshape of weights to init; I left it here just to compare with the original implementation
    out = torch.einsum('bthe,khe->btk', out, self.unifyheads.weight.view(e,h,e)) 
    return out + self.unifyheads.bias

Despite code became very short it's probably hard to understand for people that don't know einsum notation, so apparently, this is definitely not the best code to explain the idea 馃槉.

@pbloem pbloem self-assigned this Aug 25, 2019


This comment has been minimized.

Copy link

commented Aug 26, 2019

Wow, cool. I admit einsum is a bit of a blind spot for me (one of many).

I'm curious, do you know whether this is actually faster than a transpose/reshape, or does it just end up transposing under the hood? I can't find much information on how einsum is implemented, but it seems to be based mostly on reshaping and applying bmm().

You're right that adding einsum will probably make the post more difficult to read, but it's good to know that it exists. I'll see if I can reference it somewhere from the blogpost.

@pbloem pbloem closed this Aug 29, 2019

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
None yet
2 participants
You can鈥檛 perform that action at this time.