# BiLSTM for PoS Tagging

A multi-layer bi-directional LSTM for Part-of-Speech (PoS) Tagging.  

In [1]:
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

SEED = 515
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

## Preparing Data

The dataset is Universal Dependencies English Web Treebank (UDPOS).  
This dataset actually has two different sets of tags, [universal dependency (UD) tags](https://universaldependencies.org/u/pos/) and [Penn Treebank (PTB) tags](https://www.sketchengine.eu/penn-treebank-tagset/).  

In [2]:
from torchtext.data import Field, BucketIterator

TEXT = Field(lower=True, include_lengths=True)
# Because the set of possible tags is finite, do NOT use unknown token for it. 
UD_TAGS = Field(unk_token=None, include_lengths=True)
PTB_TAGS = Field(unk_token=None, include_lengths=True)

In [3]:
from torchtext.datasets import UDPOS

fields = [('text', TEXT), ('udtags', UD_TAGS), ('ptbtags', PTB_TAGS)]
train_data, valid_data, test_data = UDPOS.splits(fields=fields, root='data/')

In [4]:
print(train_data[0].text)
print(train_data[0].udtags)
print(train_data[0].ptbtags)

['al', '-', 'zaman', ':', 'american', 'forces', 'killed', 'shaikh', 'abdullah', 'al', '-', 'ani', ',', 'the', 'preacher', 'at', 'the', 'mosque', 'in', 'the', 'town', 'of', 'qaim', ',', 'near', 'the', 'syrian', 'border', '.']
['PROPN', 'PUNCT', 'PROPN', 'PUNCT', 'ADJ', 'NOUN', 'VERB', 'PROPN', 'PROPN', 'PROPN', 'PUNCT', 'PROPN', 'PUNCT', 'DET', 'NOUN', 'ADP', 'DET', 'NOUN', 'ADP', 'DET', 'NOUN', 'ADP', 'PROPN', 'PUNCT', 'ADP', 'DET', 'ADJ', 'NOUN', 'PUNCT']
['NNP', 'HYPH', 'NNP', ':', 'JJ', 'NNS', 'VBD', 'NNP', 'NNP', 'NNP', 'HYPH', 'NNP', ',', 'DT', 'NN', 'IN', 'DT', 'NN', 'IN', 'DT', 'NN', 'IN', 'NNP', ',', 'IN', 'DT', 'JJ', 'NN', '.']


In [5]:
TEXT.build_vocab(train_data, min_freq=2, 
                 vectors="glove.6B.100d", vectors_cache="vector_cache", 
                 unk_init=torch.Tensor.normal_)

UD_TAGS.build_vocab(train_data)
PTB_TAGS.build_vocab(train_data)

len(TEXT.vocab), len(UD_TAGS.vocab), len(PTB_TAGS.vocab)

(8866, 18, 51)

In [6]:
BATCH_SIZE = 128

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size=BATCH_SIZE, device=device)

In [7]:
for batch in train_iterator:
    batch_text, batch_text_lens = batch.text
    batch_tags, batch_tags_lens = batch.udtags
    break

print(batch_text)
print(batch_text_lens)
print(batch_tags)
print(batch_tags_lens)

tensor([[  27,   56,  116,  ...,  127,    9, 3715],
        [  12,  244,    4,  ...,    4,   76,    1],
        [  73,   13,    1,  ...,    1, 1904,    1],
        ...,
        [   1,    1,    1,  ...,    1,    1,    1],
        [   1,    1,    1,  ...,    1,    1,    1],
        [   1,    1,    1,  ...,    1,    1,    1]], device='cuda:0')
tensor([19, 16,  2, 20, 44, 11, 29, 13, 10, 38, 22, 71, 17,  7, 15, 12,  7, 10,
        12, 29, 20,  5, 42, 20, 25, 11, 11,  4, 22, 16, 31, 28,  2, 24, 60, 18,
         4,  7,  4, 17, 26, 38, 34,  5,  2,  6,  1,  4, 23, 24, 33,  9, 16,  1,
        20, 27, 26, 23, 20, 13, 14, 20, 29, 14,  7, 13,  6, 23, 15, 11, 14, 27,
        31, 18,  2, 38, 52,  2,  2,  5,  7, 22,  7, 12, 16, 12,  5, 42, 18, 19,
        15,  8, 11, 13,  3, 33,  7,  4,  7,  1, 25, 48, 20, 11,  2, 26, 22, 19,
        21,  4, 12,  9, 33, 16, 15, 25, 10, 36,  3,  9,  5, 20, 17, 14,  4,  2,
        19,  1], device='cuda:0')
tensor([[14, 13,  8,  ...,  1,  4,  7],
        [ 4,  1,  2,  .

## Building the Model

A Seq2Seq model  
* The elements in two sequences are not matched one by one  
* The two sequences may have different lengths  

A PoS-tagger  
* The elements in two sequences are strictly matched one by one  
* The two sequences have definitely the same length  

![BiLSTM for PoS Tagging](fig/BiLSTM-for-PoS-Tagging.png)

In [8]:
class PoSTagger(nn.Module):
    def __init__(self, voc_dim, emb_dim, hid_dim, tag_dim, n_layers, bidirect, dropout, pad_idx):
        super().__init__()
        self.emb = nn.Embedding(voc_dim, emb_dim, padding_idx=pad_idx)
        self.rnn = nn.LSTM(emb_dim, hid_dim, num_layers=n_layers, 
                           bidirectional=bidirect, dropout=dropout)
        self.fc = nn.Linear(hid_dim*2 if bidirect else hid_dim, tag_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, text, text_lens):
        # text: (step, batch)
        embedded = self.dropout(self.emb(text))
        # Pack sequence
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_lens, enforce_sorted=False)
        # hidden: (num_layers*num_directions, batch, hid_dim)
        packed_outs, (hidden, cell) = self.rnn(packed_embedded)
        # Unpack sequence
        # outs: (step, batch, hid_dim)
        outs, out_lens = nn.utils.rnn.pad_packed_sequence(packed_outs)

        # preds: (step, batch, tag_dim)
        preds = self.fc(self.dropout(outs))
        return preds

In [9]:
VOC_DIM = len(TEXT.vocab)
EMB_DIM = 100
HID_DIM = 128
TAG_DIM = len(UD_TAGS.vocab)

N_LAYERS = 2
BIDIRECT = True
DROPOUT = 0.25
TEXT_PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]
TAG_PAD_IDX = UD_TAGS.vocab.stoi[UD_TAGS.pad_token]


tagger = PoSTagger(VOC_DIM, EMB_DIM, HID_DIM, TAG_DIM, 
                   N_LAYERS, BIDIRECT, DROPOUT, TEXT_PAD_IDX).to(device)
preds = tagger(batch_text, batch_text_lens)

print(batch_text.size())
print(preds.size())

torch.Size([71, 128])
torch.Size([71, 128, 18])


In [10]:
# The padding positions have idential values 
preds[:, batch_text_lens.argmin()]

tensor([[-0.0345, -0.0701,  0.0183,  ..., -0.0255,  0.0777,  0.0155],
        [-0.0354, -0.0426, -0.0365,  ...,  0.0001,  0.0534,  0.0074],
        [-0.0354, -0.0426, -0.0365,  ...,  0.0001,  0.0534,  0.0074],
        ...,
        [-0.0354, -0.0426, -0.0365,  ...,  0.0001,  0.0534,  0.0074],
        [-0.0354, -0.0426, -0.0365,  ...,  0.0001,  0.0534,  0.0074],
        [-0.0354, -0.0426, -0.0365,  ...,  0.0001,  0.0534,  0.0074]],
       device='cuda:0', grad_fn=<SelectBackward>)

In [11]:
# Check if data are mixed across different samples in a batch.
tagger.eval()
preds_012 = tagger(batch_text[:, 0:3], batch_text_lens[0:3])
preds_123 = tagger(batch_text[:, 1:4], batch_text_lens[1:4])

step = min(preds_012.size(0), preds_123.size(0))
(preds_012[:step, 1:] - preds_123[:step, :2]).abs().max()

tensor(1.4901e-08, device='cuda:0', grad_fn=<MaxBackward1>)

## Training the Model

In [12]:
def init_weights(m):
    for name, param in m.named_parameters():
        nn.init.normal_(param.data, mean=0, std=0.1)

def count_parameters(model: nn.Module):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


tagger = PoSTagger(VOC_DIM, EMB_DIM, HID_DIM, TAG_DIM, 
                   N_LAYERS, BIDIRECT, DROPOUT, TEXT_PAD_IDX).to(device)

tagger.apply(init_weights)
print(f'The model has {count_parameters(tagger):,} trainable parameters')

The model has 1,522,010 trainable parameters


In [13]:
# Initialize Embeddings with Pre-Trained Vectors
print(TEXT.vocab.vectors.size())
print(tagger.emb.weight.size())

tagger.emb.weight.data.copy_(TEXT.vocab.vectors)

TEXT_UNK_IDX = TEXT.vocab.stoi[TEXT.unk_token]
tagger.emb.weight.data[TEXT_UNK_IDX].zero_()
tagger.emb.weight.data[TEXT_PAD_IDX].zero_()

print(tagger.emb.weight[:5, :8])

torch.Size([8866, 100])
torch.Size([8866, 100])
tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.0382, -0.2449,  0.7281, -0.3996,  0.0832,  0.0440, -0.3914,  0.3344],
        [-0.3398,  0.2094,  0.4635, -0.6479, -0.3838,  0.0380,  0.1713,  0.1598],
        [-0.1077,  0.1105,  0.5981, -0.5436,  0.6740,  0.1066,  0.0389,  0.3548]],
       device='cuda:0', grad_fn=<SliceBackward>)


