In [2]:
import tensorflow as tf
import torch.nn as n
import torch.nn.functional as fun
from tensorflow.keras.layers import Dense, Softmax

In [4]:
x = tf.random.uniform(shape=(2, 3), minval=-10.0, maxval=10.0)
x

<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[-2.1658254,  8.22505  , -3.808651 ],
       [ 6.729145 ,  6.849077 ,  6.204609 ]], dtype=float32)>

scalar vector multiplication

In [6]:
scalar = tf.constant([4])
vector = tf.constant([1, 2, 3])

print(tf.einsum("i,j->j", scalar, vector))

tf.Tensor([ 4  8 12], shape=(3,), dtype=int32)


vector vector multiplications 

In [7]:
v = tf.random.uniform(shape=(1, 3), minval=-10.0, maxval=10.0)
tf.einsum("ik,jk->ij", x, v)

<tf.Tensor: shape=(2, 1), dtype=float32, numpy=
array([[-52.543365],
       [-21.04249 ]], dtype=float32)>

outer product

In [8]:
a = tf.constant([1,2,3])
b = tf.constant([4,5,6,7])
tf.einsum('i,j->ij', a, b)

<tf.Tensor: shape=(3, 4), dtype=int32, numpy=
array([[ 4,  5,  6,  7],
       [ 8, 10, 12, 14],
       [12, 15, 18, 21]], dtype=int32)>

scalar dot product

In [12]:
a = tf.constant([[1, 2], [3, 4]])
b = tf.constant([[5, 6], [7, 8]])
tf.einsum('ij,ij->', a, b)

<tf.Tensor: shape=(), dtype=int32, numpy=70>

hadamard product

In [13]:
mat1 = tf.constant([[1,2,3],[4,5,6]])
mat2 = tf.constant([[1,2,3],[4,5,6]])
tf.einsum("ij,ij->ij",mat1,mat2)

<tf.Tensor: shape=(2, 3), dtype=int32, numpy=
array([[ 1,  4,  9],
       [16, 25, 36]], dtype=int32)>

batch matrix multiplication

In [14]:
a = tf.random.uniform(shape=(3,2,5), minval=-10.0, maxval=10.0)
b = tf.random.uniform(shape=(3,5,3), minval=-10.0, maxval=10.0)

tf.einsum('ijk,ikl->ijl', a, b)

<tf.Tensor: shape=(3, 2, 3), dtype=float32, numpy=
array([[[ 134.33382  ,  -17.328762 ,  -90.35362  ],
        [ -40.731388 ,  -13.250099 ,   99.280655 ]],

       [[  24.042496 ,  -18.281195 ,   -6.8287354],
        [  50.604687 ,  -34.24412  ,   74.72541  ]],

       [[ 127.27751  ,  -56.893227 ,  129.34631  ],
        [ -52.75166  ,   58.70645  , -196.2848   ]]], dtype=float32)>

tensor reductoin


In [15]:
a = tf.random.uniform(shape=(2,17,5,7), minval=-10.0, maxval=10.0)
b = tf.random.uniform(shape=(11,2,4,17,6), minval=-10.0, maxval=10.0)

tf.einsum('pqrs,tpwqm->rstwm', a, b).shape

TensorShape([5, 7, 11, 4, 6])

transpose

In [16]:
print(tf.einsum("ij -> ji", x))

tf.Tensor(
[[-2.1658254  6.729145 ]
 [ 8.22505    6.849077 ]
 [-3.808651   6.204609 ]], shape=(3, 2), dtype=float32)


bilinear transformation

In [17]:
a = tf.random.uniform(shape=(2,3), minval=-10.0, maxval=10.0)
b = tf.random.uniform(shape=(5,3,7), minval=-10.0, maxval=10.0)
c = tf.random.uniform(shape=(2,7), minval=-10.0, maxval=10.0)
tf.einsum('ik,jkl,il->ij', a, b, c)
     

<tf.Tensor: shape=(2, 5), dtype=float32, numpy=
array([[ 559.5431  , -329.48898 ,  -99.15325 ,  763.09424 ,  133.90778 ],
       [ -48.326393,   71.41156 ,  -42.681236,  380.50006 , -271.79953 ]],
      dtype=float32)>

attention

In [18]:
def random_tensors(shape, num=1, requires_grad=False):
    tensors = [tf.Variable(tf.random.normal(shape), trainable=requires_grad) for i in range(num)]
    return tensors[0] if num == 1 else tensors

bM, br, w = random_tensors([7], num=3, requires_grad=True)
WY, Wh, Wr, Wt = random_tensors([7, 7], num=4, requires_grad=True)

print(bM, br, w)
print(WY, Wh, Wr, Wt)

def attention(Y, ht, rt1):
    tmp = tf.einsum("ik,kl->il", ht, Wh) + tf.einsum("ik,kl->il", rt1, Wr)
    Mt = tf.math.tanh(tf.einsum("ijk,kl->ijl", Y, WY) + tf.expand_dims(tmp, 1) + bM)
    at = Softmax(axis=1)(tf.einsum("ijk,k->ij", Mt, w))
    rt = tf.einsum("ijk,ij->ik", Y, at) + tf.math.tanh(tf.einsum("ij,jk->ik", rt1, Wt) + br)
    
    return rt, at


Y = tf.constant(tf.random.normal([3, 5, 7]))

ht, rt1 = random_tensors([3, 7], num=2)

print(ht)
print(rt1)

rt, at = attention(Y, ht, rt1)
print(at)

<tf.Variable 'Variable:0' shape=(7,) dtype=float32, numpy=
array([-0.7089832 ,  0.5624019 , -0.6840893 ,  0.03597113,  0.9779826 ,
        2.6247714 , -2.3527417 ], dtype=float32)> <tf.Variable 'Variable:0' shape=(7,) dtype=float32, numpy=
array([ 0.22511412, -0.6013475 , -0.7842833 ,  0.05097978, -0.79549587,
       -0.32434383,  0.40228274], dtype=float32)> <tf.Variable 'Variable:0' shape=(7,) dtype=float32, numpy=
array([-0.11647047, -0.05216181, -0.86743826,  0.17193635,  0.4600736 ,
       -0.39271578,  1.7877902 ], dtype=float32)>
<tf.Variable 'Variable:0' shape=(7, 7) dtype=float32, numpy=
array([[-0.84986687,  0.44461298, -0.14038365, -0.34058392,  0.30389097,
        -1.1205903 , -0.270132  ],
       [ 0.92856604, -1.2996588 ,  0.8174464 , -1.4669772 ,  0.0517964 ,
        -0.8099352 , -0.73960286],
       [ 0.40043527, -1.7654386 , -1.1371478 ,  0.95960563,  0.08699208,
        -1.0108954 , -0.9886355 ],
       [ 1.2375736 , -0.13933623,  0.49699235, -0.02013335,  0.1903941 ,

tree qn

In [20]:
b = tf.Variable(tf.random.normal([5, 3]))
W = tf.Variable(tf.random.normal([5, 3, 3]))

print(b)
print(W)

def transition(zl):
  return tf.expand_dims(zl, axis=1) + tf.math.tanh(tf.einsum("bk,aki->bai", zl, W) + b)

zl = tf.random.normal([2, 3])

zl
print(transition(zl))


<tf.Variable 'Variable:0' shape=(5, 3) dtype=float32, numpy=
array([[ 1.2794458 , -0.31569585,  1.2092395 ],
       [-0.28959262, -0.11288652,  0.94261295],
       [ 0.42567143, -0.06800645,  0.06119423],
       [ 0.13189146,  0.22345331, -1.98815   ],
       [-0.11826276,  0.9600913 , -0.90108657]], dtype=float32)>
<tf.Variable 'Variable:0' shape=(5, 3, 3) dtype=float32, numpy=
array([[[-0.82465965, -0.02026201,  0.11017417],
        [ 0.68671507,  0.5564525 , -0.05423731],
        [-0.20295382, -0.07523216, -0.13472307]],

       [[-0.7306126 ,  0.77577025, -1.5224178 ],
        [ 1.382998  , -0.38332722,  1.3451985 ],
        [-0.32168388, -0.5818914 , -0.34726024]],

       [[-0.8179229 , -0.04278861, -0.5552251 ],
        [ 0.5011446 , -1.3177824 , -1.0807792 ],
        [ 1.2364876 , -0.5974472 , -2.0174239 ]],

       [[ 0.99304485,  1.0578467 ,  0.03344636],
        [ 0.25676224,  1.0589397 ,  0.28696394],
        [ 1.235737  , -0.299699  , -0.13352005]],

       [[-0.6227369 , 