In [2]:
from einops import rearrange, reduce, asnumpy, parse_shape
import numpy as np

### einsum rules

https://numpy.org/doc/stable/reference/generated/numpy.einsum.html

* The subscripts string is a comma-separated list of subscript labels, where each label refers to a dimension of the corresponding operand. 
* Whenever a label is repeated it is summed.
* When a label is not repeated it is not summed.
* NumPy-style broadcasting is done by adding an ellipsis to the left of each term, like `np.einsum('...ii->...i', a)`
* It's possible to have nothing on the RHS of the `->` in explicit mode.

### Questions

* what happens to an index that is not repeated and not part of the explicit mode output? 

In [4]:
x = np.arange(24).reshape(3,4,2); print(x)
y = np.arange(24,48).reshape(4,3,2); print(y)

[[[ 0  1]
  [ 2  3]
  [ 4  5]
  [ 6  7]]

 [[ 8  9]
  [10 11]
  [12 13]
  [14 15]]

 [[16 17]
  [18 19]
  [20 21]
  [22 23]]]
[[[24 25]
  [26 27]
  [28 29]]

 [[30 31]
  [32 33]
  [34 35]]

 [[36 37]
  [38 39]
  [40 41]]

 [[42 43]
  [44 45]
  [46 47]]]


In [5]:
np.einsum('ijk,jih->ih', x,y) #sum over two indices i,j. What happens to index k? this is nothing like matrix mult. 

array([[1044, 1072],
       [3340, 3432],
       [5892, 6048]])

In [14]:
print(x[:,0])
print(y[0])

np.inner(x[:,0].reshape(-1), y[0].reshape(-1)) #would have hoped this would be 1044 as above. 

[[ 0  1]
 [ 8  9]
 [16 17]]
[[24 25]
 [26 27]
 [28 29]]


1417

In [6]:
np.einsum('ijk,jik->ik', x,y)  # k is repeated. how does this change the computation? 

array([[ 456,  604],
       [1600, 1788],
       [2872, 3100]])

In [6]:
#matmul 3D

a  = np.arange(2 * 2 * 4).reshape((2, 2, 4))
b  = np.arange(2 * 2 * 4).reshape((2, 4, 2))
mm = np.matmul(a,b); print("here comes matmul\n", mm)

print("here comes einsum 1\n", np.einsum('...ij,...jk->...ik', a,b)) #works

print("here comes einsum 2\n", np.einsum('ijk,ikh->ijh', a,b)) #works...why do I repeat the 'i' index for both operands? Repeat label should be a summation.

here comes matmul
 [[[ 28  34]
  [ 76  98]]

 [[428 466]
  [604 658]]]
here comes einsum 1
 [[[ 28  34]
  [ 76  98]]

 [[428 466]
  [604 658]]]
here comes einsum 2
 [[[ 28  34]
  [ 76  98]]

 [[428 466]
  [604 658]]]
