In [1]:
import tensorflow as tf
from tensorflow.keras.layers import *
import numpy as np


In [40]:
# here row = 2 indicates two words each has (embedding dim) coloum=4
x =  tf.ones((16,16,512))

# project each word into another space using linear projection with wq,qk,qv matrices of dim=3/4
key_dim = query_dim = 64
value_dim = 64

wq =  Dense(query_dim)
wk = Dense(key_dim)
wv = Dense(value_dim)

q = wq(x)
k = wk(x)
v = wv(x)

# simililarity of each word/frame with respect to all other words/frames
pre_score = tf.matmul(q,k,transpose_b=True)

# normalized each dot product score into 0-1 range using softmax along each words score calculated against all other words.(N,N) size 
score = tf.keras.activations.softmax(pre_score,axis=1)

score_normalized = score/tf.math.sqrt(tf.dtypes.cast(key_size, tf.float32))

#context vector for each word/frame or signal at each time instant
#calculate context vector at each time instant 't' using weighted sum of values projection ,weighted by score matrix.
z = tf.matmul(score,v)
# q  = tf.matmul(x,wq(x))
# print(pre_score,score)
z.shape

TensorShape([16, 16, 64])

In [52]:
class Attention(tf.keras.layers.Layer):
  
  def __init__(self,query_dim = None,key_dim = None,value_dim = None,output_dim = None,num_of_head = None):
    
    super(Attention,self).__init__()

    self.query_dim = query_dim//num_of_head
    self.key_dim = key_dim//num_of_head
    self.value_dim = value_dim//num_of_head
    self.output_dim  = output_dim
    self.num_of_head  = num_of_head

    self.wq = [Dense(self.query_dim) for h in range(self.num_of_head)]
    self.wk = [Dense(self.key_dim) for h in range(self.num_of_head)]
    self.wv = [Dense(self.value_dim) for h in range(self.num_of_head)]
    self.wo = Dense(self.output_dim)



  def call(self,x):

    heads_context = []

    for h in range(self.num_of_head):

      q = self.wq[h](x)
      k = self.wk[h](x)
      v = self.wv[h](x)

      pre_score = tf.matmul(q,k,transpose_b=True)
      score = tf.keras.activations.softmax(pre_score,axis=1)/tf.math.sqrt(tf.dtypes.cast(self.key_dim, tf.float32))

      Z_context_vector = tf.matmul(score,v)

      heads_context.append(Z_context_vector)
    
    heads = tf.concat(heads_context, axis=2)
    heads = self.wo(heads)

    return heads







In [58]:
#input :-> (batch size,time_steps,embeddings_dims)
#output:-> (batch_size,time_steps,output_dims)

x =  tf.ones((12,16,512))

f = Attention(query_dim = 64,key_dim = 64,value_dim = 16,output_dim = 128,num_of_head=8)

z = f(x)

print(z.shape)

TensorShape([1, 16, 0])