-
Notifications
You must be signed in to change notification settings - Fork 172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Einsum to avoid transpose and reshape #4
Comments
Wow, cool. I admit einsum is a bit of a blind spot for me (one of many). I'm curious, do you know whether this is actually faster than a transpose/reshape, or does it just end up transposing under the hood? I can't find much information on how einsum is implemented, but it seems to be based mostly on reshaping and applying bmm(). You're right that adding einsum will probably make the post more difficult to read, but it's good to know that it exists. I'll see if I can reference it somewhere from the blogpost. |
As far as I know, tf.einsum does not yet optimize in a similar way to the einsum in numpy etc. However, there is e.g. https://github.com/dgasmith/opt_einsum So as of now einsum just complicates things imo |
@pbloem Thank you so much for this code and your wonderful blog post. Very helpful as I try to give a presentation to some non-specialist undergrad students. As a thank you and as some quick material for anyone else trying to understand import numpy as np
arr_1d = np.random.rand(1e4)
arr_2d = np.random.rand(1e3, 1e3)
arr_3d = np.random.rand(1e2, 1e2, 1e2)
def sum1():
return np.sum(arr_3d)
def sum2():
return np.einsum('ijk->', arr_3d)
print(np.allclose(sum1(), sum2()))
def eltSq1():
return arr_3d * arr_3d
def eltSq2():
return np.einsum('ijk,ijk->ijk', arr_3d, arr_3d)
print(np.allclose(eltSq1(), eltSq2()))
def eltCube1():
return arr_3d * arr_3d * arr_3d
def eltCube2():
return np.einsum('ijk,ijk,ijk->ijk', arr_3d, arr_3d, arr_3d)
print(np.allclose(eltCube1(), eltCube2()))
arr_1d = np.random.rand(5000)
def outer1():
return np.outer(arr_1d, arr_1d)
def outer2():
return np.einsum('i,j->ij', arr_1d, arr_1d)
def outer3():
return arr_1d[:,np.newaxis] * arr_1d[np.newaxis, :]
o1 = outer1()
print(np.allclose(o1, outer2()))
print(np.allclose(o1, outer3())) |
Hi,
Thank you for the great post about Transformers.
Actually, you can avoid transpose/reshape using torch.einsum.
Here is an example that behaves exactly as your implementation (except mask=True, and asserts 😊):
Despite code became very short it's probably hard to understand for people that don't know einsum notation, so apparently, this is definitely not the best code to explain the idea 😊.
The text was updated successfully, but these errors were encountered: