In [16]:
import pandas as pd
import torch
import torch.nn as nn
import torchtext

In [128]:
class NVSM(nn.Module):
    def __init__(self, n_doc, n_tok, dim_doc_emb, dim_tok_emb, neg_sampling_rate):
        super(NVSM, self).__init__()
        self.doc_emb           = nn.Embedding(n_doc, embedding_dim = dim_doc_emb)
        self.tok_emb           = nn.Embedding(n_tok, embedding_dim = dim_tok_emb)
        self.tok_to_doc        = nn.Linear(dim_doc_emb, dim_tok_emb)
        self.neg_sampling_rate = neg_sampling_rate
        
    def query_to_tensor(self, query):
        '''
        Computes the average of the word embeddings of the query. This method 
        corresponds to the function 'g' in the article.
        '''
        # Create a mask to ignore padding embeddings
        query_mask    = (query != 0).float()
        # Compute the number of tokens in each query to properly compute the 
        # average
        tok_by_input  = query_mask.sum(dim = 1)
        query_tok_emb = self.tok_emb(query)
        query_tok_emb = query_tok_emb * query_mask.unsqueeze(-1)
        # Compute the average of the embeddings
        query_emb     = query_tok_emb.sum(dim = 1) / tok_by_input.unsqueeze(-1)
        
        return query_emb
    
    def normalize_query_tensor(self, query_tensor):
        '''
        Divide each query tensor by its L2 norm. This method corresponds to 
        the function 'norm' in the article.
        '''
        norm = torch.norm(query_tensor, dim = 1) # we might have to detach this value 
                                                 # from the computation graph.
        return query_tensor / norm.unsqueeze(-1)
        
    def query_to_doc_space(self, query):
        '''
        Project a query vector into the document vector space. This method corresponds 
        to the function 'f' in the paper.
        '''
        return self.linear(query)
        
    def forward(self, query, document_ids):
        pass

In [92]:
l = [
        [1, 2, 3, 0, 0], 
        [4, 5, 6, 7, 8]
]
t = torch.tensor(l)

In [95]:
nvsm = NVSM(
    n_doc       = 20, 
    n_tok       = 9, 
    dim_doc_emb = 10, 
    dim_tok_emb = 7
)

In [96]:
nvsm.query_to_tensor(t)

