In [None]:
import tensorflow as tf

In [None]:
# Scalar vector multiplication

a = tf.random.normal([1])
b = tf.random.normal([2, 3])
mult = tf.einsum('k, ij -> ij', a, b)
print(a)
print(b)
print(mult)

tf.Tensor([2.4281507], shape=(1,), dtype=float32)
tf.Tensor(
[[ 0.9201611   0.7869265  -0.1516344 ]
 [-0.28274438 -0.24008621  1.2965367 ]], shape=(2, 3), dtype=float32)
tf.Tensor(
[[ 2.23429     1.9107761  -0.36819115]
 [-0.68654597 -0.5829655   3.1481864 ]], shape=(2, 3), dtype=float32)


In [None]:
# Vector vector multiplication

a = tf.random.normal([2, 5])
b = tf.random.normal([5, 3])
mult = tf.einsum('ij, jk -> ik', a, b)
print(a)
print(b)
print(mult)

tf.Tensor(
[[-0.04595557 -1.0043219   0.17543502  0.39695907 -0.45605093]
 [ 1.2592007   1.6554294   0.5069472  -0.5516412   0.19223353]], shape=(2, 5), dtype=float32)
tf.Tensor(
[[ 1.083499   -0.00988905  0.4788915 ]
 [ 0.02091054 -1.5508705   0.16104497]
 [ 0.33273518 -0.8711445  -2.8092797 ]
 [-0.53535753  0.2049706   0.73983634]
 [-0.08899844  0.15808834 -0.10611595]], shape=(5, 3), dtype=float32)
tf.Tensor(
[[-0.18434753  1.4144672  -0.33451575]
 [ 1.8458548  -3.1041136  -0.9830607 ]], shape=(2, 3), dtype=float32)


In [None]:
# Outer product

a = tf.range(4)
b = tf.range(3, 6)  
prod = tf.einsum('i,j -> ij', a, b)
print(a)
print(b)
print(prod)

tf.Tensor([0 1 2 3], shape=(4,), dtype=int32)
tf.Tensor([3 4 5], shape=(3,), dtype=int32)
tf.Tensor(
[[ 0  0  0]
 [ 3  4  5]
 [ 6  8 10]
 [ 9 12 15]], shape=(4, 3), dtype=int32)


In [None]:
# Scalar dot product

a = tf.range(9)
a = tf.reshape(a, [3, 3])

b = tf.range(6)
b = tf.reshape(b, [3, 2])

prod = tf.einsum('ij, jk ->', a, b)
print(a)
print(b)
print(prod)

tf.Tensor(
[[0 1 2]
 [3 4 5]
 [6 7 8]], shape=(3, 3), dtype=int32)
tf.Tensor(
[[0 1]
 [2 3]
 [4 5]], shape=(3, 2), dtype=int32)
tf.Tensor(204, shape=(), dtype=int32)


In [None]:
# Hadamard product

a = tf.range(8)
a = tf.reshape(a, [2, 4])
a = tf.([[1, 2, 4], [2, 3, 1], [5, 2, 1]])
b = torch.Tensor([[5, 2, 4], [9, 4, 1], [8, 1, 6]])
b = tf.range(4, 12)
b = tf.reshape(b, [2, 4])

prod = tf.einsum('ij, ij -> ij', a, b)
print(a)
print(b)
print(prod)

tf.Tensor(
[[0 1 2 3]
 [4 5 6 7]], shape=(2, 4), dtype=int32)
tf.Tensor(
[[ 4  5  6  7]
 [ 8  9 10 11]], shape=(2, 4), dtype=int32)
tf.Tensor(
[[ 0  5 12 21]
 [32 45 60 77]], shape=(2, 4), dtype=int32)


In [None]:
# Batch matrix multiplication

a = tf.random.normal([3, 3, 2])
b = tf.random.normal([3, 2, 2])
batch_mult = tf.einsum('bij, bjk -> bik', a, b)
print(a)
print(b)
print(batch_mult)

tf.Tensor(
[[[ 2.2917955   0.21000162]
  [ 0.8124304   0.06777922]
  [-0.11358453  0.7659347 ]]

 [[-0.1675065   0.5600017 ]
  [-0.66384035 -0.1080124 ]
  [-0.00793505 -0.4128089 ]]

 [[ 0.29294032 -0.714986  ]
  [ 1.2740775   2.499109  ]
  [-0.37249446 -1.272966  ]]], shape=(3, 3, 2), dtype=float32)
tf.Tensor(
[[[-0.6430566  -0.3129209 ]
  [ 0.49324626 -0.547294  ]]

 [[-1.3205546  -0.7890352 ]
  [-0.11901931  0.02230156]]

 [[-0.2517252  -2.472473  ]
  [ 1.0076808   0.45598066]]], shape=(3, 2, 2), dtype=float32)
tf.Tensor(
[[[-1.3701717  -0.83208334]
  [-0.48900685 -0.2913216 ]
  [ 0.4508357  -0.3836485 ]]

 [[ 0.15455046  0.14465745]
  [ 0.88949305  0.52138454]
  [ 0.0596109  -0.00294525]]

 [[-0.7942182  -1.0503068 ]
  [ 2.1975868  -2.0105767 ]
  [-1.1889771   0.34053457]]], shape=(3, 3, 2), dtype=float32)


In [None]:
# Tensor reduction

a = tf.random.normal([2, 3, 5, 7])
b = tf.random.normal([4, 1, 3, 11, 5])
reduction = tf.einsum('pqrs, tuqvr -> pstuv', a, b)
print(a.shape, b.shape, reduction.shape)

