In [1]:
import torch 
import torch.nn as nn
import fasttext as ft
import math

In [10]:
class Encoder():
    def __init__(self, vectorRepresentations):
        self.vectorRepresentations = vectorRepresentations
        self.d_model = 10
        self.new_dim = 5


    def PositionalEncoding(self,wordVecs):
        for pos in range(wordVecs.shape[0]):
            for i in range(wordVecs[pos].shape[0]):
                if i%2 == 0:
                    wordVecs[pos][i] = wordVecs[pos][i] + math.sin(pos/(10000**(2*i/self.d_model)))
                else:
                    wordVecs[pos][i] = wordVecs[pos][i] + math.cos(pos/(10000**(2*i/self.d_model)))            
        return wordVecs


    def get_qkv_weights(self,r,c):
        query_weights = torch.rand((r,c))
        key_weights = torch.rand((r,c))
        value_weights = torch.rand((r,c))
        return query_weights, key_weights, value_weights
    
    
    
    def qkvs(self,vectorMatrix, new_dim):
        query_weights, key_weights, value_weights = self.get_qkv_weights(self.d_model,new_dim)
        return torch.matmul(vectorMatrix, query_weights), torch.matmul(vectorMatrix, key_weights), \
        torch.matmul(vectorMatrix, value_weights) 
        # Check for transposeness in matrix multiplication
    
    
    def qk_dotproducts(self,queries, keys):
        dotproduct_matrix = torch.Tensor([])
        for i in queries:
            dotproduct_vector = torch.Tensor([])
            for j in keys:
                dotproduct_vector = torch.cat([dotproduct_vector, torch.dot(i,j).reshape(-1)])
            dotproduct_matrix = torch.cat([dotproduct_matrix, dotproduct_vector.reshape(1,-1)])
        return dotproduct_matrix
    
    
    def getSoftmaxed_qkdp(self,qk_dotproductmatrix):
        sm = nn.Softmax(dim = 0)
        sm_matrix = torch.tensor([])
        for i in qk_dotproductmatrix:
            sm_matrix = torch.cat([sm_matrix, sm(i).reshape(1,-1)])
        return sm_matrix
    
    
    def getSoftmaxWeightedValues(self,softmaxed_qkdp, values):
        dim2_mat = torch.tensor([])
        dim3_mat = torch.tensor([])
        outer_loop_range = softmaxed_qkdp.shape[0]
        inner_loop_range = values.shape[0]
        for i in range(outer_loop_range):
            for j in range(inner_loop_range):
                dim2_mat = torch.cat([dim2_mat, (softmaxed_qkdp[i][j]*values[j]).reshape(-1)])
            dim3_mat = torch.cat([dim3_mat, dim2_mat.reshape(1,values.shape[0],values.shape[1])])
            dim2_mat = torch.tensor([]) 
        return dim3_mat
    
    
    def returnRepresentation(self):
        pos_encoded = self.PositionalEncoding(self.vectorRepresentations)
        new_dim = self.new_dim
        queries, keys, values = self.qkvs(pos_encoded, new_dim)
        qk_dotproductmatrix = self.qk_dotproducts(queries, keys)
        d_k = keys.shape[1] # to be changed later to square root of 'key' vector dimension
        qk_dotproductmatrix/=d_k
        softmaxed_qkdp = self.getSoftmaxed_qkdp(qk_dotproductmatrix)
        softmax_weighted_values = self.getSoftmaxWeightedValues(softmaxed_qkdp, values)
        return softmax_weighted_values
                                           


In [None]:
def getWordVectors(sentence):
    sentence = sentence.split(' ')
    vecs = torch.rand((len(sentence),10))
    return vecs

In [13]:
wordVecs = getWordVectors('Hi there this is nuts')

In [14]:
a = Encoder(wordVecs)
a.returnRepresentation()

tensor([[[7.4512e-01, 1.0554e+00, 7.8300e-01, 7.4537e-01, 9.6133e-01],
         [1.2013e+00, 1.5676e+00, 1.1543e+00, 1.2686e+00, 1.5323e+00],
         [3.4198e+00, 3.9404e+00, 2.6120e+00, 3.4701e+00, 4.1074e+00],
         [1.3480e-04, 1.7156e-04, 1.1612e-04, 1.3747e-04, 1.6242e-04],
         [4.3894e-04, 6.7405e-04, 5.1444e-04, 4.3864e-04, 5.5125e-04]],

        [[7.3713e-01, 1.0440e+00, 7.7460e-01, 7.3737e-01, 9.5102e-01],
         [1.2079e+00, 1.5762e+00, 1.1607e+00, 1.2756e+00, 1.5407e+00],
         [3.4220e+00, 3.9429e+00, 2.6137e+00, 3.4723e+00, 4.1100e+00],
         [1.0125e-04, 1.2887e-04, 8.7223e-05, 1.0326e-04, 1.2199e-04],
         [3.6273e-04, 5.5702e-04, 4.2512e-04, 3.6248e-04, 4.5554e-04]],

        [[7.6089e-01, 1.0777e+00, 7.9957e-01, 7.6114e-01, 9.8168e-01],
         [1.2070e+00, 1.5751e+00, 1.1598e+00, 1.2747e+00, 1.5396e+00],
         [3.3949e+00, 3.9117e+00, 2.5930e+00, 3.4448e+00, 4.0775e+00],
         [1.2326e-04, 1.5688e-04, 1.0618e-04, 1.2571e-04, 1.4852e-04],
  