# Fine-Tuning Pretrained Transformers for PoS Tagging

A BERT followed by a linear layer for Part-of-Speech (PoS) Tagging.  

In [1]:
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

SEED = 515
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

## Loading the Tokenizer

In [2]:
from transformers import BertTokenizer

# tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# tokenizer.save_pretrained('./transformers_cache/bert-base-uncased/')
tokenizer = BertTokenizer.from_pretrained('./transformers_cache/bert-base-uncased/')
print(len(tokenizer.vocab))

30522


In [3]:
# This will tokenize and lower case the data in a way that is consistent with the pre-trained transformer model.
text = "Hello WORLD how ARE yoU?"
tokens = tokenizer.tokenize(text)
print(tokens)

indexes = tokenizer.convert_tokens_to_ids(tokens)
print(indexes)

indexes = tokenizer.encode(text, add_special_tokens=True)
print(indexes)

['hello', 'world', 'how', 'are', 'you', '?']
[7592, 2088, 2129, 2024, 2017, 1029]
[101, 7592, 2088, 2129, 2024, 2017, 1029, 102]


In [4]:
# `cls_token`: The classifier token which is used when doing sequence classification (classification of the whole
# sequence instead of per-token classification). It is the first token of the sequence when built with special tokens.
init_token = tokenizer.cls_token
# `sep_token`: The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
# for sequence classification or for a text and a question for question answering. It is also used as the last token of 
# a sequence built with special tokens.
eos_token = tokenizer.sep_token
pad_token = tokenizer.pad_token
unk_token = tokenizer.unk_token

print(init_token, eos_token, pad_token, unk_token)

[CLS] [SEP] [PAD] [UNK]


In [5]:
init_token_idx = tokenizer.cls_token_id
eos_token_idx = tokenizer.sep_token_id
pad_token_idx = tokenizer.pad_token_id
unk_token_idx = tokenizer.unk_token_id

print(init_token_idx, eos_token_idx, pad_token_idx, unk_token_idx)

101 102 0 100


In [6]:
max_input_length = tokenizer.max_model_input_sizes['bert-base-uncased']
print(max_input_length)

def cut_and_convert_to_ids(tokens, tokenizer, max_len):
    # Add special `[CLS]` and `[SEP]` tokens to the start and end of the tokens
    return tokenizer.convert_tokens_to_ids(tokens[:max_len-2])

def cut_to_max_len(tags, max_len):
    # Add special `[CLS]` and `[SEP]` tokens to the start and end of the tokens
    return tags[:max_len-2]

512


In [7]:
import functools

# Use `functools.partial` to pass functions with some of their arguments supplied
text_preprocessor = functools.partial(cut_and_convert_to_ids, tokenizer=tokenizer, max_len=max_input_length)
tag_preprocessor = functools.partial(cut_to_max_len, max_len=max_input_length)

## Preparing Data

