In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from gensim.models import KeyedVectors

from dataset import SentWordDataset
from config import BATCH_SIZE

In [16]:
class WordEncoder(nn.Module):
    def __init__(self, embedding_matrix, hidden_size):
        super(WordEncoder, self).__init__()
        embedding_dim = embedding_matrix.shape[1]
        self.embedding = nn.Embedding.from_pretrained(
            embeddings=torch.FloatTensor(embedding_matrix), freeze=True,
        )
        self.gru = nn.GRU(
            input_size=embedding_dim, hidden_size=hidden_size, bidirectional=True,
        )

    def forward(self, input, hidden_state):
        output = self.embedding(input)
        f_output, h_output = self.gru(output, hidden_state)
        return f_output, h_output

In [50]:
class WordAttention(nn.Module):
    def __init__(self, input_size):
        super(WordAttention, self).__init__()
        self.fc = nn.Linear(input_size, input_size)
        self.context_vector = nn.Parameter(torch.randn(input_size))
    
    def forward(self, input):
        output = torch.tanh(self.fc(input))
        print(output.shape)
        output = torch.matmul(output, self.context_vector)
        print(output.shape) 
        return output

In [51]:
wv = KeyedVectors.load("embedding/yelp.wv")
df = pd.read_csv("data/yelp_train_sample.csv").fillna("")
dataset = SentWordDataset(df.text, df.label, wv.vocab)
loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
word_encoder = WordEncoder(embedding_matrix=wv.vectors, hidden_size=50)
word_attention = WordAttention(input_size=50*2)

In [52]:
iter_loader = iter(loader)
labels, features = next(iter_loader)
features = features.permute(1, 2, 0)
for sentence in features:
    word_hidden_state = torch.zeros(2, BATCH_SIZE, 50)
    word_encoder_outputs = []
    for word in sentence:
        # Add an empty dimension because the GRU needs a 3D input,
        # moreover this is the dimension where all the encoder
        # outputs will be concatenated
        word = word.unsqueeze(0)  
        output, word_hidden_state = word_encoder(word, word_hidden_state)
        word_encoder_outputs.append(output)
    attn_input = torch.cat(word_encoder_outputs, dim=0)
    attn_input = attn_input.permute(1, 0, 2)
    print(attn_input.shape)
    output = word_attention(attn_input)
    break

torch.Size([3, 23, 100])
torch.Size([3, 23, 100])
torch.Size([3, 23])


In [49]:
t = torch.FloatTensor(range(10)).expand(3, 23, 10)
u = torch.ones(10)
torch.matmul(t, u).shape

torch.Size([3, 23])

In [46]:
t = torch.tensor(range(10)).expand(3, 23, 10)
t.shape

torch.Size([3, 23, 10])

In [44]:
sum(range(10))

45