Tensor contraction over specified indices and outer product.

https://www.tensorflow.org/api_docs/python/tf/einsum


Related readings


*  https://ita9naiwa.github.io/numeric%20calculation/2018/11/10/Einsum.html



In [6]:
import tensorflow as tf

### matmul

In [7]:
A = tf.constant([[2,6,5,2],
              [2,-2,2,3],
              [1,5,4,0]])
B = tf.constant([[2,9,0,3,0],
              [3,6,8,-2,2],
              [1,3,5,0,1],
              [3,0,2,0,5]])

print(A.shape)
print(B.shape)

print("Matmul C =:\n")
print(tf.matmul(A,B), "\n")

print("Einsum C =:\n" )
print(tf.einsum("ij,jk -> ik", A, B))

(3, 4)
(4, 5)
Matmul C =:

tf.Tensor(
[[33 69 77 -6 27]
 [ 9 12  0 10 13]
 [21 51 60 -7 14]], shape=(3, 5), dtype=int32) 

Einsum C =:

tf.Tensor(
[[33 69 77 -6 27]
 [ 9 12  0 10 13]
 [21 51 60 -7 14]], shape=(3, 5), dtype=int32)


### element-wise multiplication

In [9]:
A = tf.constant([[2,6,5,2],
              [2,-2,2,3],
              [1,5,4,0]])
B = tf.constant([[2,9,0,3],
              [3,6,8,-2],
              [1,3,5,0]])
print(A.shape, B.shape)

print("Hardamond C =: \n")
print(A*B, "\n")

print("Einsum C =:\n" )
print(tf.einsum("ij,ij -> ij", A, B))

(3, 4) (3, 4)
Hardamond C =: 

tf.Tensor(
[[  4  54   0   6]
 [  6 -12  16  -6]
 [  1  15  20   0]], shape=(3, 4), dtype=int32) 

Einsum C =:

tf.Tensor(
[[  4  54   0   6]
 [  6 -12  16  -6]
 [  1  15  20   0]], shape=(3, 4), dtype=int32)


### Transpose

In [12]:
A = tf.constant([[2,6,5,2],
              [2,-2,2,3],
              [1,5,4,0]])

print("Transposed A =: \n")
print(tf.transpose(A), "\n")

print("Einsum Transpose A =:\n" )
print(tf.einsum("ij -> ji", A))

Transposed A =: 

tf.Tensor(
[[ 2  2  1]
 [ 6 -2  5]
 [ 5  2  4]
 [ 2  3  0]], shape=(4, 3), dtype=int32) 

Einsum C =:

tf.Tensor(
[[ 2  2  1]
 [ 6 -2  5]
 [ 5  2  4]
 [ 2  3  0]], shape=(4, 3), dtype=int32)


### Batch Multiplication (3D)

In [14]:
A = tf.constant([
              [[2,6,5,2],
              [2,-2,2,3],
              [1,5,4,0]],
              [[1,3,1,22],
               [0,2,2,0],
               [1,5,4,1]]
              ])
B = tf.constant([
                [[2,9,0,3,0],
                [3,6,8,-2,2],
                [1,3,5,0,1],
                [3,0,2,0,5]],
                [[1,0,0,3,0],
                [3,0,4,-2,2],
                [1,0,2,0,0],
                [3,0,1,1,0]]])
print(A.shape, B.shape)

print("Batch Matmul C =:\n")
print(tf.matmul(A,B), "\n")

print("Batch Einsum C =:\n" )
print(tf.einsum("bij,bjk -> bik", A, B))

(2, 3, 4) (2, 4, 5)
Batch Matmul C =:

tf.Tensor(
[[[33 69 77 -6 27]
  [ 9 12  0 10 13]
  [21 51 60 -7 14]]

 [[77  0 36 19  6]
  [ 8  0 12 -4  4]
  [23  0 29 -6 10]]], shape=(2, 3, 5), dtype=int32) 

Batch Einsum C =:

tf.Tensor(
[[[33 69 77 -6 27]
  [ 9 12  0 10 13]
  [21 51 60 -7 14]]

 [[77  0 36 19  6]
  [ 8  0 12 -4  4]
  [23  0 29 -6 10]]], shape=(2, 3, 5), dtype=int32)


### Sum

In [19]:
A = tf.constant([
              [[2,6,5,2],
              [2,-2,2,3],
              [1,5,4,0]],
              [[1,3,1,22],
               [0,2,2,0],
               [1,5,4,1]]
              ])

print("Sum A =:\n")
print(tf.math.reduce_sum(A), "\n")

print("Einsum A =:\n" )
print(tf.einsum("bij ->", A)) # if summing up all possible elements, keep it empty after arrow

Sum A =:

tf.Tensor(72, shape=(), dtype=int32) 

Einsum A =:

tf.Tensor(72, shape=(), dtype=int32)


In [23]:
A = tf.constant([[2,6,5,2],
              [2,-2,2,3],
              [1,5,4,0]])

print("Sum up all colums of A =:\n")
print(tf.math.reduce_sum(A, axis=0), '\n') # fixing axis 0

print("Einsum A =:\n")
print(tf.einsum('ij -> j', A)) # this means summing up all the columns

Sum up all colums of A =:

tf.Tensor([ 5  9 11  5], shape=(4,), dtype=int32) 

Einsum A =:

tf.Tensor([ 5  9 11  5], shape=(4,), dtype=int32)


### Attention

In [26]:
Q = tf.random.normal((32, 64, 512))
K = tf.random.normal((32, 128, 512)) # need to be transposed
#bqm,bkm -> qm,km -> qm,mk -> qk

In [29]:
tf.einsum("bqm, bkm -> bqk", Q, K).shape

TensorShape([32, 64, 128])

### Reformer

https://ai.googleblog.com/2020/01/reformer-efficient-transformer.html

In [30]:
A = tf.random.normal((2,4,4,2)) # bcij
B = tf.random.normal((2,4,4,1)) # bcik

In [32]:
tf.einsum("bcik,bcij -> bckj", B, A).shape

TensorShape([2, 4, 1, 2])

In [34]:
tf.matmul(tf.transpose(B, (0,1,3,2)), A)

<tf.Tensor: shape=(2, 4, 1, 2), dtype=float32, numpy=
array([[[[-0.02408861, -0.49770772]],

        [[-0.09428987, -0.20355959]],

        [[-3.7941003 ,  1.3513222 ]],

        [[-2.3759034 ,  0.85872626]]],


       [[[-2.5640306 , -5.3546343 ]],

        [[-1.9126546 , -2.592187  ]],

        [[ 2.918517  , -0.0059047 ]],

        [[ 3.4186466 , -2.4441628 ]]]], dtype=float32)>

In [35]:
# using einsum is much cleaner
A = tf.random.normal((2,4,5,4,2)) # bcij
B = tf.random.normal((2,4,5,4,1)) # bcik
tf.einsum("bcdik,bcdij -> bcdkj", B, A).shape

TensorShape([2, 4, 5, 1, 2])