# Einsum (Einsum is all you need)

- **Free Indices:** Are the indices specified in the outpus
- **Summation Indices:** All other indices. Those that appear in the input argument
  but Not in output specification.

  ex: 矩阵乘法运算中 $A_{i k} \cdot B_{k  j} = Result_{ij}$， i 和 j 就是 Free Indices, k 是 Summation Indices.

In [1]:
import torch

In [3]:
mat_x = torch.rand(3, 5)
mat_y = torch.rand(5, 2)
mat_c = torch.empty((3, 2))

In [4]:
for i in range(mat_x.shape[0]):
    for j in range(mat_y.shape[1]):
        tmp_total = 0
        for k in range(mat_x.shape[1]):
            tmp_total += mat_x[i][k] * mat_y[k][j]

        mat_c[i][j] = tmp_total

In [5]:
mat_t1 = mat_x @ mat_y
mat_t1

tensor([[1.9031, 0.7166],
        [0.8968, 0.6699],
        [1.3841, 0.8715]])

In [7]:
mat_c == mat_t1

tensor([[True, True],
        [True, True],
        [True, True]])

In [8]:
mat_einsum = torch.einsum('ik, kj -> ij', mat_x, mat_y)

In [9]:
mat_einsum

tensor([[1.9031, 0.7166],
        [0.8968, 0.6699],
        [1.3841, 0.8715]])

上例中 i, j 是 Free Indices,
k 是 Summation Indices.
在 k 循环中求和

In [18]:
a = torch.rand(5)
b = torch.rand(3)
c = torch.empty((5, 3))

outer = torch.einsum('i, j -> ij', a, b)

In [19]:
print(a)
print(b)
print(outer)

tensor([0.8842, 0.6382, 0.2734, 0.6178, 0.3061])
tensor([0.6580, 0.6709, 0.2253])
tensor([[0.5818, 0.5932, 0.1992],
        [0.4199, 0.4282, 0.1438],
        [0.1799, 0.1834, 0.0616],
        [0.4065, 0.4145, 0.1392],
        [0.2014, 0.2054, 0.0690]])


In [20]:
for i in range(a.shape[0]):
    for j in range(b.shape[0]):
        c[i][j] = a[i] * b[j]

In [21]:
c

tensor([[0.5818, 0.5932, 0.1992],
        [0.4199, 0.4282, 0.1438],
        [0.1799, 0.1834, 0.0616],
        [0.4065, 0.4145, 0.1392],
        [0.2014, 0.2054, 0.0690]])

In [22]:
x = torch.ones(3)
x

tensor([1., 1., 1.])

In [25]:
torch.einsum('i ->', x)

tensor(3.)

自己的一点小理解，就是循环每个索引。

先循环结果中的索引(i, j)，再循环结果中没有的索引(k)。

对于结果中没有的索引，计算后再相加进行累积，然后填入到结果中。

## RULES:
1. Repeating letters in different inputs means 
   those values will be multiplied and those products will be the output.
2. Omitting a letter means that axis will be summed.
3. We can return the unsummed axes in any order.

## Code Show

In [2]:
import torch

x = torch.rand((2, 3))
x

tensor([[0.3718, 0.5213, 0.8097],
        [0.1012, 0.9702, 0.6299]])

In [47]:
# Permutation of Tensor

len_i = x.shape[0]
len_j = x.shape[1]
result = torch.empty((len_j, len_i))
for i in range(len_i):
    for j in range(len_j):
        result[j][i] = x[i][j]
        
t = torch.einsum("ij -> ji", x)

print(torch.all(result == t))
print(t)

tensor(True)
tensor([[0.5497, 0.7491],
        [0.9246, 0.6787],
        [0.5731, 0.8049]])


In [51]:
# Summation
len_i = x.shape[0]
len_j = x.shape[1]
result = torch.empty((1))

tmp_sum = 0
for i in range(len_i):
    for j in range(len_j):
        tmp_sum += x[i][j]

result = tmp_sum

t = torch.einsum("ij ->", x)
print(t == result)
print(t)

tensor(True)
tensor(4.2801)


In [3]:
# Column Sum
# column sum 即把元素按列求和， 如果有 m 列，每一列都求和得到 (1,m)。 (n, m) -> (1, m)
# 即有多少列，就能得到几个数，每一列的数都被求和
len_i = x.shape[0]
len_j = x.shape[1]
result = torch.empty((len_j))


for i in range(len_i):
    for j in range(len_j):
        result[j] += x[i][j]

t = torch.einsum("ij -> j", x)
print(torch.all(t == result))
print(t)

tensor(False)
tensor([0.4730, 1.4914, 1.4396])


In [54]:
# Row Sum
len_i = x.shape[0]
len_j = x.shape[1]
result = torch.empty((len_i))

