In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from gensim.models import KeyedVectors

from dataset import WordDataset
from han import WordEncoder, Attention
from config import BATCH_SIZE, DEVICE, WORD_HIDDEN_SIZE

In [5]:
class Wan(nn.Module):
    "Word-level Attention Network"

    def __init__(self, embedding_matrix, word_hidden_size, num_classes, batch_size):
        super(Wan, self).__init__()
        self.word_hidden_size = word_hidden_size
        self.word_encoder = WordEncoder(embedding_matrix, word_hidden_size)
        self.word_attention = Attention(word_hidden_size * 2)
        self.fc = nn.Linear(word_hidden_size * 2, num_classes)
        self.init_hidden_state(batch_size)

    def init_hidden_state(self, batch_size):
        self.word_hidden_state = torch.zeros(2, batch_size, self.word_hidden_size).to(
            DEVICE
        )

    def forward(self, input):
        input = input.permute(1, 0)
        self.word_hidden_state = torch.zeros_like(self.word_hidden_state).to(DEVICE)
        word_encoder_outputs = []
        for word in input:
            # Add an empty dimension because the GRU needs a 3D input,
            # moreover this is the dimension where all the encoder
            # outputs will be concatenated
            word = word.unsqueeze(0)
            output, self.word_hidden_state = self.word_encoder(word, self.word_hidden_state)
            word_encoder_outputs.append(output)
        word_attn_input = torch.cat(word_encoder_outputs, dim=0)
        word_attn_input = word_attn_input.permute(1, 0, 2)
        output = self.word_attention(word_attn_input)
        output = self.fc(output)
        output = F.log_softmax(output, dim=1)
        return output

In [6]:
wv = KeyedVectors.load("embedding/yelp.wv")
df = pd.read_csv("data/yelp_minisample.csv").fillna("")
dataset = WordDataset(df.text, df.label, wv.vocab)
loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
model = Wan(wv.vectors, WORD_HIDDEN_SIZE, 5, BATCH_SIZE)

In [7]:
i = iter(loader)
labels, features = next(i)
model.init_hidden_state(len(labels))
predictions = model(features)

torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5, 200])
torch.Size([1, 5

In [None]:
torch.zeros