<a href="https://colab.research.google.com/github/zhangguanheng66/tutorials/blob/bert_question_answer/fine_tune_BERT_question_answer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%shell

rm -r /usr/local/lib/python3.6/dist-packages/torch*
pip install numpy
pip install --pre torch torchtext -f https://download.pytorch.org/whl/nightly/cu101/torch_nightly.html

In [1]:
import torch
import torchtext
print(torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

True
cuda


# Fine-tuning BERT model for Question-Answer with Torchtext

Recently, we release a new torchtext library to address a few issues from OSS community ([torchtext 0.7.0 release note](https://github.com/pytorch/text/releases/tag/v0.7.0-rc3) and 0.8.0 release note). To accelerate research, the new torchtext library will provide reusable, orthogonal, correct, and performant building blocks  (Vectors, Vocab, Tokenizer) for cutting-edge experimentation based on deep knowledge of the NLP research and communities. In addition, we deeply integrate torchtext with broad range of PyTorch capabilities, such as Just-in-Time (JIT), quantization, distributed, and mobile, to enable seamless research-to-production for core end-to-end applications.

In this tutorial, we are going to fine-tune a pretrained BERT model for question-answer task with the new library.


## Step 1: Prepare datasets

We have revisited the very basic components of the torchtext library, including vocab, word vectors, tokenizer backed by regular expression, and sentencepiece. Those are the basic data processing building blocks for raw text string.

### Tokenizer-vocabulary data processing pipeline

Download a vocab text file.

In [3]:
%%shell
rm bert_vocab.txt
wget https://pytorch.s3.amazonaws.com/models/text/torchtext_bert_example/bert_vocab.txt 

rm: cannot remove 'bert_vocab.txt': No such file or directory
--2020-10-01 16:30:13--  https://pytorch.s3.amazonaws.com/models/text/torchtext_bert_example/bert_vocab.txt
Resolving pytorch.s3.amazonaws.com (pytorch.s3.amazonaws.com)... 52.216.102.11
Connecting to pytorch.s3.amazonaws.com (pytorch.s3.amazonaws.com)|52.216.102.11|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 923169 (902K) [text/plain]
Saving to: ‘bert_vocab.txt’


2020-10-01 16:30:14 (2.30 MB/s) - ‘bert_vocab.txt’ saved [923169/923169]





Prepare data pipeline for the question-answer dataset

In [5]:
from torchtext.experimental.transforms import (
    basic_english_normalize,
    TextSequentialTransforms,
)
from torchtext.experimental.vocab import vocab_from_file
with open('bert_vocab.txt', 'r') as f:
  vocab = vocab_from_file(f)
text_pipeline = TextSequentialTransforms(basic_english_normalize(), vocab)
pos_pipeline = lambda x: torch.tensor(x, dtype=torch.long)
qa_data_pipeline= {'context': text_pipeline, 'question': text_pipeline,
                   'answers': text_pipeline, 'ans_pos': pos_pipeline}

The dataset in `torchtext.experimental.datasets.raw` returns iterators which yield the raw data. In this way, users can definte the custom data processing pipelines and work on the raw data.

In [22]:
from torchtext.experimental.datasets.raw import SQuAD1
train, dev = SQuAD1()

Materialize the raw SQuAD data iterators. Pass the data and data processing pipelines (a.k.a. transforms) to the question answer dataset abstraction. `QuestionAnswerDataset` is an abstraction ([link](https://github.com/pytorch/text/blob/467ee98faba8e00b0e6acbf3132a723e08f36859/torchtext/experimental/datasets/question_answer.py#L12)) that applies the user-defined transform pipelines to the raw question-answer data.

In [23]:
from torchtext.experimental.datasets.question_answer import QuestionAnswerDataset
train_data = QuestionAnswerDataset(list(train), vocab, qa_data_pipeline)
dev_data = QuestionAnswerDataset(list(dev), vocab, qa_data_pipeline)

### (Optional for tutorial) Word-vector embedding data processing pipeline

---



Word embeddings are a type of word representation that allows words with similar meaning to have a similar representation. FastText and GloVe are well established baseline word vectors in the NLP community. In the new torchtext library, a Vector object supports the mapping between tokens and their corresponding vector representation (i.e. word embeddings).

In [None]:
from torchtext.experimental.transforms import (
    basic_english_normalize,
    TextSequentialTransforms,
)
from torchtext.experimental.vectors import FastText
vector = FastText()
word_vector_pipeline = TextSequentialTransforms(basic_english_normalize(), vector)

### (Optional for tutorial) SentencePiece data processing pipeline

SentencePiece is an unsupervised text tokenizer and detokenizer mainly for Neural Network-based text generation systems where the vocabulary size is predetermined prior to the neural model training. For sentencepiece transforms in torchtext, both subword units (e.g., byte-pair-encoding (BPE) ) and unigram language model are supported.

Here is an example to apply SentencePiece transform to build a Language Modeling dataset. Although the pretrained BERT model was generated on a different vocabulary, the follow LM dataset with the SentencePiece transform can be used to train a masked language model task (described in the BERT paper) from scratch.

In [7]:
from torchtext.experimental.transforms import (
    PRETRAINED_SP_MODEL,
    sentencepiece_processor,
)
from torchtext.utils import download_from_url
spm_filepath = download_from_url(PRETRAINED_SP_MODEL['text_unigram_25000'])
spm_transform = sentencepiece_processor(spm_filepath)

text_unigram_25000.model: 100%|██████████| 678k/678k [00:00<00:00, 2.15MB/s]


Check out raw WikiText2 dataset as an iterator.

In [8]:
from torchtext.experimental.datasets.raw import WikiText2
train_iter, test_iter, valid_iter = WikiText2()

wikitext-2-v1.zip: 100%|██████████| 4.48M/4.48M [00:00<00:00, 8.54MB/s]


Add WikiText2 data and data processing pipeline (a.k.a. transform) to the language modeling dataset abstraction.

In [9]:
from torchtext.experimental.datasets.language_modeling import LanguageModelingDataset
wikitext2_train = LanguageModelingDataset(list(train_iter), None, spm_transform, False)
wikitext2_test = LanguageModelingDataset(list(test_iter), None, spm_transform, False)
wikitext2_valid = LanguageModelingDataset(list(valid_iter), None, spm_transform, False)
len(wikitext2_train)

36718

### JIT support for the data processing pipeline

The new building blocks in torchtext library is compatible with `torch.jit.script`. TorchScript is a way to create serializable and optimizable models from PyTorch code. Any TorchScript program can be saved from a Python process and loaded in a process where there is no Python dependency. The data processing pipelines above can be converted and run on the JIT mode without Python dependency


In [None]:
text_pipeline = text_pipeline.to_ivalue()
jit_text_pipeline = torch.jit.script(text_pipeline)

## Step 2: Data Iterator

The PyTorch data loading utility is the [torch.utils.data.DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) class. It works with a map-style dataset that implements the __getitem__() and __len__() protocols, and represents a map from indices/keys to data sampl*e*s. Before sending to the model, collate_fn function works on a batch of samples generated from DataLoader. 

In [24]:
from torch.utils.data import DataLoader

# [TODO] integrate with torchtext.experimental.transforms.PadTransform
# Need to land https://github.com/pytorch/text/pull/952

cls_id = vocab(['<cls>'])
sep_id = vocab(['<sep>'])
pad_id = vocab(['<pad>'])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def collate_batch(batch):
    seq_list, ans_pos_list, tok_type = [], [], []
    for (_context, _question, _answers, _ans_pos) in batch:
        _context, _question = torch.tensor(_context), torch.tensor(_question)
        qa_item = torch.cat((torch.tensor(cls_id), _question, torch.tensor(sep_id),
                             _context, torch.tensor(sep_id)))
        seq_list.append(qa_item)
        pos_list = [pos + _question.size(0) + 2 for pos in _ans_pos]
        ans_pos_list.append(pos_list)
        tok_type.append(torch.cat((torch.zeros((_question.size(0) + 2)),
                                   torch.ones((_context.size(0) + 1)))))
    _ans_pos_list = [torch.stack(list(pos)) for pos in zip(*ans_pos_list)]
    target_start_pos, target_end_pos = _ans_pos_list[0].split(1, dim=-1)
    target_start_pos = target_start_pos.squeeze(-1)
    target_end_pos = target_end_pos.squeeze(-1)
    seq_list = torch.nn.utils.rnn.pad_sequence(seq_list, batch_first=True, padding_value=float(pad_id[0]))
    seq_list = seq_list.long().t().contiguous()
    tok_type = torch.nn.utils.rnn.pad_sequence(tok_type, batch_first=True, padding_value=1.0)
    tok_type = tok_type.long().t().contiguous()
    return seq_list.to(device), target_start_pos.to(device), target_end_pos.to(device), tok_type.to(device)

BATCH_SIZE = 8
train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
dev_dataloader = DataLoader(dev_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)

## Step 3: Model for Question-Answer Task

A BERT model was pretrained with the masked language modeling task and next-sentence task according to the paper - [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805).  torchtext.datasets.WikiText103 and BookCorpus were used to pre-trained the model.

Here is the layout of the model for the question-answer task. On the top of the BERT model, there is a linear layer to project the position of the answer.

In [25]:
class QuestionAnswerTask(torch.nn.Module):
    """Contain a pretrain BERT model and a linear layer."""

    def __init__(self, bert_model):
        super(QuestionAnswerTask, self).__init__()
        self.bert_model = bert_model
        self.activation = torch.nn.GELU()
        self.qa_span = torch.nn.Linear(bert_model.ninp, 2)

    def forward(self, src, token_type_input):
        output = self.bert_model(src, token_type_input)
        # transpose output (S, N, E) to (N, S, E)
        output = output.transpose(0, 1)
        output = self.activation(output)
        pos_output = self.qa_span(output)
        start_pos, end_pos = pos_output.split(1, dim=-1)
        start_pos = start_pos.squeeze(-1)
        end_pos = end_pos.squeeze(-1)
        return start_pos, end_pos

 The pretrained model is available here

In [None]:
%%shell
rm ns_bert.pt, model.py
wget https://pytorch.s3.amazonaws.com/models/text/torchtext_bert_example/ns_bert.pt
md5sum -c <<<"f14abe2424ea321e66a82407ddab2dd4 ns_bert.pt"
wget https://pytorch.s3.amazonaws.com/models/text/torchtext_bert_example/model.py
md5sum -c <<<"cbf74f9e864a988f25d09b51f5080168 model.py"

and can be loaded to the question answer task

In [26]:
from model import BertModel
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vocab_size, emsize, nhead, nhid, nlayers, dropout = 99230, 768, 12, 3072, 12, 0.2
pretrained_bert = BertModel(vocab_size, emsize, nhead, nhid, nlayers, dropout)
pretrained_bert.load_state_dict(torch.load('ns_bert.pt', map_location=device))
qa_model = QuestionAnswerTask(pretrained_bert).to(device)

## Step 4: Fine-tuning the Model

Then, we fine-tune the BERT model with the question-answer task based on the SQuAD1 dataset

In [29]:
import time
import math

def fine_tune(model, dataloader, optimizer, criterion, batch_size, device):
    model.train()
    total_loss = 0.
    log_interval = 20
    start_time = time.time()

    for idx, (seq_input, target_start_pos, target_end_pos, tok_type) in enumerate(dataloader):
        optimizer.zero_grad()
        # print(seq_input.size(), tok_type.size())
        start_pos, end_pos = model(seq_input, token_type_input=tok_type)
        loss = (criterion(start_pos, target_start_pos) + criterion(end_pos, target_end_pos)) / 2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()
        total_loss += loss.item()
        if idx % log_interval == 0 and idx > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:05.5f} | '
                  'ms/batch {:5.2f} | '
                  'loss {:5.2f} | ppl {:8.2f}'.format(epoch, idx,
                                                      len(dataloader) // batch_size,
                                                      scheduler.get_last_lr()[0],
                                                      elapsed * 1000 / log_interval,
                                                      cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()

def evaluate(model, dataloader, optimizer, criterion, batch_size, device):
    model.eval()
    total_loss = 0.
    ans_pred_tokens_samples = []

    with torch.no_grad():
        for idx, (seq_input, target_start_pos, target_end_pos, tok_type) in enumerate(dataloader):
            start_pos, end_pos = model(seq_input, token_type_input=tok_type)
            loss = (criterion(start_pos, target_start_pos)
                    + criterion(end_pos, target_end_pos)) / 2
            total_loss += loss.item()
    return total_loss / len(dataloader)

In [None]:
# Hyperparameters
EPOCHS = 10 # epoch
LR = 0.5  # learning rate
BATCH_SIZE = 72 # batch size for training
  
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(qa_model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
eval_loss = None

for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    fine_tune(qa_model, train_dataloader, optimizer, criterion, BATCH_SIZE, device)
    _loss = evaluate(qa_model, dev_dataloader, optimizer, criterion, BATCH_SIZE, device)
    if eval_loss is not None and _loss > eval_loss:
      scheduler.step()
    else:
       eval_loss = _loss
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | '
          'valid loss {:5.2f} | '.format(epoch, (time.time() - epoch_start_time),
                                         _loss))
    print('-' * 89)