(2, 3, 5, 7) (4, 1, 3, 11, 5) (2, 7, 4, 1, 11)


In [None]:
# Transpose

a = tf.range(8)
a = tf.reshape(a, [4, 2])
transpose = tf.einsum('ij -> ji', a)
print(a)
print(transpose)

tf.Tensor(
[[0 1]
 [2 3]
 [4 5]
 [6 7]], shape=(4, 2), dtype=int32)
tf.Tensor(
[[0 2 4 6]
 [1 3 5 7]], shape=(2, 4), dtype=int32)


In [None]:
# Bilinear transformation

a = tf.random.normal([2, 3])
b = tf.random.normal([3, 3, 4])
c = tf.random.normal([2, 4])
bilinear = tf.einsum('ik, jkl, il -> ij', a, b, c)
print(a)
print(b)
print(c)
print(bilinear)

tf.Tensor(
[[ 0.3146136  0.4827846  1.1269066]
 [-1.5365981 -1.0660455  1.1572146]], shape=(2, 3), dtype=float32)
tf.Tensor(
[[[ 0.14003627  0.01163231 -0.13139397  0.7894577 ]
  [-0.2158014   1.813187   -0.5014698  -0.9422409 ]
  [ 1.2023556  -0.13701855  0.3905423  -0.9621358 ]]

 [[ 0.9206127   1.1183867  -0.61091995 -0.4539109 ]
  [ 1.5881643  -0.67286307  2.8616853   0.10222552]
  [-1.4472909  -0.21255776  0.34756663  0.32714945]]

 [[-0.13337424 -1.2615469   0.6651735   0.5972824 ]
  [ 1.0825644   0.23683508 -2.28232     0.25065213]
  [ 0.36147472 -0.3565847  -1.0303156  -0.8182818 ]]], shape=(3, 3, 4), dtype=float32)
tf.Tensor(
[[ 0.140482    0.39513645 -1.0058335   1.2766947 ]
 [ 0.55331546 -0.561943   -0.62871015 -0.6802107 ]], shape=(2, 4), dtype=float32)
tf.Tensor(
[[-1.3372614  -1.4036021   1.137095  ]
 [ 2.1155162  -1.5282719   0.30341208]], shape=(2, 3), dtype=float32)


In [None]:
# Attention

# Parameters
# [hidden_dimension]
bM = tf.random.normal([7])
br = tf.random.normal([7]) 
w = tf.random.normal([7])
# [hidden_dimension x hidden_dimension]
WY = tf.random.normal([7, 7])
Wh = tf.random.normal([7, 7])
Wr = tf.random.normal([7, 7])
Wt = tf.random.normal([7, 7])

def attention(Y, ht, rt1):
  # [batch_size x hidden_dimension] 
  tmp = tf.einsum('ik, kl -> il', ht, Wh) + tf.einsum('ik, kl -> il', rt1, Wr)

  tmp_expanded = tf.expand_dims(tmp, 1)
  tmp_tiled = tf.tile(tmp_expanded, [1, Y.shape[1], 1]) 
  Mt = tf.tanh(tf.einsum('ijk, kl -> ijl', Y, WY) + tmp_tiled + bM)
  
  # [batch_size x sequence_length]
  at = tf.nn.softmax(tf.einsum('ijk, k -> ij', Mt, w)) 
  
  # [batch_size x hidden_dimension]
  rt = tf.einsum('ijk, ij -> ik', Y, at) + tf.tanh(tf.einsum('ij, jk -> ik', rt1, Wt) + br)
  
  return rt, at

# Inputs - [batch_size x sequence_length x hidden_dimension]
Y = tf.random.normal([3,5,7])
# [batch_size x hidden_dimension]
ht = tf.random.normal([3, 7])
rt1 = tf.random.normal([3, 7])

rt, at = attention(Y, ht, rt1)

print(at)

tf.Tensor(
[[0.04286945 0.13299473 0.6374267  0.04080512 0.14590405]
 [0.3995505  0.45393756 0.03856793 0.01970782 0.08823621]
 [0.26846936 0.0719656  0.0737488  0.31743228 0.268384  ]], shape=(3, 5), dtype=float32)


In [None]:
# Treeqn

def transition(zl):
  # [batch_size x num_actions x hidden_dimension]
  return tf.expand_dims(zl, 1) + tf.tanh(tf.einsum('bk, aki -> bai', zl, W) + b)

# Inputs - [batch_size x hidden_dimension]
zl = tf.random.normal([2, 3])
# Parameters - [num_actions x hidden_dimension]
b = tf.random.normal([5, 3])
# Actions - [num_actions x hidden_dimension x hidden_dimension]
W = tf.random.normal([5, 3, 3])

transition(zl)

<tf.Tensor: shape=(2, 5, 3), dtype=float32, numpy=
array([[[-1.0403477 , -0.17211628, -0.93623567],
        [-1.329174  , -0.2081781 , -2.7835898 ],
        [-0.87566197, -0.44945782, -2.7821753 ],
        [-1.7981579 , -2.1150217 , -0.94191855],
        [ 0.1530087 , -0.28926474, -2.9333048 ]],

       [[ 0.78875744, -1.0552517 , -0.23725814],
        [-0.8185021 , -1.3224752 , -0.14671803],
        [ 1.0521754 , -0.94115245,  0.9765916 ],
        [ 0.28786626, -0.73542947,  1.185014  ],
        [ 0.9353899 , -1.3044728 ,  0.02783903]]], dtype=float32)>