<a href="https://colab.research.google.com/github/zhangguanheng66/tutorials/blob/bert_question_answer/fine_tune_BERT_question_answer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%shell

rm -r /usr/local/lib/python3.6/dist-packages/torch*
pip install numpy
pip install --pre torch torchvision torchtext -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html

In [None]:
import torch
import torchtext

# Fine-tuning BERT model for Question-Answer with Torchtext

Recently, we release a new torchtext library to address a few issues from OSS community ([torchtext 0.7.0 release note](https://github.com/pytorch/text/releases/tag/v0.7.0-rc3) and 0.8.0 release note). To accelerate research, the new torchtext library will provide reusable, orthogonal, correct, and performant building blocks  (Vectors, Vocab, Tokenizer) for cutting-edge experimentation based on deep knowledge of the NLP research and communities. In addition, we deeply integrate torchtext with broad range of PyTorch capabilities, such as Just-in-Time (JIT), quantization, distributed, and mobile, to enable seamless research-to-production for core end-to-end applications.

In this tutorial, we are going to fune-tune a pretrained BERT model for question-answer task with the new library.


## Step 1: Prepare datasets

We have revisited the very basic components of the torchtext library, including vocab, word vectors, tokenizer backed by regular expression, and sentencepiece. Those are the basic data processing building blocks for raw text string.

### Tokenizer-vocabulary data processing pipeline

Download a vocab text file.

In [None]:
%%shell
rm *.txt
wget https://pytorch.s3.amazonaws.com/models/text/torchtext_bert_example/bert_vocab.txt

Prepare data pipeline for the question-answer dataset

In [None]:
from torchtext.experimental.transforms import (
    basic_english_normalize,
    TextSequentialTransforms,
)
from torchtext.experimental.vocab import vocab_from_file
vocab = vocab_from_file(open('bert_vocab.txt', 'r'))
text_pipeline = TextSequentialTransforms(basic_english_normalize(), vocab)
pos_pipeline = lambda x: torch.tensor(x, dtype=torch.long)
qa_data_pipeline= {'context': text_pipeline, 'question': text_pipeline,
                   'answers': text_pipeline, 'ans_pos': pos_pipeline}

Check out raw SQuAD dataset as an iterator.

In [None]:
from torchtext.experimental.datasets.raw import SQuAD1
train, dev = SQuAD1()

Add SQuAD data and data processing pipelines (a.k.a. transforms) to the question answer dataset abstraction.

In [None]:
from torchtext.experimental.datasets.question_answer import QuestionAnswerDataset
train_data = QuestionAnswerDataset([item for item in train], vocab, qa_data_pipeline)
dev_data = QuestionAnswerDataset([item for item in dev], vocab, qa_data_pipeline)
# train_data[0]

### (Optional for tutorial) Word-vector embedding data processing pipeline

Word embeddings are a type of word representation that allows words with similar meaning to have a similar representation. FastText and GloVe are well established baseline word vectors in the NLP community. In the new torchtext library, a Vector object supports the mapping between tokens and their corresponding vector representation (i.e. word embeddings).

In [None]:
from torchtext.experimental.transforms import (
    basic_english_normalize,
    TextSequentialTransforms,
)
from torchtext.experimental.vectors import FastText
vector = FastText()
word_vector_pipeline = TextSequentialTransforms(basic_english_normalize(), vector)

### (Optional for tutorial) SentencePiece data processing pipeline

SentencePiece is an unsupervised text tokenizer and detokenizer mainly for Neural Network-based text generation systems where the vocabulary size is predetermined prior to the neural model training. For sentencepiece transforms in torchtext, both subword units (e.g., byte-pair-encoding (BPE) ) and unigram language model are supported.

Here is an example to apply SentencePiece transform to build a Language Modeling dataset. Although the pretrained BERT model was generated on a different vocabulary, the follow LM dataset with the SentencePiece transform can be used to train a masked language model task (described in the BERT paper) from scratch.

In [None]:
# Need to land https://github.com/pytorch/text/pull/916
from torchtext.experimental.transforms import (
    load_pretrained_sp_model,
    sentencepiece_processor,
)
spm_filepath = load_pretrained_sp_model()[1]
spm_transform = sentencepiece_processor(spm_filepath)

Check out raw WikiText2 dataset as an iterator.

In [None]:
from torchtext.experimental.datasets.raw import WikiText2
train_iter, test_iter, valid_iter = WikiText2()

Add WikiText2 data and data processing pipeline (a.k.a. transform) to the language modeling dataset abstraction.

In [None]:
from torchtext.experimental.datasets.language_modeling import LanguageModelingDataset
wikitext2_train = LanguageModelingDataset(list(train_iter), None, spm_transform, False)
wikitext2_test = LanguageModelingDataset(list(test_iter), None, spm_transform, False)
wikitext2_valid = LanguageModelingDataset(list(valid_iter), None, spm_transform, False)

### JIT support for the data processing pipeline

The new building blocks in torchtext library is compatible with torch.jit.script. TorchScript is a way to create serializable and optimizable models from PyTorch code. Any TorchScript program can be saved from a Python process and loaded in a process where there is no Python dependency. The data processing pipelines above can be converted and run on the JIT mode without Python dependency


In [None]:
jit_text_pipeline = torch.jit.script(text_pipeline.to_ivalue())

## Step 2: Data Iterator

The PyTorch data loading utility is the [torch.utils.data.DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) class. It works with a map-style dataset that implements the __getitem__() and __len__() protocols, and represents a map from indices/keys to data sampl*e*s. Before sending to the model, collate_fn function works on a batch of samples generated from DataLoader. 

In [None]:
from torch.utils.data import DataLoader

# [TODO] integrate with torchtext.experimental.transforms.PadTransform
# Need to land https://github.com/pytorch/text/pull/952

cls_id = vocab(['<cls>'])
sep_id = vocab(['<sep>'])
pad_id = vocab(['<pad>'])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def collate_batch(batch):
    seq_list, ans_pos_list, tok_type = [], [], []
    for (_context, _question, _answers, _ans_pos) in batch:
        _context, _question = torch.tensor(_context), torch.tensor(_question)
        qa_item = torch.cat((torch.tensor(cls_id), _question, torch.tensor(sep_id),
                             _context, torch.tensor(sep_id)))
        seq_list.append(qa_item)
        pos_list = [pos + _question.size(0) + 2 for pos in _ans_pos]
        ans_pos_list.append(pos_list)
        tok_type.append(torch.cat((torch.zeros((_question.size(0) + 2)),
                                   torch.ones((_context.size(0))))))
    _ans_pos_list = [torch.stack(list(pos)) for pos in zip(*ans_pos_list)]
    seq_list = torch.nn.utils.rnn.pad_sequence(seq_list, batch_first=True, padding_value=float(pad_id[0]))
    tok_type = torch.nn.utils.rnn.pad_sequence(tok_type, batch_first=True, padding_value=1.0)
    return seq_list.long(), _ans_pos_list, tok_type.long()

BATCH_SIZE = 16
dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)

In [None]:
for batch in dataloader:
  print(batch)
  break


## Step 3: Model for Question-Answer Task

A BERT model was pretrained with the masked language modeling task and next-sentence task according to the paper - [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805).  torchtext.datasets.WikiText103 and BookCorpus were used to pre-trained the model.

Here is the layout of the model for the question-answer task. On the top of the BERT model, there is a linear layer to project the position of the answer.

In [None]:
class QuestionAnswerTask(torch.nn.Module):
    """Contain a pretrain BERT model and a linear layer."""

    def __init__(self, bert_model):
        super(QuestionAnswerTask, self).__init__()
        self.bert_model = bert_model
        self.activation = nn.functional.gelu
        self.qa_span = nn.Linear(bert_model.ninp, 2)

    def forward(self, src, token_type_input):
        output = self.bert_model(src, token_type_input)
        # transpose output (S, N, E) to (N, S, E)
        output = output.transpose(0, 1)
        output = self.activation(output)
        pos_output = self.qa_span(output)
        start_pos, end_pos = pos_output.split(1, dim=-1)
        start_pos = start_pos.squeeze(-1)
        end_pos = end_pos.squeeze(-1)
        return start_pos, end_pos

 The pretrained model is available here

In [None]:
%%shell
rm *.pt
wget https://pytorch.s3.amazonaws.com/models/text/torchtext_bert_example/full_ns_bert.pt
md5sum -c <<<"8070efa65373e0fb28ed23e1861c0def full_ns_bert.pt"

--2020-09-22 18:20:14--  https://pytorch.s3.amazonaws.com/models/text/torchtext_bert_example/full_ns_bert.pt
Resolving pytorch.s3.amazonaws.com (pytorch.s3.amazonaws.com)... 52.216.184.11
Connecting to pytorch.s3.amazonaws.com (pytorch.s3.amazonaws.com)|52.216.184.11|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 660517596 (630M) [binary/octet-stream]
Saving to: ‘full_ns_bert.pt’


2020-09-22 18:20:33 (33.7 MB/s) - ‘full_ns_bert.pt’ saved [660517596/660517596]

full_ns_bert.pt: OK




and can be loaded to the question answer task

In [None]:
pretrained_bert = torch.load('full_ns_bert.pt')
qa_model = QuestionAnswerTask(pretrained_bert)

## Step 4: Fine-tuning the Model

Then, we fine-tune the BERT model with the question-answer task based on the SQuAD1 dataset

In [None]:
import time

def fine_tune(model, optimizer, criterion, batch_size, device):
    model.train()
    total_loss = 0.
    log_interval = 200
    start_time = time.time()
    dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True,
                            collate_fn=collate_batch)
    for idx, (seq_input, ans_pos, tok_type) in enumerate(dataloader):
        seq_input = seq_input.t().contiguous().to(device)
        tok_type = tok_type.t().contiguous().to(device)
        optimizer.zero_grad()
        start_pos, end_pos = model(seq_input, token_type_input=tok_type)
        target_start_pos, target_end_pos = ans_pos[0].to(device).split(1, dim=-1)
        target_start_pos = target_start_pos.squeeze(-1)
        target_end_pos = target_end_pos.squeeze(-1)
        loss = (criterion(start_pos, target_start_pos) + criterion(end_pos, target_end_pos)) / 2
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()
        total_loss += loss.item()
        if idx % log_interval == 0 and idx > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:05.5f} | '
                  'ms/batch {:5.2f} | '
                  'loss {:5.2f} | ppl {:8.2f}'.format(epoch, idx,
                                                      len(train_dataset) // batch_size,
                                                      scheduler.get_last_lr()[0],
                                                      elapsed * 1000 / log_interval,
                                                      cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()
            
# Hyperparameters
EPOCHS = 10 # epoch
LR = 0.5  # learning rate
BATCH_SIZE = 72 # batch size for training
  
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(qa_model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    fine_tune(qa_model, optimizer, criterion, BATCH_SIZE, device)