In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from gensim.models import KeyedVectors

from dataset import SentWordDataset
from config import BATCH_SIZE, WORD_HIDDEN_SIZE, SENT_HIDDEN_SIZE

In [2]:
class WordEncoder(nn.Module):
    def __init__(self, embedding_matrix, hidden_size):
        super(WordEncoder, self).__init__()
        embedding_dim = embedding_matrix.shape[1]
        self.embedding = nn.Embedding.from_pretrained(
            embeddings=torch.FloatTensor(embedding_matrix), freeze=True,
        )
        self.gru = nn.GRU(
            input_size=embedding_dim, hidden_size=hidden_size, bidirectional=True,
        )

    def forward(self, input, hidden_state):
        output = self.embedding(input)
        f_output, h_output = self.gru(output, hidden_state)
        return f_output, h_output

In [8]:
class SentEncoder(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(SentEncoder, self).__init__()
        self.gru = nn.GRU(
            input_size=input_size, hidden_size=hidden_size, bidirectional=True,
        )

    def forward(self, input, hidden_state):
        f_output, h_output = self.gru(input, hidden_state)
        return f_output, h_output

In [9]:
class Attention(nn.Module):
    def __init__(self, input_size):
        super(Attention, self).__init__()
        self.input_size = input_size
        self.fc = nn.Linear(self.input_size, self.input_size)
        self.context_vector = nn.Parameter(torch.randn(self.input_size))

    def forward(self, input):
        output = torch.tanh(self.fc(input))
        output = torch.matmul(output, self.context_vector)
        output = F.softmax(output, dim=1)
        output = output.permute(1, 0)
        input = input.permute(1, 0, 2)
        batch_size = input.shape[1]
        weighted_sum = torch.zeros(batch_size, self.input_size)
        for alpha, h in zip(output, input):
            alpha = alpha.unsqueeze(1).expand_as(h)
            weighted_sum += alpha * h
        return weighted_sum

In [13]:
class Han(nn.Module):
    def __init__(self, embedding_matrix, word_hidden_size, sent_hidden_size, num_classes, batch_size):
        super(Han, self).__init__()
        self.word_encoder = WordEncoder(embedding_matrix, word_hidden_size)
        self.word_attention = Attention(word_hidden_size * 2)
        self.sent_encoder = SentEncoder(word_hidden_size * 2, sent_hidden_size)
        self.sent_attention = Attention(sent_hidden_size * 2)
        self.fc = nn.Linear(sent_hidden_size * 2, num_classes)
        self.word_hidden_state = torch.zeros(2, batch_size, word_hidden_size)
        self.sent_hidden_state = torch.zeros(2, batch_size, sent_hidden_size)
    
    def forward(self, input):
        input = input.permute(1, 2, 0)
        nn.init.zeros_(self.sent_hidden_state)
        sent_encoder_outputs = []
        for sentence in input:
            nn.init.zeros_(self.word_hidden_state)
            word_encoder_outputs = []
            for word in sentence:
                # Add an empty dimension because the GRU needs a 3D input,
                # moreover this is the dimension where all the encoder
                # outputs will be concatenated
                word = word.unsqueeze(0)
                output, word_hidden_state = self.word_encoder(word, self.word_hidden_state)
                word_encoder_outputs.append(output)
            word_attn_input = torch.cat(word_encoder_outputs, dim=0)
            word_attn_input = word_attn_input.permute(1, 0, 2)
            output = self.word_attention(word_attn_input)
            # Add an empty dimension (as before)
            output = output.unsqueeze(0)
            output, sent_hidden_state = self.sent_encoder(output, self.sent_hidden_state)
            sent_encoder_outputs.append(output)
        sent_attn_input = torch.cat(sent_encoder_outputs, dim=0)
        sent_attn_input = sent_attn_input.permute(1, 0, 2)
        output = self.sent_attention(sent_attn_input)
        output = self.fc(output)
        output = F.log_softmax(output, dim=1)
        return output

In [14]:
wv = KeyedVectors.load("embedding/yelp.wv")
df = pd.read_csv("data/yelp_train_sample.csv").fillna("")
dataset = SentWordDataset(df.text, df.label, wv.vocab)
loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
model = Han(wv.vectors, WORD_HIDDEN_SIZE, SENT_HIDDEN_SIZE, 5, BATCH_SIZE)

In [16]:
iter_loader = iter(loader)
labels, features = next(iter_loader)
predictions = model(features)
predictions
#features = features.permute(1, 2, 0)
#sent_hidden_state = torch.zeros(2, BATCH_SIZE, SENT_HIDDEN_SIZE)
#sent_encoder_outputs = []
#for sentence in features:
#    word_hidden_state = torch.zeros(2, BATCH_SIZE, WORD_HIDDEN_SIZE)
#    word_encoder_outputs = []
#    for word in sentence:
#        # Add an empty dimension because the GRU needs a 3D input,
#        # moreover this is the dimension where all the encoder
#        # outputs will be concatenated
#        word = word.unsqueeze(0)
#        output, word_hidden_state = word_encoder(word, word_hidden_state)
#        word_encoder_outputs.append(output)
#    word_attn_input = torch.cat(word_encoder_outputs, dim=0)
#    word_attn_input = word_attn_input.permute(1, 0, 2)
#    output = word_attention(word_attn_input)
#    # Add an empty dimension (as before)
#    output = output.unsqueeze(0)
#    output, sent_hidden_state = sent_encoder(output, sent_hidden_state)
#    sent_encoder_outputs.append(output)
#sent_attn_input = torch.cat(sent_encoder_outputs, dim=0)
#sent_attn_input = sent_attn_input.permute(1, 0, 2)
#output = sent_attention(sent_attn_input)
#print(output.shape)

tensor([[-1.5154, -1.6213, -1.6352, -1.6340, -1.6473],
        [-1.5101, -1.6382, -1.6389, -1.6329, -1.6336],
        [-1.5246, -1.6430, -1.6746, -1.6095, -1.6018]],
       grad_fn=<LogSoftmaxBackward>)

In [17]:
t = torch.randn(5,4,3)
t

tensor([[[ 1.7711e+00,  5.5403e-01, -2.7802e-01],
         [ 9.7187e-01, -3.8668e-01, -9.7937e-01],
         [ 1.1483e+00,  1.4296e+00, -7.4583e-01],
         [ 1.8697e+00,  6.9198e-01, -1.0255e-01]],

        [[-2.5148e+00, -2.5694e-01, -7.5848e-01],
         [ 1.9415e+00,  9.1977e-01,  3.5009e-01],
         [-5.3225e-01, -1.0782e+00, -1.1308e+00],
         [ 9.1710e-01,  1.0315e-01, -1.0322e-01]],

        [[-1.9805e-03, -2.3041e-01,  2.1059e-01],
         [-1.4332e+00, -1.6309e+00, -3.9389e-01],
         [-1.8161e-01,  2.1778e-01,  1.2066e-01],
         [ 7.7609e-01,  4.9084e-01, -7.2770e-01]],

        [[-2.0876e-01,  9.7782e-01,  1.9496e+00],
         [ 1.4178e-01,  7.9491e-01, -5.2412e-01],
         [-7.2215e-01, -4.0194e-01,  9.1482e-01],
         [-1.3133e+00, -1.5092e+00,  1.1074e+00]],

        [[-2.2761e-01,  3.8097e-01,  4.7954e-01],
         [-1.0804e+00,  4.4056e-01,  6.7093e-01],
         [-3.2950e-01,  8.9058e-01, -4.0156e-01],
         [ 1.1965e+00, -6.1624e-01, -1.900

In [20]:
str(t.device)

'cpu'

In [None]:
for a, b in zip(t, h):
    a = a.unsqueeze(1).expand_as(b)
    print(a.shape)
    print(b.shape)
    print((a * b).shape)
    print((a * b))
    break