The dataset is Universal Dependencies English Web Treebank (UDPOS).  
This dataset actually has two different sets of tags, [universal dependency (UD) tags](https://universaldependencies.org/u/pos/) and [Penn Treebank (PTB) tags](https://www.sketchengine.eu/penn-treebank-tagset/).  

In [8]:
from torchtext.data import Field, BucketIterator

# The data for PoS Tagging have already been tokenized, so NO need for a tokenizer
# `use_vocab`: Whether to use a Vocab object. If False, the data in this field should already be numerical.
TEXT = Field(batch_first=True, use_vocab=False, lower=True, preprocessing=text_preprocessor, 
             init_token=init_token_idx, eos_token=eos_token_idx, pad_token=pad_token_idx, unk_token=unk_token_idx,
             include_lengths=True)

# Because the set of possible tags is finite, do NOT use unknown token for it. 
# Add `<pad>` to both the start and end of tags, consistent to which has been done to the tokens. 
# Note the `<pad>` token will be ignored when calculating loss and accuracy. 
UD_TAGS = Field(batch_first=True, preprocessing=tag_preprocessor, 
                init_token='<pad>', eos_token='<pad>', unk_token=None, include_lengths=True)
PTB_TAGS = Field(batch_first=True, preprocessing=tag_preprocessor, 
                 init_token='<pad>', eos_token='<pad>', unk_token=None, include_lengths=True)

In [9]:
from torchtext.datasets import UDPOS

fields = [('text', TEXT), ('udtags', UD_TAGS), ('ptbtags', PTB_TAGS)]
train_data, valid_data, test_data = UDPOS.splits(fields=fields, root='data/')

In [10]:
print(train_data[0].text)
print(train_data[0].udtags)
print(train_data[0].ptbtags)

[2632, 1011, 100, 1024, 2137, 2749, 2730, 100, 14093, 2632, 1011, 100, 1010, 1996, 14512, 2012, 1996, 8806, 1999, 1996, 2237, 1997, 100, 1010, 2379, 1996, 9042, 3675, 1012]
['PROPN', 'PUNCT', 'PROPN', 'PUNCT', 'ADJ', 'NOUN', 'VERB', 'PROPN', 'PROPN', 'PROPN', 'PUNCT', 'PROPN', 'PUNCT', 'DET', 'NOUN', 'ADP', 'DET', 'NOUN', 'ADP', 'DET', 'NOUN', 'ADP', 'PROPN', 'PUNCT', 'ADP', 'DET', 'ADJ', 'NOUN', 'PUNCT']
['NNP', 'HYPH', 'NNP', ':', 'JJ', 'NNS', 'VBD', 'NNP', 'NNP', 'NNP', 'HYPH', 'NNP', ',', 'DT', 'NN', 'IN', 'DT', 'NN', 'IN', 'DT', 'NN', 'IN', 'NNP', ',', 'IN', 'DT', 'JJ', 'NN', '.']


In [11]:
UD_TAGS.build_vocab(train_data)
PTB_TAGS.build_vocab(train_data)

print(len(UD_TAGS.vocab), len(PTB_TAGS.vocab))
print(UD_TAGS.vocab.itos)

18 51
['<pad>', 'NOUN', 'PUNCT', 'VERB', 'PRON', 'ADP', 'DET', 'PROPN', 'ADJ', 'AUX', 'ADV', 'CCONJ', 'PART', 'NUM', 'SCONJ', 'X', 'INTJ', 'SYM']


In [12]:
BATCH_SIZE = 100

device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size=BATCH_SIZE, device=device)

In [13]:
for batch in train_iterator:
    batch_text, batch_text_lens = batch.text
    batch_tags, batch_tags_lens = batch.udtags
    break

print(batch_text)
print(batch_text_lens)
print(batch_tags)
print(batch_tags_lens)

print(batch_text_lens == batch_tags_lens)

tensor([[  101,  2004,  2017,  ...,     0,     0,     0],
        [  101,  2028,  2518,  ...,     0,     0,     0],
        [  101,  2190,  1010,  ...,     0,     0,     0],
        ...,
        [  101, 11556,  2479,  ...,     0,     0,     0],
        [  101,  1996, 18196,  ...,     0,     0,     0],
        [  101, 12159,   102,  ...,     0,     0,     0]], device='cuda:3')
tensor([21, 18,  4, 22, 46, 13, 31, 15, 12, 40, 24, 73, 19,  9, 17, 14,  9, 12,
        14, 31, 22,  7, 44, 22, 27, 13, 13,  6, 24, 18, 33, 30,  4, 26, 62, 20,
         6,  9,  6, 19, 28, 40, 36,  7,  4,  8,  3,  6, 25, 26, 35, 11, 18,  3,
        22, 29, 28, 25, 22, 15, 16, 22, 31, 16,  9, 15,  8, 25, 17, 13, 16, 29,
        33, 20,  4, 40, 54,  4,  4,  7,  9, 24,  9, 14, 18, 14,  7, 44, 20, 21,
        17, 10, 13, 15,  5, 35,  9,  6,  9,  3], device='cuda:3')
