In [1]:
import torch
from gensim.models import KeyedVectors

import han
from data_iterator import DataIterator
from utils import load_pickled_obj

In [2]:
dataset = load_pickled_obj("data/amazon_train_sample.pkl")
wv = KeyedVectors.load("embedding/amazon_sample.wv")
vocab_size = len(wv.vocab)
embedding_dim = wv.vector_size
embedding_matrix = wv.vectors

In [3]:
class WordAttention(torch.nn.Module):
    def __init__(self, input_size=100):
        super(WordAttention, self).__init__()
        self.context_vector = torch.nn.Parameter(torch.Tensor(input_size, 1))
        self.linear = torch.nn.Linear(in_features=input_size, out_features=input_size)

    def forward(self, input):
        output = torch.tanh(self.linear(input))
        output = torch.matmul(output, self.context_vector).squeeze()
        output = torch.nn.functional.softmax(output, dim=-1)
        out = []
        return sum(
            alpha.unsqueeze(1).expand_as(h) * h for alpha, h in zip(output, input)
        )

In [5]:
word_encoder = han.WordEncoder(embedding_matrix)
word_attention = WordAttention()
it = iter(DataIterator(dataset, wv.vocab))
labels, batch = next(it)
batch = torch.LongTensor(batch)
output_list = []
batch = batch.permute(1, 0, 2)
for k, i in enumerate(batch):
    i = i.permute(1, 0)
    output, _ = word_encoder(i, torch.zeros(2, 64, 50))
    output = word_attention(output)
    output_list.append(output)
output = torch.cat(output_list, 0)
output.shape

(64, 240, 30)


torch.Size([15360, 100])