In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
sys.path.append('../')
from constant import *

In [68]:
class Attention(nn.Module):
    """
    Inputs:
        last_hidden: (batch_size, hidden_size)
        encoder_outputs: (batch_size, max_time, hidden_size)
    Returns:
        attention_weights: (batch_size, max_time)
    """
    def __init__(self,hidden_size, method="general"):
        super(Attention, self).__init__()
        self.method = method
        self.hidden_size = hidden_size
        if method == 'dot':
            pass
        elif method == 'general':
            self.Wa = nn.Linear(hidden_size, hidden_size/2, bias=False)
            #self.Wa = nn.Linear(hidden_size, (hidden_size/4), bias=False)
        elif method == "concat":
            self.Wa = nn.Linear(hidden_size, hidden_size, bias=False)
            self.va = nn.Parameter(torch.FloatTensor(batch_size, hidden_size))
        elif method == 'bahdanau':
            self.Wa = nn.Linear(hidden_size, hidden_size, bias=False)
            self.Ua = nn.Linear(hidden_size, hidden_size, bias=False)
            self.va = nn.Parameter(torch.FloatTensor(batch_size, hidden_size))
        else:
            raise NotImplementedError

    def forward(self, last_hidden, encoder_outputs, encoder_batch_len):
        batch_size, seq_lens, _ = encoder_outputs.size()

        attention_energies = self._score(last_hidden, encoder_outputs, self.method)
        
        # masking
        maxlen = encoder_outputs.size(1)
        mask = torch.arange(maxlen)[None, :] < encoder_batch_len[:, None]
        attention_energies[~mask] = float('-inf')
        return F.softmax(attention_energies, -1)

    def _score(self, last_hidden, encoder_outputs ,method):
        """
        Computes an attention score
        :param last_hidden: (batch_size, hidden_dim)
        :param encoder_outputs: (batch_size, max_time, hidden_dim)
        :param method: str (`dot`, `general`, `concat`, `bahdanau`)
        :return: a score (batch_size, max_time)
        """

        #assert encoder_outputs.size()[-1] == self.hidden_size

        if method == 'dot':
            last_hidden = last_hidden.unsqueeze(-1)
            return encoder_outputs.bmm(last_hidden).squeeze(-1)

        elif method == 'general':
            x = self.Wa(last_hidden)
            x = x.unsqueeze(-1)
            print x.size(),encoder_outputs.size()
            return encoder_outputs.bmm(x).squeeze(-1)

        elif method == "concat":
            x = last_hidden.unsqueeze(1)
            x = F.tanh(self.Wa(torch.cat((x, encoder_outputs), 1)))
            return x.bmm(self.va.unsqueeze(2)).squeeze(-1)

        elif method == "bahdanau":
            x = last_hidden.unsqueeze(1)
            out = F.tanh(self.Wa(x) + self.Ua(encoder_outputs))
            return out.bmm(self.va.unsqueeze(2)).squeeze(-1)

        else:
            raise NotImplementedError

In [27]:
from load_data_exp import *
from encoder.encoder import *

In [44]:
for i in train_dataloader:
    break

In [45]:
i[3]

tensor([48, 48, 48, 45, 39, 34, 34, 33, 33, 33, 31, 31, 27, 27, 27, 24, 21, 21,
        20, 19, 16, 16, 16, 16, 15, 15, 13, 12, 10,  7,  7,  7])

In [46]:
enc = EncoderSentence(len(word_mapping)+1,WORD_DIM,128,pretrained_word_embeds,'sum')
e,f = enc(i[1],i[3])

In [47]:
e.size()

torch.Size([32, 48, 128])

In [51]:
last_hidden = torch.rand(32,128)

In [52]:
last_hidden.size()

torch.Size([32, 128])

In [74]:
attn = Attention(128,method='dot')
a = attn(last_hidden,e,i[3])

In [75]:
a.size()

torch.Size([32, 48])

In [73]:
a[-1]

tensor([0.0818, 0.1367, 0.1534, 0.1557, 0.1519, 0.1632, 0.1573, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000], grad_fn=<SelectBackward>)

In [76]:
a[-1]

tensor([0.0337, 0.1525, 0.1803, 0.1508, 0.1615, 0.1803, 0.1410, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000], grad_fn=<SelectBackward>)

------------------