for i in range(len_i):
    tmp_total = 0
    for j in range(len_j):
        tmp_total += x[i][j]
    result[i] = tmp_total

t = torch.einsum("ij -> i", x)
print(torch.all(t == result))
print(t)

tensor(True)
tensor([2.0475, 2.2326])


In [55]:
# Matrix-Vector Multiplication
v = torch.rand((1, 3))

len_i = x.shape[0]
len_j = x.shape[1]
len_k = v.shape[0]
result = torch.empty((len_i, len_k))

for i in range(len_i):
    for k in range(len_k):
        tmp_total = 0
        for j in range(len_j):
            tmp_total += x[i][j] * v[k][j]
        result[i][k] = tmp_total

t = torch.einsum("ij, kj -> ik", x, v)  # (2, 3) @ (3, 1)
print(torch.all(t == result))
print(t)

tensor(True)
tensor([[1.3543],
        [1.5067]])


In [57]:
# Matrix-Matrix Multiplication
len_i = x.shape[0]
len_j = x.shape[1]
len_k = x.shape[0]
result = torch.empty((len_i, len_k))

for i in range(len_i):
    for k in range(len_k):
        tmp_total = 0
        for j in range(len_j):
            tmp_total += x[i][j] * x[k][j]
        result[i][k] = tmp_total

t = torch.einsum("ij, kj -> ik", x, x)  # (2, 3) @ (3, 2)
print(torch.all(t == result))
print(t)

tensor(True)
tensor([[1.4857, 1.5006],
        [1.5006, 1.6696]])


In [37]:
# Dot product first row with first row of matrix
torch.einsum("i, i -> ", x[0], x[0])

tensor(1.4857)

In [43]:
# Dot product with matrix
torch.einsum("ij, ij ->",x, x)

tensor(3.1552)

In [58]:
# Hadamard Product (element-wise multiplication)
len_i = x.shape[0]
len_j = x.shape[1]
result = torch.empty((len_i, len_j))

for i in range(len_i):
    for j in range(len_j):
        result[i][j] = x[i][j] * x[i][j]

t = torch.einsum("ij, ij -> ij", x, x)
torch.all(t == result)

tensor(True)

In [68]:
# Outer Product
a = torch.rand((3))
b = torch.rand((5))

len_i = a.shape[0]
len_j = b.shape[0]
result = torch.empty((len_i, len_j))

for i in range(len_i):
    for j in range(len_j):
        result[i][j] = a[i] * b[j]

t = torch.einsum("i, j -> ij", a, b)
# torch.all(t, result)
t == result

tensor([[True, True, True, True, True],
        [True, True, True, True, True],
        [True, True, True, True, True]])

In [69]:
# Batch Matrix Multiplication
a = torch.rand((3 , 2, 5))
b = torch.rand((3, 5, 3))
torch.einsum("ijk, ikl -> ijl", a, b)

tensor([[[2.2354, 1.4121, 1.6984],
         [2.4894, 1.5692, 1.6881]],

        [[1.2697, 1.6723, 0.9452],
         [0.7643, 1.5241, 0.7664]],

        [[1.2094, 0.4053, 0.4002],
         [1.1668, 0.5641, 0.7423]]])

In [76]:
x = torch.rand((3, 3))

In [77]:
# Matrix Diagonal
len_i = x.shape[0]
result = torch.empty((len_i))

for i in range(len_i):
    result[i] = x[i][i]

t = torch.einsum("ii -> i", x)
t == result

tensor([True, True, True])

In [84]:
len_i = x.shape[0]
len_j = x.shape[1]
result = torch.empty((len_i))

for i in range(len_i):
    tmp_total = 0
    for j in range(len_j):
        tmp_total += x[i][j]
    result[i] = tmp_total

t = torch.einsum("ij -> i", x)
t == result

tensor([True, True, True])

In [82]:
# Matrix Trace
len_i = x.shape[0]
result = torch.empty((1))

for i in range(len_i):
    result += x[i][i]

t = torch.einsum("ii->", x)
t == result

tensor([True])

## vector attention test

In [97]:
attn = torch.rand((4, 10, 3, 32))
value = torch.rand((4, 10, 3, 32))

batch_size, num_points, k_num, feature = attn.shape

# result = torch.empty((batch_size, num_points, feature))
result = torch.zeros((batch_size, num_points, feature))

for b in range(batch_size):
    for n in range(num_points):
        for k in range(k_num):
            for f in range(feature):
                result[b, n, f] += attn[b, n, k, f] * value[b, n, k, f]


t = torch.einsum('bnkf,bnkf->bnf', attn, value)
torch.all(t == result)

tensor(True)

In [87]:
t.shape

torch.Size([4, 10, 32])