tensor([[ 0, 14,  4,  ...,  0,  0,  0],
        [ 0, 13,  1,  ...,  0,  0,  0],
        [ 0,  8,  2,  ...,  0,  0,  0],
        ...,
        [ 0,  7,  7,  

## Building the Model

`BertModel.forward`
* Input
    * `input_ids`: (batch, step)
    * `attention_mask`: (batch, step)
        * Mask to avoid performing attention on padding token indices  
        * A `torch.FloatTensor` with values selected in `{0, 1}`; The value being `0` means `masked`, and the value being `1` means `not-masked` 
* Output
    * `last_hidden_state`: (batch, step, hidden)
        * Sequence of hidden-states at the output of the last layer of the model  
    * `pooler_output`: (batch, hidden)
        * Last layer hidden-state of the first token of the sequence (classification token)
        * It will be further processed by a linear layer and a `tanh`, which was trained for next sentence prediction (classification) objective  
    * `attentions`: tuple of (batch, head, step, step), returned when `config.output_attentions=True`  
        * Attention weights after the `softmax`  

In [14]:
from transformers import BertModel

# bert = BertModel.from_pretrained('bert-base-uncased')
# bert.save_pretrained('./transformers_cache/bert-base-uncased/')

# Set `output_attentions=True` to return attentions from `bert.forward`
bert = BertModel.from_pretrained('./transformers_cache/bert-base-uncased/', output_attentions=True).to(device)
bert.config.output_attentions

True

In [15]:
# mask: (batch, step)
mask = (batch.text[0] != pad_token_idx).float()
bert_outs, bert_pooled_outs, attens = bert(batch.text[0], attention_mask=mask)
print(batch.text[0].size())
print(bert_outs.size())
print(bert_pooled_outs.size())

print(len(attens), bert.config.num_hidden_layers)
print(attens[0].size())

torch.Size([100, 73])
torch.Size([100, 73, 768])
torch.Size([100, 768])
12 12
torch.Size([100, 12, 73, 73])


In [16]:
# Check whether the attention is 0 on padding positions 
print((attens[0].sum(dim=-1) - 1).abs().max())
print(((attens[0] != 0) == mask.view(mask.size(0), 1, 1, -1)).all())

# Show the first head attention
print(attens[0][:, 0])

tensor(2.9802e-07, device='cuda:3', grad_fn=<MaxBackward1>)
tensor(True, device='cuda:3')
tensor([[[0.0371, 0.0607, 0.0216,  ..., 0.0000, 0.0000, 0.0000],
         [0.0758, 0.0309, 0.0481,  ..., 0.0000, 0.0000, 0.0000],
         [0.0611, 0.0352, 0.0358,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0088, 0.0112, 0.0748,  ..., 0.0000, 0.0000, 0.0000],
         [0.0088, 0.0120, 0.0691,  ..., 0.0000, 0.0000, 0.0000],
         [0.0070, 0.0118, 0.0770,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.0360, 0.0951, 0.0281,  ..., 0.0000, 0.0000, 0.0000],
         [0.0832, 0.0308, 0.0522,  ..., 0.0000, 0.0000, 0.0000],
         [0.0587, 0.0141, 0.0430,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0107, 0.0171, 0.0850,  ..., 0.0000, 0.0000, 0.0000],
         [0.0104, 0.0145, 0.0789,  ..., 0.0000, 0.0000, 0.0000],
         [0.0085, 0.0161, 0.0877,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.1473, 0.0975, 0.1198,  ..., 0.0000, 0.0000, 0.0000],
         [0.2599, 0.0979, 0.3725,

In [17]:
# The values at padding positions are NOT zeros? 
# Yes, but they will never pollute the non-padding positions, since the attentions are applied with masking. 
print(batch.text[0])
print(bert_outs)

tensor([[  101,  2004,  2017,  ...,     0,     0,     0],
        [  101,  2028,  2518,  ...,     0,     0,     0],
        [  101,  2190,  1010,  ...,     0,     0,     0],
        ...,
        [  101, 11556,  2479,  ...,     0,     0,     0],
        [  101,  1996, 18196,  ...,     0,     0,     0],
        [  101, 12159,   102,  ...,     0,     0,     0]], device='cuda:3')
tensor([[[ 0.0769, -0.0570,  0.2451,  ..., -0.2046,  0.6273,  0.4673],
         [ 0.4734,  0.3849,  0.9369,  ..., -0.3964,  0.8343, -0.1047],
         [-0.2909, -0.4783,  0.4526,  ...,  0.0791,  1.0221, -0.1757],
         ...,
         [ 0.9056, -0.0645,  0.1375,  ..., -0.0397,  0.9553, -0.7583],
         [ 0.9373, -0.0585,  0.2285,  ..., -0.0013,  0.7479, -0.7201],
         [ 0.4119, -0.0657,  0.3063,  ...,  0.2168,  0.4914, -0.3562]],

        [[-0.1209,  0.1067, -0.0999,  ..., -0.1584,  0.2379,  0.7386],
         [-0.0933, -0.6518, -0.5936,  ...,  0.4867,  0.4932,  0.4381],
         [-0.2321,  0.2492, -0.4463, 

Instead of using an embedding layer to get embeddings for our text, we'll be using the pre-trained transformer model. These embeddings will then be fed into a linear layer to predict the tag for each token.  

We get the embedding dimension size (called the `hidden_size`) from the transformer via its config attribute.

In [18]:
class PoSTagger(nn.Module):
    def __init__(self, bert, tag_dim, dropout):
        super().__init__()
        # Use `bert` to provide word embeddings. 
        self.bert = bert
        emb_dim = bert.config.hidden_size
        
        self.fc = nn.Linear(emb_dim, tag_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, text):
        # text/mask: (batch, step)
        mask = (text != self.bert.config.pad_token_id).float()
        embedded, *_ = self.bert(text, attention_mask=mask)

        embedded = self.dropout(embedded)
        
        # preds: (batch, step, tag_dim)
        preds = self.fc(embedded)
        return preds

In [19]:
TAG_DIM = len(UD_TAGS.vocab)
# TAG_DIM = len(PTB_TAGS.vocab)

DROPOUT = 0.25
TAG_PAD_IDX = UD_TAGS.vocab.stoi[UD_TAGS.pad_token]
# TAG_PAD_IDX = PTB_TAGS.vocab.stoi[PTB_TAGS.pad_token]


tagger = PoSTagger(bert, TAG_DIM, DROPOUT).to(device)
preds = tagger(batch_text)

print(batch_text.size())
print(preds.size())

torch.Size([100, 73])
torch.Size([100, 73, 18])


In [20]:
# Check if data are mixed across different samples in a batch.
tagger.eval()
preds_012 = tagger(batch_text[0:3, :])
preds_123 = tagger(batch_text[1:4, :])

(preds_012[1:] - preds_123[:2]).abs().max()

tensor(0., device='cuda:3', grad_fn=<MaxBackward1>)

## Training the Model

In [21]:
def count_parameters(model):
    """
    Count trainable parameters. 
    """
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(tagger):,} trainable parameters')

The model has 109,496,082 trainable parameters


In [22]:
LEARNING_RATE = 5e-5

optimizer = optim.AdamW(tagger.parameters(), lr=LEARNING_RATE)
loss_func = nn.CrossEntropyLoss(ignore_index=TAG_PAD_IDX, reduction='mean')

In [23]:
def train_epoch(tagger, iterator, optimizer, loss_func):
    tagger.train()
    epoch_loss = 0
    epoch_acc = 0
    for batch in iterator:
        # Forward pass
        text, text_lens = batch.text
        tags, tags_lens = batch.udtags
        preds = tagger(text)

        # Calculate loss
        preds_flattened = preds.view(-1, preds.size(-1))
        tags_flattened = tags.flatten()
        loss = loss_func(preds_flattened, tags_flattened)

        # Backward propagation
        optimizer.zero_grad()
        loss.backward()
        # Update weights
        optimizer.step()
        # Accumulate loss and acc
        epoch_loss += loss.item()
        non_padding = (tags_flattened != loss_func.ignore_index)
        epoch_acc += (preds_flattened.argmax(dim=-1) == tags_flattened)[non_padding].sum().item() / non_padding.sum().item()
    return epoch_loss/len(iterator), epoch_acc/len(iterator)

def eval_epoch(tagger, iterator, loss_func):
    tagger.eval()
    epoch_loss = 0
    epoch_acc = 0
    with torch.no_grad():
        for batch in iterator:
            # Forward pass
            text, text_lens = batch.text
            tags, tags_lens = batch.udtags
            preds = tagger(text)

            # Calculate loss
            preds_flattened = preds.view(-1, preds.size(-1))
            tags_flattened = tags.flatten()
            loss = loss_func(preds_flattened, tags_flattened)
            
            # Accumulate loss and acc
            epoch_loss += loss.item()
            non_padding = (tags_flattened != loss_func.ignore_index)
            epoch_acc += (preds_flattened.argmax(dim=-1) == tags_flattened)[non_padding].sum().item() / non_padding.sum().item()
    return epoch_loss/len(iterator), epoch_acc/len(iterator)

In [24]:
import time
N_EPOCHS = 10
best_valid_loss = np.inf

for epoch in range(N_EPOCHS):
    t0 = time.time()
    train_loss, train_acc = train_epoch(tagger, train_iterator, optimizer, loss_func)
    valid_loss, valid_acc = eval_epoch(tagger, valid_iterator, loss_func)
    epoch_secs = time.time() - t0

    epoch_mins, epoch_secs = int(epoch_secs // 60), int(epoch_secs % 60)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(tagger.state_dict(), 'models/tut2-model.pt')
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

Epoch: 01 | Epoch Time: 0m 50s
	Train Loss: 0.522 | Train Acc: 85.65%
	 Val. Loss: 0.276 |  Val. Acc: 91.74%
Epoch: 02 | Epoch Time: 0m 49s
	Train Loss: 0.137 | Train Acc: 96.05%
	 Val. Loss: 0.242 |  Val. Acc: 92.54%
Epoch: 03 | Epoch Time: 0m 50s
	Train Loss: 0.095 | Train Acc: 97.24%
	 Val. Loss: 0.250 |  Val. Acc: 92.53%
Epoch: 04 | Epoch Time: 0m 50s
	Train Loss: 0.069 | Train Acc: 98.05%
	 Val. Loss: 0.243 |  Val. Acc: 92.59%
Epoch: 05 | Epoch Time: 0m 51s
	Train Loss: 0.050 | Train Acc: 98.53%
	 Val. Loss: 0.246 |  Val. Acc: 93.00%
Epoch: 06 | Epoch Time: 0m 51s
	Train Loss: 0.039 | Train Acc: 98.86%
	 Val. Loss: 0.244 |  Val. Acc: 93.25%
Epoch: 07 | Epoch Time: 0m 50s
	Train Loss: 0.030 | Train Acc: 99.16%
	 Val. Loss: 0.262 |  Val. Acc: 93.62%
Epoch: 08 | Epoch Time: 0m 50s
	Train Loss: 0.024 | Train Acc: 99.32%
	 Val. Loss: 0.274 |  Val. Acc: 93.43%
Epoch: 09 | Epoch Time: 0m 50s
	Train Loss: 0.019 | Train Acc: 99.45%
	 Val. Loss: 0.286 |  Val. Acc: 93.24%
Epoch: 10 | Epoch T

In [25]:
tagger.load_state_dict(torch.load('models/tut2-model.pt'))

valid_loss, valid_acc = eval_epoch(tagger, valid_iterator, loss_func)
test_loss, test_acc = eval_epoch(tagger, test_iterator, loss_func)

print(f'Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:.2f}%')
print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')

Val. Loss: 0.242 | Val. Acc: 92.54%
Test Loss: 0.260 | Test Acc: 91.54%
