In [1]:
import torch 
import torch.nn as nn
import fasttext as ft
import math

In [None]:
# ! wget https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.bin.gz -P {path}
! wget -c https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.fr.300.bin.gz
# !gunzip cc.en.300.bin cc.en.300.bin.gz
!gunzip cc.fr.300.bin cc.fr.300.bin.gz

In [None]:
en = ft.load_model(str('cc.en.300.bin'))
fr = ft.load_model(str('cc.fr.300.bin'))

# New Implementation

In [2]:
d_model = 5 #dimension of word embedding

In [3]:
def getWordVectors(sentence):
    sentence = sentence.split(' ')
    vecs = torch.rand((len(sentence),d_model))
    return vecs

In [4]:
def PositionalEncoding(wordVecs):
    for pos in range(wordVecs.shape[0]):
        for i in range(wordVecs[pos].shape[0]):
            if i%2 == 0:
                wordVecs[pos][i] = wordVecs[pos][i] + math.sin(pos/(10000**(2*i/d_model)))
            else:
                wordVecs[pos][i] = wordVecs[pos][i] + math.cos(pos/(10000**(2*i/d_model)))
                
    return wordVecs
                                           

In [5]:
def get_qkv_weights(r,c):
    query_weights = torch.rand((r,c))
    key_weights = torch.rand((r,c))
    value_weights = torch.rand((r,c))
    return query_weights, key_weights, value_weights
    

In [6]:
def qkvs(vectorMatrix, new_dim):
    
    query_weights, key_weights, value_weights = get_qkv_weights(5,new_dim)
    
    return torch.matmul(vectorMatrix, query_weights), torch.matmul(vectorMatrix, key_weights), \
    torch.matmul(vectorMatrix, value_weights) 

# Check for transposeness in matrix multiplication

In [7]:
def qk_dotproducts(queries, keys):
    dotproduct_matrix = torch.Tensor([])
    for i in queries:
        dotproduct_vector = torch.Tensor([])
        for j in keys:
            dotproduct_vector = torch.cat([dotproduct_vector, torch.dot(i,j).reshape(-1)])
        dotproduct_matrix = torch.cat([dotproduct_matrix, dotproduct_vector.reshape(1,-1)])
     
    return dotproduct_matrix

In [8]:
def getSoftmaxed_qkdp(qk_dotproductmatrix):
    
    sm = nn.Softmax(dim = 0)
    sm_matrix = torch.tensor([])
    for i in qk_dotproductmatrix:
        sm_matrix = torch.cat([sm_matrix, sm(i).reshape(1,-1)])
        
    return sm_matrix
    

In [56]:
def getSoftmaxWeightedValues(softmaxed_qkdp, values):
    
    dim2_mat = torch.tensor([])
    dim3_mat = torch.tensor([])
    
    outer_loop_range = softmaxed_qkdp.shape[0]
    inner_loop_range = values.shape[0]
    
    
    for i in range(outer_loop_range):
        for j in range(inner_loop_range):
            dim2_mat = torch.cat([dim2_mat, (softmaxed_qkdp[i][j]*values[j]).reshape(-1)])
        dim3_mat = torch.cat([dim3_mat, dim2_mat.reshape(1,values.shape[0],values.shape[1])])
        dim2_mat = torch.tensor([])

        
    return dim3_mat

In [None]:
'''
wordVecs = getWordVectors('Hi there this is nuts')
pos_encoded = PositionalEncoding(wordVecs)

new_dim = 3
queries, keys, values = qkvs(pos_encoded, new_dim)
qk_dotproductmatrix = qk_dotproducts(queries, keys)

d_k = keys.shape[1] # to be changed later to square root of 'key' vector dimension
qk_dotproductmatrix/=d_k

softmaxed_qkdp = getSoftmaxed_qkdp(qk_dotproductmatrix)
softmax_weighted_values = getSoftmaxWeightedValues(softmaxed_qkdp, values)

'''

In [10]:
wordVecs = getWordVectors('Hi there this is nuts')
# wordVecs

In [11]:
pos_encoded = PositionalEncoding(wordVecs)
# pos_encoded

In [12]:
new_dim = 3
queries, keys, values = qkvs(pos_encoded, new_dim)

In [13]:
qk_dotproductmatrix = qk_dotproducts(queries, keys)

In [14]:
d_k = keys.shape[1] # to be changed later to square root of 'key' vector dimension
qk_dotproductmatrix/=d_k

In [None]:
qk_dotproductmatrix

In [15]:
softmaxed_qkdp = getSoftmaxed_qkdp(qk_dotproductmatrix)

In [25]:
softmaxed_qkdp.shape

torch.Size([5, 5])

In [57]:
softmax_weighted_values = getSoftmaxWeightedValues(softmaxed_qkdp, values)

In [58]:
softmax_weighted_values

tensor([[[0.3957, 0.3757, 0.2720],
         [0.8062, 0.9068, 0.8936],
         [1.2354, 1.1684, 0.9549],
         [1.1423, 1.1615, 0.8895],
         [0.0890, 0.1170, 0.0900]],

        [[0.3786, 0.3594, 0.2602],
         [0.7820, 0.8795, 0.8668],
         [1.3036, 1.2330, 1.0076],
         [1.1613, 1.1808, 0.9043],
         [0.0704, 0.0925, 0.0712]],

        [[0.3447, 0.3273, 0.2369],
         [0.7939, 0.8929, 0.8799],
         [1.3247, 1.2529, 1.0239],
         [1.1892, 1.2092, 0.9261],
         [0.0595, 0.0783, 0.0602]],

        [[0.3423, 0.3250, 0.2353],
         [0.7984, 0.8981, 0.8850],
         [1.2992, 1.2288, 1.0042],
         [1.2049, 1.2252, 0.9383],
         [0.0627, 0.0824, 0.0634]],

        [[0.4656, 0.4420, 0.3200],
         [0.7968, 0.8962, 0.8831],
         [1.1077, 1.0477, 0.8562],
         [1.0780, 1.0961, 0.8395],
         [0.1467, 0.1928, 0.1484]]])