In [14]:
loss_func = nn.CrossEntropyLoss(ignore_index=TAG_PAD_IDX, reduction='mean')
optimizer = optim.AdamW(tagger.parameters())

In [15]:
def train_epoch(tagger, iterator, optimizer, loss_func):
    tagger.train()
    epoch_loss = 0
    epoch_acc = 0
    for batch in iterator:
        # Forward pass
        text, text_lens = batch.text
        tags, tags_lens = batch.udtags
        preds = tagger(text, text_lens)

        # Calculate loss
        preds_flattened = preds.view(-1, preds.size(-1))
        tags_flattened = tags.flatten()
        loss = loss_func(preds_flattened, tags_flattened)

        # Backward propagation
        optimizer.zero_grad()
        loss.backward()
        # Update weights
        optimizer.step()
        # Accumulate loss and acc
        epoch_loss += loss.item()
        non_padding = (tags_flattened != loss_func.ignore_index)
        epoch_acc += (preds_flattened.argmax(dim=-1) == tags_flattened)[non_padding].sum().item() / non_padding.sum().item()
    return epoch_loss/len(iterator), epoch_acc/len(iterator)

def eval_epoch(tagger, iterator, loss_func):
    tagger.eval()
    epoch_loss = 0
    epoch_acc = 0
    with torch.no_grad():
        for batch in iterator:
            # Forward pass
            text, text_lens = batch.text
            tags, tags_lens = batch.udtags
            preds = tagger(text, text_lens)

            # Calculate loss
            preds_flattened = preds.view(-1, preds.size(-1))
            tags_flattened = tags.flatten()
            loss = loss_func(preds_flattened, tags_flattened)
            
            # Accumulate loss and acc
            epoch_loss += loss.item()
            non_padding = (tags_flattened != loss_func.ignore_index)
            epoch_acc += (preds_flattened.argmax(dim=-1) == tags_flattened)[non_padding].sum().item() / non_padding.sum().item()
    return epoch_loss/len(iterator), epoch_acc/len(iterator)

In [16]:
import time
N_EPOCHS = 10
best_valid_loss = np.inf

for epoch in range(N_EPOCHS):
    t0 = time.time()
    train_loss, train_acc = train_epoch(tagger, train_iterator, optimizer, loss_func)
    valid_loss, valid_acc = eval_epoch(tagger, valid_iterator, loss_func)
    epoch_secs = time.time() - t0

    epoch_mins, epoch_secs = int(epoch_secs // 60), int(epoch_secs % 60)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(tagger.state_dict(), 'models/tut1-model.pt')
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

Epoch: 01 | Epoch Time: 0m 6s
	Train Loss: 1.307 | Train Acc: 59.53%
	 Val. Loss: 0.643 |  Val. Acc: 81.14%
Epoch: 02 | Epoch Time: 0m 7s
	Train Loss: 0.466 | Train Acc: 85.49%
	 Val. Loss: 0.470 |  Val. Acc: 85.72%
Epoch: 03 | Epoch Time: 0m 6s
	Train Loss: 0.340 | Train Acc: 89.29%
	 Val. Loss: 0.422 |  Val. Acc: 85.75%
Epoch: 04 | Epoch Time: 0m 6s
	Train Loss: 0.282 | Train Acc: 91.09%
	 Val. Loss: 0.391 |  Val. Acc: 86.64%
Epoch: 05 | Epoch Time: 0m 6s
	Train Loss: 0.245 | Train Acc: 92.22%
	 Val. Loss: 0.372 |  Val. Acc: 88.05%
Epoch: 06 | Epoch Time: 0m 6s
	Train Loss: 0.218 | Train Acc: 93.07%
	 Val. Loss: 0.356 |  Val. Acc: 88.34%
Epoch: 07 | Epoch Time: 0m 6s
	Train Loss: 0.199 | Train Acc: 93.63%
	 Val. Loss: 0.343 |  Val. Acc: 88.46%
Epoch: 08 | Epoch Time: 0m 6s
	Train Loss: 0.186 | Train Acc: 94.04%
	 Val. Loss: 0.335 |  Val. Acc: 88.76%
Epoch: 09 | Epoch Time: 0m 6s
	Train Loss: 0.172 | Train Acc: 94.51%
	 Val. Loss: 0.336 |  Val. Acc: 88.78%
Epoch: 10 | Epoch Time: 0m 6

In [17]:
tagger.load_state_dict(torch.load('models/tut1-model.pt'))

valid_loss, valid_acc = eval_epoch(tagger, valid_iterator, loss_func)
test_loss, test_acc = eval_epoch(tagger, test_iterator, loss_func)

print(f'Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:.2f}%')
print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')

Val. Loss: 0.327 | Val. Acc: 88.78%
Test Loss: 0.324 | Test Acc: 89.20%


## Check Embeddings
* The Embeddings of `<unk>` and `<pad>` tokens
    * Because the `padding_idx` has been passed to `nn.Embedding`, so the `<pad>` embedding will remain zeros throughout training.  
    * While the `<unk>` embedding will be learned.

In [18]:
print(tagger.emb.weight[:5, :8])

tensor([[-1.0665e-01,  1.3830e-01,  3.2588e-02,  2.8069e-02,  1.8685e-02,
          1.3227e-01,  1.1601e-01,  8.0786e-02],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00],
        [-2.5757e-01, -1.8532e-01,  7.9732e-01, -5.6174e-01,  3.3207e-02,
          1.6820e-01, -5.3720e-01,  5.0092e-01],
        [-5.1804e-01,  9.9945e-02,  6.1691e-01, -7.6312e-01, -2.8021e-01,
          2.5440e-02,  3.5878e-01,  8.2394e-02],
        [-2.7172e-01, -2.2799e-04,  7.4437e-01, -6.6373e-01,  5.3565e-01,
          1.2817e-01,  2.2783e-01,  1.8637e-01]], device='cuda:0',
       grad_fn=<SliceBackward>)


