In [1]:
import numpy as np

In [2]:
def softmax(x):
    return(np.exp(x)/np.exp(x).sum())

In [3]:
def split_heads(x, num_heads):
  split_head = np.reshape(x, (np.shape(x)[0], np.shape(x)[1], num_heads, -1))
  return split_head

In [4]:
def scaled_dot_product_attention(query, key, value):

    # Compute dot product of query and key
    dot_product=np.matmul(query, np.transpose(key, (0,2,1)))  
    #print("dot product shape:", dot_product.shape)

    # Calculating Scale
    key_dim=key.shape[-1]
    #print(key_dim)
    
    # Scale dot-product by key dimension
    scaled_dot_product = dot_product / np.sqrt(key_dim)
    #print("scaled dot product shape:", scaled_dot_product.shape)

    # Compute attention weights by softmax application
    attention_weights=softmax(scaled_dot_product)
    #print("attention weights:", attention_weights.shape)

    # Compute attention output
    scaled_dot_product_attention_output=np.matmul(attention_weights, value)
    #print("attention output:", attention_output.shape)

    return scaled_dot_product_attention_output

In [5]:
def multi_head_scaled_attention(query, key, value, num_heads, W_q, W_k, W_v):
  
    # Get the linearly projected values of query, key and value
    
    query = np.matmul(query,W_q)
    key = np.matmul(key,W_k)
    value = np.matmul(value, W_v)
    
    # Split q, k, and v into num_heads separate 2D tensors
    
    #queries = np.reshape(query, (np.shape(query)[0], np.shape(query)[1], num_heads, -1))
    queries = split_heads(query, num_heads)
    #print("Queries init: = ",queries.shape)

    #keys = np.reshape(key, (np.shape(key)[0], np.shape(key)[1], num_heads, -1))
    keys = split_heads(key, num_heads)
    #print("Keys init: =", keys.shape)

    #values = np.reshape(value, (np.shape(value)[0], np.shape(value)[1], num_heads, -1))
    values = split_heads(value, num_heads)
    #print("Values init: =",values.shape)

    # Concatenate attentions by calculating the attention for each head
    attention_outputs = []

    for i in range(num_heads):
        attention_output=scaled_dot_product_attention(queries[:,:,:,i], keys[:,:,:,i], values[:,:,:,i])
        attention_outputs.append(attention_output)

    # Make an array of concatenated attention outputs
    attention = np.array(attention_outputs)
    #print("Concat attention:", attention.shape)

    # Project the concatenated attention back to the original size
    multi_head_attention_output = np.transpose(attention, (1, 2, 3, 0))
    multi_head_attention_output = np.reshape(multi_head_attention_output, (multi_head_attention_output.shape[0], multi_head_attention_output.shape[1], -1))
    #print("Multi Head Attention Output:", multi_head_attention_output.shape)

    return multi_head_attention_output

In [6]:
# Testing out with following input values 

input_seq_len=5 # Maximum length of the input sequence
d_q=64          # Dimensionality of thge linearly projected queries
d_k=64          # Dimensionality of the linearly projected keys
d_v=64          # Dimensionality of the linearly projected values
batch_size=64   # Batch size from the training process
num_heads=8     # Number of self-attention heads

query = np.random.randn(batch_size, input_seq_len, d_q)   # generating input query matrix
key = np.random.randn(batch_size, input_seq_len, d_k)     # generating input key matrix
value = np.random.randn(batch_size, input_seq_len, d_v)   # generating input value matrix

W_q = np.random.randn(d_q, d_q)                           # for generating num head projection matrices for queries
W_k = np.random.randn(d_k, d_k)                           # for generating num head projection matrices for keys
W_v = np.random.randn(d_v, d_v)                           # for generating num head projection matrices for values 

In [7]:
# Testing code of scaled dot product attention

attention=scaled_dot_product_attention(query, key, value)
print("Scaled Dot Product Attention:", attention)
print("Scaled Dot Product Attention Shape:", attention.shape)

Scaled Dot Product Attention: [[[-5.42331165e-05  8.93264585e-04 -3.62929491e-04 ... -1.53482501e-03
   -1.64819457e-03  4.49659776e-04]
  [-3.51388413e-04 -9.18896634e-05  2.43727082e-05 ... -7.87239600e-04
   -6.95673509e-04  1.87105840e-03]
  [-3.80599868e-04  1.25855370e-03  7.43379396e-05 ... -2.57113688e-03
   -2.94539998e-03  4.09979698e-03]
  [-2.10901490e-04  5.92957119e-04 -8.53901711e-04 ... -2.07863604e-03
   -2.17466432e-03  1.17826968e-03]
  [-2.93210528e-06  5.43493552e-04  1.66037258e-04 ... -1.42462097e-03
   -1.44958928e-03  1.67015369e-03]]

 [[-1.38830841e-03  4.85514576e-03  1.64503101e-03 ...  2.32063376e-03
   -1.13489613e-03 -2.03268644e-03]
  [-7.09308490e-04  1.22612488e-03 -2.06180656e-04 ...  6.22827005e-04
   -4.06224227e-04 -3.49757754e-04]
  [-2.22514445e-03  1.14259259e-02  4.14821600e-03 ...  4.95919281e-03
   -1.53045199e-03 -4.93872052e-03]
  [-7.59862671e-04  1.91651466e-03 -1.39841116e-04 ...  5.29781486e-04
   -1.91300142e-04 -4.11299925e-04]
  [-3

In [8]:
# Testing code of multi head scaled attention

multi_head_attention=multi_head_scaled_attention(query, key, value, num_heads, W_q, W_k, W_v)
print("Multi Head Scaled Attention", multi_head_attention)
print("Multi Head Scaled Attention Shape:", multi_head_attention.shape)

Multi Head Scaled Attention [[[-6.82548377e-085  2.49936909e-132  1.69240724e-114 ...
    1.16807424e-108 -7.62068133e-064  2.29188417e-075]
  [-1.47432394e-132  7.39977915e-109  8.23297448e-091 ...
    4.81963390e-125  2.77499089e-084 -3.12620788e-064]
  [-2.73078037e-125  3.86699723e-124  1.00098713e-093 ...
    2.18581769e-081 -7.47450150e-093 -8.90008576e-072]
  [-3.75508114e-099 -8.62033077e-053  2.16251409e-126 ...
    8.46616557e-126  1.27894956e-087  8.96634175e-067]
  [-2.35275564e-067  6.17343298e-123  1.38601228e-086 ...
   -6.03470506e-115  3.23234402e-102  7.77983719e-082]]

 [[ 2.18208946e-111  9.89663008e-093  3.44639193e-123 ...
   -4.69289384e-114  2.41793168e-102  1.19141644e-061]
  [-2.15476646e-096  8.36805203e-105  1.59630012e-117 ...
   -3.17069982e-107  2.17700555e-101 -1.36687805e-074]
  [ 7.00885194e-125  2.87290024e-128  1.24635476e-055 ...
   -1.91052387e-112 -8.66578646e-106  2.64999271e-093]
  [ 1.29097858e-097  7.13375964e-090  3.25965575e-097 ...
   -2.98