In [1]:
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchtext

In [149]:
class NVSM(nn.Module):
    def __init__(self, n_doc, n_tok, dim_doc_emb, dim_tok_emb, neg_sampling_rate):
        super(NVSM, self).__init__()
        self.doc_emb           = nn.Embedding(n_doc, embedding_dim = dim_doc_emb)
        self.tok_emb           = nn.Embedding(n_tok, embedding_dim = dim_tok_emb)
        self.tok_to_doc        = nn.Linear(dim_tok_emb, dim_doc_emb)
        self.bias              = nn.Parameter(torch.Tensor(dim_doc_emb))
        self.neg_sampling_rate = neg_sampling_rate
        
    def query_to_tensor(self, query):
        '''
        Computes the average of the word embeddings of the query. This method 
        corresponds to the function 'g' in the article.
        '''
        # Create a mask to ignore padding embeddings
        query_mask    = (query != 0).float()
        # Compute the number of tokens in each query to properly compute the 
        # average
        tok_by_input  = query_mask.sum(dim = 1)
        query_tok_emb = self.tok_emb(query)
        query_tok_emb = query_tok_emb * query_mask.unsqueeze(-1)
        # Compute the average of the embeddings
        query_emb     = query_tok_emb.sum(dim = 1) / tok_by_input.unsqueeze(-1)
        
        return query_emb
    
    def normalize_query_tensor(self, query_tensor):
        '''
        Divides each query tensor by its L2 norm. This method corresponds to 
        the function 'norm' in the article.
        '''
        norm = torch.norm(query_tensor, dim = 1) # we might have to detach this value 
                                                 # from the computation graph.
        return query_tensor / norm.unsqueeze(-1)
        
    def query_to_doc_space(self, query):
        '''
        Projects a query vector into the document vector space. This method corresponds 
        to the function 'f' in the article.
        '''
        return self.tok_to_doc(query)
    
    def score(self, query, document):
        '''
        Computes the cosine similarity between a query and a document embedding.
        This method corresponds to the function 'score' in the article.
        '''
        # batch dot product using batch matrix multiplication
        num   = torch.bmm(query.unsqueeze(1), document.unsqueeze(-1))
        denum = torch.norm(query, dim = 1) * torch.norm(document, dim = 1)
        
        return num / denum
        
    def non_stand_projection(self, n_gram):
        '''
        Computes the non-standard projection of a n-gram into the document vector 
        space. This method corresponds to the function 'T^~' in the article.
        '''
        n_gram_tensor      = self.query_to_tensor(n_gram)
        norm_n_gram_tensor = self.normalize_query_tensor(n_gram_tensor)
        projection         = self.query_to_doc_space(norm_n_gram_tensor)
        
        return projection
    
    def _custom_batchnorm(self, batch):
        '''
        Computes the variant of the batch normalization formula used in this article. 
        It only uses a bias and no weights.
        '''
        batch_feat_norm = (batch - batch.mean(dim = 0)) / batch.std(dim = 0)
        batch_feat_norm = batch_feat_norm + self.bias
        
        return batch_feat_norm
    
    def stand_projection(self, batch):
        '''
        Computes the standard projection of a n-gram into document vector space with
        a hardtanh activation. This method corresponds to the function 'T' in the 
        article.
        '''
        non_stand_proj = self.non_stand_projection(batch) 
        bn             = self._custom_batchnorm(non_stand_proj)
        activation     = F.hardtanh(bn)

        return activation
    
    def representation_similarity(self, query, document):
        '''
        Computes the similarity between a query and a document. This method corresponds 
        to the function 'P' in the article.
        '''
        document_emb  = self.doc_emb(document)
        query_proj    = self.stand_projection(query)
        # If we have a single document to match against each query, we have
        # to reshape the tensor to compute a simple dot product.
        # Otherwise, we compute a simple matrix multiplication to match the 
        # query against each document.
        if len(document_emb.shape) == 2:
            document_emb = document_emb.unsqueeze(1)
        dot_product   = torch.bmm(document_emb, query_proj.unsqueeze(-1))
        similarity    = torch.sigmoid(dot_product)
        
        return similarity.squeeze()
    
    def proba_doc_query(self, query, document):
        '''
        Approximates the probability of document given query by uniformly sampling 
        constrastive examples. This method corresponds to the 'P^~' function in the 
        article.
        '''
        # Positive term, this should be maximized as it indicates how similar the
        # correct document is to the query
        pos_repr = self.representation_similarity(query, document)
        
        # Sampling uniformly 'self.neg_sampling_rate' documents to compute the 
        # negative term. We first randomly draw the indices of the documents and 
        # then we compute the similarity with the query.
        z               = self.neg_sampling_rate # corresponds to the z variable in 
                                                 # the article
        n_docs          = self.doc_emb.num_embeddings
        neg_sample_size = (query.size(0), z)
        neg_sample      = torch.randint(low = 0, high = n_docs, size = neg_sample_size)
        neg_repr        = self.representation_similarity(query, neg_sample)
        
        # Probability computation
        positive_term = torch.log(pos_repr)
        negative_term = torch.log(1 - neg_repr).sum(dim = 1)
        proba         = ((z + 1) / (2 * z)) * (z * positive_term + negative_term)
        
        return proba
    
    def forward(self, query, document_ids):
        pass

In [169]:
def loss_function(nvsm, pred, lamb):
    output_term = pred.mean()
    sum_square  = lambda m: (m.weight * m.weight).sum()
    reg_term    = sum_square(nvsm.tok_emb) + \
                  sum_square(nvsm.doc_emb) + \
                  sum_square(nvsm.tok_to_doc)
    loss        = -output_term + (lamb / (2 * pred.shape[0])) * reg_term
    
    return loss

In [170]:
nvsm = NVSM(
    n_doc       = 20, 
    n_tok       = 9, 
    dim_doc_emb = 10, 
    dim_tok_emb = 7,
    neg_sampling_rate = 4
)

In [171]:
l = [
        [1, 2, 3, 0, 0], 
        [4, 5, 6, 7, 8],
        [2, 7, 8, 3, 0]
]
query = torch.tensor(l)
document = torch.tensor([2,3,7])
print(query)
print(document)

tensor([[1, 2, 3, 0, 0],
        [4, 5, 6, 7, 8],
        [2, 7, 8, 3, 0]])
tensor([2, 3, 7])


In [172]:
pred = nvsm.representation_similarity(query, document)
pred

tensor([0.4251, 0.8351, 0.3350], grad_fn=<SqueezeBackward0>)

In [173]:
loss = loss_function(nvsm, pred, .1)
loss

tensor(3.8049, grad_fn=<AddBackward0>)