## Inference

In [19]:
import spacy
nlp = spacy.load('en')

def tag_pos(tagger, sentence):
    tagger.eval()
    if isinstance(sentence, str):
        tokenized = [tok.text for tok in nlp.tokenizer(sentence)]
    else:
        tokenized = sentence
    indexed = [TEXT.vocab.stoi[tok] for tok in tokenized]
    lens = len(indexed)

    indexed = torch.tensor(indexed, dtype=torch.long).unsqueeze(1).to(device)
    lens = torch.tensor(lens, dtype=torch.long).unsqueeze(0).to(device)
    pred = tagger(indexed, lens).squeeze(1).argmax(dim=-1)
    return [UD_TAGS.vocab.itos[i.item()] for i in pred]

In [20]:
ex_idx = 0
tokens = train_data[ex_idx].text
real_tags = train_data[ex_idx].udtags
pred_tags = tag_pos(tagger, tokens)

for tok, rtag, ptag in zip(tokens, real_tags, pred_tags):
    correct = '✔' if ptag == rtag else '✘'
    print(f"{ptag}\t{rtag}\t{correct}\t{tok}")

PROPN	PROPN	✔	al
PUNCT	PUNCT	✔	-
PROPN	PROPN	✔	zaman
PUNCT	PUNCT	✔	:
ADJ	ADJ	✔	american
NOUN	NOUN	✔	forces
VERB	VERB	✔	killed
PROPN	PROPN	✔	shaikh
PROPN	PROPN	✔	abdullah
PROPN	PROPN	✔	al
PUNCT	PUNCT	✔	-
PROPN	PROPN	✔	ani
PUNCT	PUNCT	✔	,
DET	DET	✔	the
NOUN	NOUN	✔	preacher
ADP	ADP	✔	at
DET	DET	✔	the
NOUN	NOUN	✔	mosque
ADP	ADP	✔	in
DET	DET	✔	the
NOUN	NOUN	✔	town
ADP	ADP	✔	of
PROPN	PROPN	✔	qaim
PUNCT	PUNCT	✔	,
ADP	ADP	✔	near
DET	DET	✔	the
ADJ	ADJ	✔	syrian
NOUN	NOUN	✔	border
PUNCT	PUNCT	✔	.


In [21]:
ex_idx = 1
tokens = train_data[ex_idx].text
real_tags = train_data[ex_idx].udtags
pred_tags = tag_pos(tagger, tokens)

for tok, rtag, ptag in zip(tokens, real_tags, pred_tags):
    correct = '✔' if ptag == rtag else '✘'
    print(f"{ptag}\t{rtag}\t{correct}\t{tok}")

PUNCT	PUNCT	✔	[
DET	DET	✔	this
NOUN	NOUN	✔	killing
ADP	ADP	✔	of
DET	DET	✔	a
ADJ	ADJ	✔	respected
NOUN	NOUN	✔	cleric
AUX	AUX	✔	will
AUX	AUX	✔	be
VERB	VERB	✔	causing
PRON	PRON	✔	us
NOUN	NOUN	✔	trouble
ADP	ADP	✔	for
NOUN	NOUN	✔	years
PART	PART	✔	to
VERB	VERB	✔	come
PUNCT	PUNCT	✔	.
PUNCT	PUNCT	✔	]
