In [1]:
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchtext

In [119]:
class NVSM(nn.Module):
    def __init__(self, n_doc, n_tok, dim_doc_emb, dim_tok_emb, neg_sampling_rate):
        super(NVSM, self).__init__()
        self.doc_emb           = nn.Embedding(n_doc, embedding_dim = dim_doc_emb)
        self.tok_emb           = nn.Embedding(n_tok, embedding_dim = dim_tok_emb)
        self.tok_to_doc        = nn.Linear(dim_tok_emb, dim_doc_emb)
        self.bias              = nn.Parameter(torch.Tensor(dim_doc_emb))
        self.neg_sampling_rate = neg_sampling_rate
        
    def query_to_tensor(self, query):
        '''
        Computes the average of the word embeddings of the query. This method 
        corresponds to the function 'g' in the article.
        '''
        # Create a mask to ignore padding embeddings
        query_mask    = (query != 0).float()
        # Compute the number of tokens in each query to properly compute the 
        # average
        tok_by_input  = query_mask.sum(dim = 1)
        query_tok_emb = self.tok_emb(query)
        query_tok_emb = query_tok_emb * query_mask.unsqueeze(-1)
        # Compute the average of the embeddings
        query_emb     = query_tok_emb.sum(dim = 1) / tok_by_input.unsqueeze(-1)
        
        return query_emb
    
    def normalize_query_tensor(self, query_tensor):
        '''
        Divides each query tensor by its L2 norm. This method corresponds to 
        the function 'norm' in the article.
        '''
        norm = torch.norm(query_tensor, dim = 1) # we might have to detach this value 
                                                 # from the computation graph.
        return query_tensor / norm.unsqueeze(-1)
        
    def query_to_doc_space(self, query):
        '''
        Projects a query vector into the document vector space. This method corresponds 
        to the function 'f' in the article.
        '''
        return self.tok_to_doc(query)
    
    def score(self, query, document):
        '''
        Computes the cosine similarity between a query and a document embedding.
        This method corresponds to the function 'score' in the article.
        '''
        # batch dot product using batch matrix multiplication
        num   = torch.bmm(query.unsqueeze(1), document.unsqueeze(-1))
        denum = torch.norm(query, dim = 1) * torch.norm(document, dim = 1)
        
        return num / denum
        
    def non_stand_projection(self, n_gram):
        '''
        Computes the non-standard projection of a n-gram into the document vector 
        space. This method corresponds to the function 'T^~' in the article.
        '''
        n_gram_tensor      = self.query_to_tensor(n_gram)
        norm_n_gram_tensor = self.normalize_query_tensor(n_gram_tensor)
        projection         = self.query_to_doc_space(norm_n_gram_tensor)
        
        return projection
    
    def _custom_batchnorm(self, batch):
        '''
        Computes the variant of the batch normalization formula used in this article. 
        It only uses a bias and no weights.
        '''
        batch_feat_norm = (batch - batch.mean(dim = 0)) / batch.std(dim = 0)
        batch_feat_norm = batch_feat_norm + self.bias
        
        return batch_feat_norm
    
    def stand_projection(self, batch):
        '''
        Computes the standard projection of a n-gram into document vector space with
        a hardtanh activation. This method corresponds to the function 'T' in the 
        article.
        '''
        non_stand_proj = self.non_stand_projection(batch) 
        bn             = self._custom_batchnorm(non_stand_proj)
        activation     = F.hardtanh(bn)

        return activation
    
    def representation_similarity(self, query, document):
        '''
        Computes the similarity between a query and a document. This method corresponds 
        to the function 'P' in the article.
        '''
        document_emb  = self.doc_emb(document)
        query_proj    = self.stand_projection(query)
        # If we have a single document to match against each query, we have
        # to reshape the tensor to compute a simple dot product.
        # Otherwise, we compute a simple matrix multiplication to match the 
        # query against each document.
        if len(document_emb.shape) == 2:
            document_emb = document_emb.unsqueeze(1)
        dot_product   = torch.bmm(document_emb, query_proj.unsqueeze(-1))
        similarity    = torch.sigmoid(dot_product)
        
        return similarity.squeeze()
    
    def proba_doc_query(self, query, document):
        '''
        Approximates the probability of document given query by uniformly sampling 
        constrastive examples.
        '''
        # Positive term, this should be maximized as it indicates how similar the
        # correct document is to the query
        pos_repr = self.representation_similarity(query, document)
        
        # Sampling uniformly 'self.neg_sampling_rate' documents to compute the 
        # negative term. We first randomly draw the indices of the documents and 
        # then we compute the similarity with the query.
        n_docs             = self.doc_emb.num_embeddings
        neg_sample_size    = (query.size(0), self.neg_sampling_rate)
        neg_sample         = torch.randint(low = 0, high = n_docs, size = neg_sample_size)
#         neg_sample_doc_emb = self.doc_emb(neg_sample)
        neg_repr           = self.representation_similarity(query, neg_sample)
        
        return pos_repr, neg_repr
    
    def forward(self, query, document_ids):
        pass

In [120]:
nvsm = NVSM(
    n_doc       = 20, 
    n_tok       = 9, 
    dim_doc_emb = 10, 
    dim_tok_emb = 7,
    neg_sampling_rate = 4
)

In [121]:
l = [
        [1, 2, 3, 0, 0], 
        [4, 5, 6, 7, 8],
        [2, 7, 8, 3, 0]
]
query = torch.tensor(l)
document = torch.tensor([2,3,7])
print(query)
print(document)

tensor([[1, 2, 3, 0, 0],
        [4, 5, 6, 7, 8],
        [2, 7, 8, 3, 0]])
tensor([2, 3, 7])


In [122]:
nvsm.representation_similarity(query, document)

tensor([0.0150, 0.4598, 0.6190], grad_fn=<SqueezeBackward0>)

In [123]:
pos_repr, neg_repr = nvsm.proba_doc_query(query, document)

In [124]:
pos_repr.shape

torch.Size([3])

In [125]:
neg_repr.shape

torch.Size([3, 4])

In [126]:
neg_repr

tensor([[0.0895, 0.0451, 0.6147, 0.5289],
        [0.0224, 0.9056, 0.9005, 0.9289],
        [0.8214, 0.9935, 0.9186, 0.2677]], grad_fn=<SqueezeBackward0>)

In [6]:
res = nvsm.non_stand_projection(t)

In [7]:
res

tensor([[ 0.2862,  0.1525,  0.4956, -0.2371,  0.3743,  0.1711,  0.1562,  0.3258,
          0.0477,  0.0066],
        [ 0.5237, -0.5505,  0.3950, -0.2800,  0.0744, -0.3260, -0.0512, -0.0246,
          0.3163,  0.5055]], grad_fn=<AddmmBackward>)

In [85]:
query_emb = torch.tensor([
    [1, 2, 3, 4],
    [1, 0, 0, 0],
])
doc_emb = torch.arange(24).view(2, 3, 4)
print(query_emb.shape, doc_emb.shape)
print(query_emb)
print(doc_emb)

torch.Size([2, 4]) torch.Size([2, 3, 4])
tensor([[1, 2, 3, 4],
        [1, 0, 0, 0]])
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])


In [86]:
print(doc_emb.shape, query_emb.unsqueeze(-1).shape)

torch.Size([2, 3, 4]) torch.Size([2, 4, 1])


In [87]:
result = torch.bmm(doc_emb, query_emb.unsqueeze(-1))
print(result.shape)
print(result)

torch.Size([2, 3, 1])
tensor([[[ 20],
         [ 60],
         [100]],

        [[ 12],
         [ 16],
         [ 20]]])


In [31]:
result = torch.bmm(t1.unsqueeze(1), t2.transpose(-1, -2))
print(result.shape)
print(result)

torch.Size([3, 1, 3])
tensor([[[ 20,  60, 100]],

        [[ 12,  16,  20]],

        [[ 26,  30,  34]]])


In [13]:
t1.shape

torch.Size([3])

In [14]:
t2.shape

torch.Size([3, 3])

In [20]:
t2.transpose(-1, -2)

tensor([[[0, 3, 6],
         [1, 4, 7],
         [2, 5, 8]]])

In [40]:
emb = nvsm.doc_emb
emb

Embedding(20, 10)

In [55]:
negative_sample = torch.randint(low = 0, high = 9, size = (3, 5))
negative_sample

tensor([[2, 7, 7, 8, 2],
        [6, 8, 0, 8, 2],
        [1, 1, 6, 3, 1]])

In [56]:
emb_doc = emb(negative_sample)

In [57]:
emb_doc.shape

torch.Size([3, 5, 10])

In [43]:
negative_sample.dtype

torch.int64

In [45]:
type(negative_sample.int())

torch.Tensor

In [49]:
negative_sample.shape

torch.Size([5, 1])

In [58]:
emb.num_embeddings

20