Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Einsum to avoid transpose and reshape #4

Closed
azzever opened this issue Aug 25, 2019 · 3 comments
Closed

Einsum to avoid transpose and reshape #4

azzever opened this issue Aug 25, 2019 · 3 comments
Assignees

Comments

@azzever
Copy link

azzever commented Aug 25, 2019

Hi,

Thank you for the great post about Transformers.
Actually, you can avoid transpose/reshape using torch.einsum.

Here is an example that behaves exactly as your implementation (except mask=True, and asserts 😊):

def forward_einsum(self, x):
    b, t, e = x.size()
    h = self.heads

    keys    = self.tokeys(x).view(b, t, h, e)
    queries = self.toqueries(x).view(b, t, h, e)
    values  = self.tovalues(x).view(b, t, h, e)

    dot = torch.einsum('bthe,bihe->bhti', queries, keys) / math.sqrt(e)
    dot = F.softmax(dot, dim=-1)

    out = torch.einsum('bhtd,bdhe->bthe', dot, values)

    # we can move reshape of weights to init; I left it here just to compare with the original implementation
    out = torch.einsum('bthe,khe->btk', out, self.unifyheads.weight.view(e,h,e)) 
    return out + self.unifyheads.bias

Despite code became very short it's probably hard to understand for people that don't know einsum notation, so apparently, this is definitely not the best code to explain the idea 😊.

@pbloem pbloem self-assigned this Aug 25, 2019
@pbloem
Copy link
Owner

pbloem commented Aug 26, 2019

Wow, cool. I admit einsum is a bit of a blind spot for me (one of many).

I'm curious, do you know whether this is actually faster than a transpose/reshape, or does it just end up transposing under the hood? I can't find much information on how einsum is implemented, but it seems to be based mostly on reshaping and applying bmm().

You're right that adding einsum will probably make the post more difficult to read, but it's good to know that it exists. I'll see if I can reference it somewhere from the blogpost.

@pbloem pbloem closed this as completed Aug 29, 2019
@FabricioArendTorres
Copy link

As far as I know, tf.einsum does not yet optimize in a similar way to the einsum in numpy etc.

However, there is e.g. https://github.com/dgasmith/opt_einsum

So as of now einsum just complicates things imo

@mfenner1
Copy link

@pbloem Thank you so much for this code and your wonderful blog post. Very helpful as I try to give a presentation to some non-specialist undergrad students.

As a thank you and as some quick material for anyone else trying to understand einsum attention, here are a few examples I made years ago that helped me start understanding (NumPy's) einsum:

import numpy as np

arr_1d = np.random.rand(1e4)
arr_2d = np.random.rand(1e3, 1e3)
arr_3d = np.random.rand(1e2, 1e2, 1e2)

def sum1():
    return np.sum(arr_3d)
def sum2():
    return np.einsum('ijk->', arr_3d)

print(np.allclose(sum1(), sum2()))

def eltSq1():
    return arr_3d * arr_3d
def eltSq2():
    return np.einsum('ijk,ijk->ijk', arr_3d, arr_3d)

print(np.allclose(eltSq1(), eltSq2()))

def eltCube1():
    return arr_3d * arr_3d * arr_3d
def eltCube2():
    return np.einsum('ijk,ijk,ijk->ijk', arr_3d, arr_3d, arr_3d)

print(np.allclose(eltCube1(), eltCube2()))

arr_1d = np.random.rand(5000)

def outer1():
    return np.outer(arr_1d, arr_1d)
def outer2():
    return np.einsum('i,j->ij', arr_1d, arr_1d)
def outer3():
    return arr_1d[:,np.newaxis] * arr_1d[np.newaxis, :]

o1 = outer1()
print(np.allclose(o1, outer2()))
print(np.allclose(o1, outer3()))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants