In [1]:
from torchtext import data
from torchtext import datasets
from torchtext.vocab import GloVe
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch
from tqdm import tqdm

In [34]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 64
embedding_dim = 200
hidden_dim = 200
epochs = 5

In [24]:
 # define Field
TEXT = data.ReversibleField(lower=True, include_lengths=True)
LABEL = data.Field(sequential=False)
# make splits for data
train, test = datasets.IMDB.splits(TEXT, LABEL)
# build the vocabulary
TEXT.build_vocab(train, vectors=GloVe(name='6B', dim=embedding_dim))
LABEL.build_vocab(train)

train_iter, test_iter = data.BucketIterator.splits(
        (train, test), sort_key=lambda x:len(x.text),
        sort_within_batch=True, 
        batch_size=batch_size, device=device,
        repeat=False)

100%|██████████| 400000/400000 [00:27<00:00, 14621.98it/s]


In [25]:
class SelfAttention(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.projection = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ReLU(True),
            nn.Linear(64, 1)
        )

    def forward(self, encoder_outputs):
        batch_size = encoder_outputs.size(0)
        # (B, L, H) -> (B , L, 1)
        energy = self.projection(encoder_outputs)
        weights = F.softmax(energy.squeeze(-1), dim=1)
        # (B, L, H) * (B, L, 1) -> (B, H)
        outputs = (encoder_outputs * weights.unsqueeze(-1)).sum(dim=1)
        return outputs, weights

class AttnClassifier(nn.Module):
    def __init__(self, input_dim, embedding_dim, hidden_dim):
        super().__init__()
        self.input_dim = input_dim
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.embedding = nn.Embedding(input_dim, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, bidirectional=True)
        self.attention = SelfAttention(hidden_dim)
        self.fc = nn.Linear(hidden_dim, 1)
        
    def set_embedding(self, vectors):
        self.embedding.weight.data.copy_(vectors)
        
    def forward(self, inputs, lengths):
        batch_size = inputs.size(1)
        # (L, B)
        embedded = self.embedding(inputs)
        # (L, B, E)
        packed_emb = nn.utils.rnn.pack_padded_sequence(embedded, lengths)
        out, hidden = self.lstm(packed_emb)
        out = nn.utils.rnn.pad_packed_sequence(out)[0]
        out = out[:, :, :self.hidden_dim] + out[:, :, self.hidden_dim:]
        # (L, B, H)
        embedding, attn_weights = self.attention(out.transpose(0, 1))
        # (B, HOP, H)
        outputs = self.fc(embedding.view(batch_size, -1))
        # (B, 1)
        return outputs, attn_weights

In [26]:
def train(train_iter, model, optimizer, criterion):
    model.train()
    epoch_loss = 0
    bar = tqdm(total=len(train_iter))
    b_ix = 1
    for batch in train_iter:
        (x, x_l), y = batch.text, batch.label - 1
        optimizer.zero_grad()
        outputs, _ = model(x, x_l)
        loss = criterion(outputs.view(-1), y.float())
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        if b_ix % 10 == 0:
            bar.update(10)
            bar.set_description('current loss:{:.4f}'.format(epoch_loss / b_ix))
        b_ix += 1
    bar.update((b_ix - 1) % 10)
    bar.close()
    return epoch_loss / len(train_iter)

In [35]:
model = AttnClassifier(len(TEXT.vocab), embedding_dim, hidden_dim).to(device)
model.set_embedding(TEXT.vocab.vectors)
# optim
optimizer = optim.Adam(model.parameters())
criterion = nn.BCEWithLogitsLoss().to(device)
# train model 
for epoch in range(epochs):
    train(train_iter, model, optimizer, criterion)

current loss:0.3752: 100%|██████████| 391/391 [00:29<00:00, 13.20it/s]
current loss:0.1868: 100%|██████████| 391/391 [00:28<00:00, 13.71it/s]
current loss:0.0909: 100%|██████████| 391/391 [00:27<00:00, 14.20it/s]
current loss:0.0266: 100%|██████████| 391/391 [00:30<00:00, 13.01it/s]
current loss:0.0075: 100%|██████████| 391/391 [00:30<00:00, 12.78it/s]


In [36]:
def binary_accuracy(preds, y):
    # round predictions to the closest integer
    rounded_preds = torch.round(F.sigmoid(preds))
    correct = (rounded_preds == y).float()  # convert into float for division
    acc = correct.sum() / len(correct)
    return acc

def accuracy(model, test_iter):
    model.eval()
    total_acc = 0
    for batch in test_iter:
        (x, x_l), y = batch.text, batch.label - 1
        outputs,_ = model(x, x_l)
        total_acc += binary_accuracy(outputs.view(-1), y.float()).item()
    return total_acc / len(test_iter)

In [37]:
print(accuracy(model, test_iter))

0.8739050512423601


In [38]:
def highlight(word, attn):
    html_color = '#%02X%02X%02X' % (255, int(255*(1 - attn)), int(255*(1 - attn)))
    return '<span style="background-color: {}">{}</span>'.format(html_color, word)

def mk_html(seq, attns):
    html = ""
    for ix, attn in zip(seq, attns):
        html += ' ' + highlight(
            TEXT.vocab.itos[ix],
            attn
        )
    return html + "<br><br>\n"

In [43]:
from IPython.display import HTML, display
with torch.no_grad():
    for batch in test_iter:
        x, x_l = batch.text
        y = batch.label - 1
        outputs, attn_weights = model(x, x_l)
        for i in range(batch_size):
            if torch.round(F.sigmoid(outputs[i])) != y[i].float():
                text = mk_html(x.t()[i].cpu().numpy(), attn_weights[i].cpu().numpy())
                display(HTML(text))
        break