<img src="attachments/einsum.png" width='500'>

<br> source: https://rockt.ai/2018/04/30/einsum

free indices = outer loop<br>
summation indices = inner loop

In [2]:
import torch
import itertools

def is_same(a,b):
    return torch.all(a==b)

def all_combination(*sizes):
    return itertools.product(*(range(s) for s in sizes)) # cartesian product

# Batched Matrix Multiplication (batched Dot Product)

In [22]:
B, T, C = 2,3,5

Q = torch.randint(low=0, high=10, size=(B,T,C))
K = torch.randint(low=0, high=10, size=(B,T,C))

# pytorch
y = Q @ K.transpose(-2,-1)

# einsum
y_ = torch.einsum('btc,bcT->btT', Q, K.transpose(-2,-1))
assert is_same(y, y_)

y_ = torch.einsum('btc,bTc->btT', Q, K) # implicit transpose
assert is_same(y, y_)

# free indices(outer loop): btT
# summation indices(inner loop): c
y_ = torch.empty((B,T,T))
for b, tq, tk in all_combination(B,T,T): # parallel (order doesnt rly matter)
    total = 0
    for c in range(C): # multiplied and sum over c (as c is repeated)
        total += Q[b,tq,c] * K[b,tk,c]
    y_[b,tq,tk] = total
assert is_same(y, y_)


# extra
# batched with head
H = 5
x = torch.randint(low=0, high=10, size=(B,H,T,C))
M = torch.randint(low=0, high=10, size=(B,H,C,T))
y = torch.einsum('bhtc, bhcT -> bhtT', x, M)
y_ = torch.einsum('bthc, bchT -> bhtT', x.transpose(1,2), M.transpose(1,2))
assert( y_.shape == torch.Size([B, H, T, T]) )

# transpose

In [7]:
B,T,H,C = 2,3,4,5
x = torch.randint(low=0, high=10, size=(B,T,H,C))

y = x.transpose(1,2)
y_ = torch.einsum('bthc->bhtc', x)
assert is_same(y, y_)

# Sum

In [17]:
B, T = 2,3

x = torch.randint(low=0, high=10, size=(B,T,T))

# sum over last dimension
y = x.sum(dim=-1)
y_ = torch.einsum('btT->bt',x)
assert is_same(y, y_)

y_ = torch.empty((B,T))
for b, t1 in all_combination(B,T):
    total = 0
    for t2 in range(T):
        total += x[b,t1,t2]
    y_[b,t1] = total
assert is_same(y, y_)

# sum over second dim
y = x.sum(dim=1)
y_ = torch.einsum('btT->bT',x)
assert is_same(y, y_)

y_ = torch.empty((B,T))
for b, t2 in all_combination(B,T):
    total = 0
    for t1 in range(T):
        total += x[b,t1,t2]
    y_[b,t2] = total
assert is_same(y, y_)

# sum over first dim
y = x.sum(dim=0)
y_ = torch.einsum('btT->tT',x)
assert is_same(y, y_)

# sum of all elements
y = x.sum()
y_ = torch.einsum('btT->',x)
assert is_same(y, y_)

# Matrix Vector (element wise) Multiplication

In [34]:
B, T = 2,3
m = torch.randint(low=0, high=10, size=(B, T))
n = torch.randint(low=0, high=10, size=[T])

# sum over t
y = torch.einsum('bt,t->b', m, n)
y_ = (m * n).sum(dim=1)
assert is_same(y, y_)

y_ = torch.empty(B)
for b in range(B):
    total = 0
    for t in range(T):
        total += m[b, t] * n[t]
    y_[b] = total
assert is_same(y, y_)

# Hadamard Product (element wise multiplication)

In [2]:
B, T, C = 2, 3, 4

a = torch.randint(low=0, high=10, size=(B,T,C))

y = torch.einsum('btc,btc->btc', a, a)
y_ = a * a
assert is_same(y, y_)

y_ = torch.empty(B,T,C)
for b, t, c in all_combination(B,T,C):
    total = 0

# Outer Product

In [9]:
l = 3
a = torch.randint(low=0, high=4, size=(l,))
b = torch.randint(low=5, high=10, size=(l,))

y = torch.einsum('i,j->ij', a, b)
y_ = a.unsqueeze(1) * b.unsqueeze(0) # [i,1] * [1,j] -> [i,j]
assert is_same(y, y_)

# free indices: ij
# summation indices: None
y_ = torch.empty(l,l)
for i, j in all_combination(l, l):
    y_[i, j] = a[i] * b[j]
assert is_same(y, y_)

# tensor contraction (IMPORTANT)

most tensor operations are tensor contractions (observe the previous operations)

In [23]:
U, V = 3, 5
a,b,c,d,e = 2,7,11,13,17

x1 = torch.randint(low=0, high=4,size=(a,U,V,b))
x2 = torch.randint(low=5, high=10, size=(c,d,U,e,V))

y = torch.einsum('aUVb,cdUeV->abcde', x1, x2)

# free indices: abcde
# summation indices: U, V
y_ = torch.empty(a,b,c,d,e)
for ai, bi, ci, di, ei in all_combination(a, b, c, d, e):
    patch_total = 0
    for u, v in all_combination(U, V):
        patch_total += x1[ai,u,v,bi] * x2[ci,di,u,ei,v]
    y_[ai,bi,ci,di,ei] = patch_total
assert is_same(y, y_)


# batch matrix multiplication is a special case of tensor contraction
B, T, C, D = 2, 3, 4, 7
x = torch.randint(low=0, high=10, size=(B,T,C))
M = torch.randint(low=0, high=10, size=(B,C,D))

# free indices: btd (because b is in both tensors, we dont repeat it, which is a special case)
# summation indices: c
y_ = torch.einsum('btc,bcd->btd', x, M)
assert( y_.shape == torch.Size([B, T, D]) )

# if b is diff
B1, B2 = 2, 3
x = torch.randint(low=0, high=10, size=(B1,T,C))
M = torch.randint(low=0, high=10, size=(B2,C,D))
y_ = torch.einsum('btc,Bcd->bBtd', x, M)
assert( y_.shape == torch.Size([B1,B2, T, D]) )

# batched with head
H = 5
x = torch.randint(low=0, high=10, size=(B,H,T,C))
M = torch.randint(low=0, high=10, size=(B,H,C,D))
y_ = torch.einsum('bhtc, bhcd -> bhtd', x, M)
assert( y_.shape == torch.Size([B,H, T, D]) )

# this is also why we need a different T in the einsum, as it should be a different dimension in the output
x = torch.randint(low=0, high=10, size=(B,H,T,C))
M = torch.randint(low=0, high=10, size=(B,H,C,T))
y_ = torch.einsum('bhtc, bhcT -> bhtT', x, M)
assert( y_.shape == torch.Size([B, H, T, T]) )

# Bilinear Transformation

In [25]:
i,j,K,L = 2,3,5,7

x1 = torch.randn(i, K)
x2 = torch.randn(j, K, L)
x3 = torch.randn(i, L)

# free indices: ij
# summation indices: KL
y = torch.einsum('iK, jKL, iL -> ij', x1, x2, x3)

# iK, jKL -> ijL (tensor contraction)
# ijL, iL -> ij (batched matrix vector)
y_ = torch.einsum('iK, jKL -> ijL', x1, x2)
y_ = torch.einsum('ijL, iL -> ij', y_, x3)
assert is_same(y, y_)