query_mask tensor([[1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1.]])
tok_by_input tensor([3., 5.])
query_tok_emb.shape torch.Size([2, 5, 7])
raw query_tok_emb tensor([[[ 6.8339e-01,  1.1856e-01,  5.2280e-01, -1.3694e+00, -1.7462e-01,
           5.6550e-01, -3.3361e-01],
         [-8.7057e-01,  1.4503e-01, -1.1387e+00,  1.2351e-01, -1.1690e+00,
          -1.1901e+00, -4.3799e-01],
         [ 2.5456e-01, -1.5504e-02,  2.4088e-01,  4.0717e-01, -2.7572e-02,
           3.8714e-02,  1.5409e-01],
         [-1.5461e-02, -6.2406e-02,  1.9464e+00,  2.8442e-01,  5.5147e-01,
          -7.0005e-01,  1.8206e-02],
         [-1.5461e-02, -6.2406e-02,  1.9464e+00,  2.8442e-01,  5.5147e-01,
          -7.0005e-01,  1.8206e-02]],

        [[-2.1760e-01,  7.9879e-02,  7.1749e-01,  5.0667e-01,  7.5838e-01,
          -7.7784e-02, -1.3649e+00],
         [ 8.4118e-01, -2.6891e-02,  3.9473e-01,  1.9589e-01, -2.1137e-03,
          -1.0450e+00, -1.5264e+00],
         [-4.0436e-01, -3.6638e-01,  1.0578e+00, -1.

tensor([[ 0.0225,  0.0827, -0.1250, -0.2796, -0.4571, -0.1953, -0.2058],
        [ 0.0950, -0.5259,  0.2783, -0.1781, -0.0685, -0.6384, -0.2491]],
       grad_fn=<DivBackward0>)

In [66]:
mask = (t != 0).float()

In [78]:
t_by_input = mask.sum(dim = 1)

In [68]:
emb = nn.Embedding(9, 10)

In [69]:
t_emb = emb(t)
t_emb.shape

torch.Size([2, 5, 10])

In [70]:
t_emb

tensor([[[-0.1371, -1.0148, -1.0533,  0.4884,  1.2920, -0.4554,  0.2100,
           0.3689, -0.8928, -0.9624],
         [-0.1229, -1.5732,  0.4160,  1.3271,  0.7516,  0.7500,  1.8796,
           1.1547, -0.0250,  0.7064],
         [-0.5257, -0.1013, -0.7273, -0.4093, -1.2507, -0.0586, -0.7206,
          -1.8130,  0.7023, -0.1104],
         [-0.4782, -0.1574,  0.6121,  0.0597,  0.1085,  0.1639,  1.5457,
           0.2678, -0.9187,  0.9393],
         [-0.4782, -0.1574,  0.6121,  0.0597,  0.1085,  0.1639,  1.5457,
           0.2678, -0.9187,  0.9393]],

        [[ 0.9484, -0.8201, -1.2047, -1.3277,  1.9986,  0.4751, -1.0518,
          -0.0499, -1.2592,  0.1978],
         [-0.5886, -0.1557, -1.6149, -1.3309, -0.9811,  2.5172,  2.1142,
           0.4247,  0.3768,  0.1469],
         [-1.9336,  1.9258,  0.3230,  0.1779,  0.6453,  2.0279,  0.3779,
          -0.2205, -0.7898, -1.2786],
         [ 1.6658,  0.7742,  1.5495,  0.4461, -1.0913,  0.7044, -0.9092,
          -0.8003, -0.2587,  0.4214],

In [71]:
t_emb * mask.unsqueeze(-1)

tensor([[[-0.1371, -1.0148, -1.0533,  0.4884,  1.2920, -0.4554,  0.2100,
           0.3689, -0.8928, -0.9624],
         [-0.1229, -1.5732,  0.4160,  1.3271,  0.7516,  0.7500,  1.8796,
           1.1547, -0.0250,  0.7064],
         [-0.5257, -0.1013, -0.7273, -0.4093, -1.2507, -0.0586, -0.7206,
          -1.8130,  0.7023, -0.1104],
         [-0.0000, -0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000, -0.0000,  0.0000],
         [-0.0000, -0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000, -0.0000,  0.0000]],

        [[ 0.9484, -0.8201, -1.2047, -1.3277,  1.9986,  0.4751, -1.0518,
          -0.0499, -1.2592,  0.1978],
         [-0.5886, -0.1557, -1.6149, -1.3309, -0.9811,  2.5172,  2.1142,
           0.4247,  0.3768,  0.1469],
         [-1.9336,  1.9258,  0.3230,  0.1779,  0.6453,  2.0279,  0.3779,
          -0.2205, -0.7898, -1.2786],
         [ 1.6658,  0.7742,  1.5495,  0.4461, -1.0913,  0.7044, -0.9092,
          -0.8003, -0.2587,  0.4214],

In [72]:
tok_embs = t_emb * mask.unsqueeze(-1)

In [81]:
ratata = tok_embs.sum(dim = 1) 

In [83]:
t_by_input.shape

torch.Size([2])

In [82]:
ratata.shape

torch.Size([2, 10])

In [84]:
ratata 

tensor([[-0.7856, -2.6893, -1.3645,  1.4062,  0.7928,  0.2359,  1.3690, -0.2894,
         -0.2155, -0.3665],
        [-0.0931,  1.4814, -1.4352, -1.9225,  0.5221,  3.9315,  1.3737, -1.3703,
         -0.4889, -1.1191]], grad_fn=<SumBackward2>)

In [85]:
ratata / t_by_input.unsqueeze(-1)

tensor([[-0.2619, -0.8964, -0.4548,  0.4687,  0.2643,  0.0786,  0.4563, -0.0965,
         -0.0718, -0.1222],
        [-0.0186,  0.2963, -0.2870, -0.3845,  0.1044,  0.7863,  0.2747, -0.2741,
         -0.0978, -0.2238]], grad_fn=<DivBackward0>)

In [86]:
t_by_input

tensor([3., 5.])

In [111]:
tens = torch.tensor([
    [1., 1.],
    [1., 0.],
    [0., 1.]
])
tens

tensor([[1., 1.],
        [1., 0.],
        [0., 1.]])

In [112]:
tens_norm = torch.norm(tens, dim = 1)
tens_norm

tensor([1.4142, 1.0000, 1.0000])

In [113]:
tens / tens_norm

RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1

In [114]:
tens_norm

tensor([1.4142, 1.0000, 1.0000])

In [115]:
tens_norm.expand_as(tens)

RuntimeError: The expanded size of the tensor (2) must match the existing size (3) at non-singleton dimension 1.  Target sizes: [3, 2].  Tensor sizes: [3]

In [116]:
tens_norm.expand_as?

In [117]:
tens_norm.expand?

In [120]:
tens.requires_grad = True

In [127]:
tens / torch.norm(tens, dim = 1).unsqueeze(-1)

tensor([[0.7071, 0.7071],
        [1.0000, 0.0000],
        [0.0000, 1.0000]], grad_fn=<DivBackward0>)