Use a smaller corpus WikiText-2

Comparing with the PTB dataset used for pretraining word2vec in Section 15.3, WikiText-2 

(i) retains the original punctuation, making it suitable for next sentence prediction; 

(ii) retains the original case and numbers; 

(iii) is over twice larger.

In [1]:
import os
import random
import torch
from d2l import torch as d2l

  warn(f"Failed to load image Python extension: {e}")


Get data

In [2]:
#@save
d2l.DATA_HUB['wikitext-2'] = (
    'https://s3.amazonaws.com/research.metamind.io/wikitext/'
    'wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe')

#@save
def _read_wiki(data_dir):
    file_name = os.path.join(data_dir, 'wiki.train.tokens')
    with open(file_name, 'r') as f:
        lines = f.readlines()
    # Uppercase letters are converted to lowercase ones
    paragraphs = [line.strip().lower().split(' . ')
                  for line in lines if len(line.split(' . ')) >= 2]
    random.shuffle(paragraphs)
    return paragraphs

In [3]:
text = ['The following function generates training examples for next sentence prediction from the input paragraph by invoking the _get_next_sentence function . Here paragraph is a list of sentences , where each sentence is a list of tokens . ',
        'Paragraphs with at least two sentences are retained . To split sentences , we only use the period as the delimiter for simplicity . We leave discussions of more complex sentence splitting techniques in the exercises at the end of this section . ',
        f'In the following , we begin by implementing helper functions for the two BERT pretraining tasks: next sentence prediction and masked language modeling . In order to generate training examples for the masked language modeling task from a BERT input sequence, we define the following _replace_mlm_tokens function. In its inputs , tokens is a list of tokens representing a BERT input sequence , candidate_pred_positions is a list of token indices of the BERT input sequence excluding those of special tokens ( special tokens are not predicted in the masked language modeling task ) , and num_mlm_preds indicates the number of predictions ( recall 15% random tokens to predict) . ',
        'These helper functions will be invoked later when transforming the raw text corpus into the dataset of the ideal format to pretrain BERT . Following the definition of the masked language modeling task in Section 15.8.5.1 , at each prediction position , the input may be replaced by a special “ <mask> ” token or a random token , or remain unchanged . In the end, the function returns the input tokens after possible replacement , the token indices where predictions take place and labels for these predictions .']

# strip: Return a copy of the string with leading and trailing whitespace removed. If chars is given and not None, remove characters in chars instead.
paragraphs = [line.strip().lower().split(' . ')
                for line in text if len(line.split(' . ')) >= 2]
paragraphs

[['the following function generates training examples for next sentence prediction from the input paragraph by invoking the _get_next_sentence function',
  'here paragraph is a list of sentences , where each sentence is a list of tokens .'],
 ['paragraphs with at least two sentences are retained',
  'to split sentences , we only use the period as the delimiter for simplicity',
  'we leave discussions of more complex sentence splitting techniques in the exercises at the end of this section .'],
 ['in the following , we begin by implementing helper functions for the two bert pretraining tasks: next sentence prediction and masked language modeling',
  'in order to generate training examples for the masked language modeling task from a bert input sequence, we define the following _replace_mlm_tokens function. in its inputs , tokens is a list of tokens representing a bert input sequence , candidate_pred_positions is a list of token indices of the bert input sequence excluding those of speci

Get NSP

In [4]:
# input: a sentence and consecutive sentence, and the paragraph we want to sample for the false sample
# #sentence and #next_sentence = 1

def _get_next_sentence(sentence, next_sentence, paragraphs):
    if random.random() < 0.5:   # whether generate a negative sample
        is_next = True
    else:
        # `paragraphs` is a list of lists of lists
        next_sentence = random.choice(random.choice(paragraphs)) # first randomly choose a row of para, then randomly choose a sentence within the row
        is_next = False
    return sentence, next_sentence, is_next

In [None]:
_get_next_sentence(['asd'], ['fdf'], paragraphs)

In [None]:
_get_next_sentence(['asd'], ['fdf'], paragraphs)

In [7]:
# paragraph: a list of sentences
# paragraphs: the corpus to get negative sample
# max_len: specifies the maximum length of a BERT input sequence during pretraining.

def _get_nsp_data_from_paragraph(paragraph, paragraphs, vocab, max_len):
    nsp_data_from_paragraph = []
    for i in range(len(paragraph) - 1):
        tokens_a, tokens_b, is_next = _get_next_sentence(
            paragraph[i], paragraph[i + 1], paragraphs)
        # Consider 1 '<cls>' token and 2 '<sep>' tokens
        
        if len(tokens_a) + len(tokens_b) + 3 > max_len:
            continue
        tokens, segments = d2l.get_tokens_and_segments(tokens_a, tokens_b)
        nsp_data_from_paragraph.append((tokens, segments, is_next))
        # nsp_data_from_paragraph.append((tokens_a, tokens_b, is_next)) # test mode
    return nsp_data_from_paragraph

In [6]:
_get_nsp_data_from_paragraph(['fds', 'ers', 'azc'], paragraphs, 2, 1000)

[('fds', 'paragraphs with at least two sentences are retained', False),
 ('ers',
  'in the following , we begin by implementing helper functions for the two bert pretraining tasks: next sentence prediction and masked language modeling',
  False)]

Get MLM

In [8]:
#tokens: a list of tokens representing a BERT input sequence
# candidate_pred_positions:  a list of token indices of the BERT input sequence excluding 
#                            those of special tokens (special tokens are not predicted in the masked language modeling task)
# num_mlm_preds: indicates the number of predictions (recall 15% random tokens to predict)

def _replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds,
                        vocab):
    # For the input of a masked language model, make a new copy of tokens and
    # replace some of them by '<mask>' or random tokens
    mlm_input_tokens = tokens[:]
    pred_positions_and_labels = []
    # Shuffle for getting 15% random tokens for prediction in the masked
    # language modeling task
    random.shuffle(candidate_pred_positions)
    for mlm_pred_position in candidate_pred_positions:
        if len(pred_positions_and_labels) >= num_mlm_preds:
            break
        masked_token = None
        # 80% of the time: replace the word with the '<mask>' token
        if random.random() < 0.3:
            masked_token = '<mask>'
        else:
            # 10% of the time: keep the word unchanged
            if random.random() < 0.1:
                masked_token = tokens[mlm_pred_position]
            # 10% of the time: replace the word with a random word
            else:
                masked_token = random.choice(vocab.idx_to_token)
        mlm_input_tokens[mlm_pred_position] = masked_token
        pred_positions_and_labels.append(
            (mlm_pred_position, tokens[mlm_pred_position]))
    return mlm_input_tokens, pred_positions_and_labels

In [38]:
#test

class ss():
    def __init__(self, vocab):
        self.idx_to_token = vocab

a = ss(["ab", 'gr', 'cv'])

In [40]:
# test
_replace_mlm_tokens(tokens=[3,4,2,3,5,7,4,2,3], candidate_pred_positions=[1, 2, 3, 4, 5, 6, 7], num_mlm_preds=3, vocab=a)

([3, 4, 2, '<mask>', 5, 7, 4, 'ab', 3], [(7, 2), (2, 2), (3, 3)])

In [44]:
#@save
def _get_mlm_data_from_tokens(tokens, vocab):
    candidate_pred_positions = []
    # `tokens` is a list of strings
    for i, token in enumerate(tokens):
        # Special tokens are not predicted in the masked language modeling
        # task
        if token in ['<cls>', '<sep>']:
            continue
        candidate_pred_positions.append(i)
    # 15% of random tokens are predicted in the masked language modeling task
    num_mlm_preds = max(1, round(len(tokens) * 0.15))
    mlm_input_tokens, pred_positions_and_labels = _replace_mlm_tokens(
        tokens, candidate_pred_positions, num_mlm_preds, vocab)
    pred_positions_and_labels = sorted(pred_positions_and_labels,
                                       key=lambda x: x[0])
    pred_positions = [v[0] for v in pred_positions_and_labels]
    mlm_pred_labels = [v[1] for v in pred_positions_and_labels]
    return vocab[mlm_input_tokens], pred_positions, vocab[mlm_pred_labels]

padding

In [45]:
#@save
def _pad_bert_inputs(examples, max_len, vocab):
    max_num_mlm_preds = round(max_len * 0.15)
    all_token_ids, all_segments, valid_lens,  = [], [], []
    all_pred_positions, all_mlm_weights, all_mlm_labels = [], [], []
    nsp_labels = []
    for (token_ids, pred_positions, mlm_pred_label_ids, segments,
         is_next) in examples:
        all_token_ids.append(torch.tensor(token_ids + [vocab['<pad>']] * (
            max_len - len(token_ids)), dtype=torch.long))
        all_segments.append(torch.tensor(segments + [0] * (
            max_len - len(segments)), dtype=torch.long))
        # `valid_lens` excludes count of '<pad>' tokens
        valid_lens.append(torch.tensor(len(token_ids), dtype=torch.float32))
        all_pred_positions.append(torch.tensor(pred_positions + [0] * (
            max_num_mlm_preds - len(pred_positions)), dtype=torch.long))
        # Predictions of padded tokens will be filtered out in the loss via
        # multiplication of 0 weights
        all_mlm_weights.append(
            torch.tensor([1.0] * len(mlm_pred_label_ids) + [0.0] * (
                max_num_mlm_preds - len(pred_positions)),
                dtype=torch.float32))
        all_mlm_labels.append(torch.tensor(mlm_pred_label_ids + [0] * (
            max_num_mlm_preds - len(mlm_pred_label_ids)), dtype=torch.long))
        nsp_labels.append(torch.tensor(is_next, dtype=torch.long))
    return (all_token_ids, all_segments, valid_lens, all_pred_positions,
            all_mlm_weights, all_mlm_labels, nsp_labels)

dataset

In [57]:
#@save
class _WikiTextDataset(torch.utils.data.Dataset):
    def __init__(self, paragraphs, max_len):
        # Input `paragraphs[i]` is a list of sentence strings representing a
        # paragraph; while output `paragraphs[i]` is a list of sentences
        # representing a paragraph, where each sentence is a list of tokens
        paragraphs = [d2l.tokenize(
            paragraph, token='word') for paragraph in paragraphs]
        sentences = [sentence for paragraph in paragraphs
                     for sentence in paragraph]
        self.vocab = d2l.Vocab(sentences, min_freq=5, reserved_tokens=[
            '<pad>', '<mask>', '<cls>', '<sep>'])
        # Get data for the next sentence prediction task
        examples = []
        for paragraph in paragraphs:
            examples.extend(_get_nsp_data_from_paragraph(
                paragraph, paragraphs, self.vocab, max_len))
        # Get data for the masked language model task
        examples = [(_get_mlm_data_from_tokens(tokens, self.vocab)
                      + (segments, is_next))
                     for tokens, segments, is_next in examples]
        # Pad inputs
        (self.all_token_ids, self.all_segments, self.valid_lens,
         self.all_pred_positions, self.all_mlm_weights,
         self.all_mlm_labels, self.nsp_labels) = _pad_bert_inputs(
            examples, max_len, self.vocab)

    def __getitem__(self, idx):
        return (self.all_token_ids[idx], self.all_segments[idx],
                self.valid_lens[idx], self.all_pred_positions[idx],
                self.all_mlm_weights[idx], self.all_mlm_labels[idx],
                self.nsp_labels[idx])

    def __len__(self):
        return len(self.all_token_ids)

In [70]:
b = _WikiTextDataset(paragraphs, 100)
what = ['all_token_ids', 'all_segments', 'valid_lens', 'all_pred_positions', 'all_mlm_weights', 'all_mlm__labels', 'nsp_labels']
for i, a in enumerate(b.__getitem__(3)):
    print(f'{what[i]}: \n{a}\n')

all_token_ids: 
tensor([ 1,  1, 12,  5,  0,  5,  2,  5,  5,  5,  2,  8, 12,  7,  7,  1,  5,  5,
         5,  5,  5,  5,  5,  5,  4,  5, 12,  5, 11, 12,  3,  5,  5,  2,  9, 12,
         5,  0,  5,  5,  5,  5,  0, 12, 10,  5,  5,  2,  5,  6,  5,  5,  2,  5,
         5,  5, 10,  5,  5,  0,  5,  5,  5,  4,  3,  3,  3,  3,  3,  3,  3,  3,
         3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,
         3,  3,  3,  3,  3,  3,  3,  3,  3,  3])

all_segments: 
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0])

valid_lens: 
64.0

all_pred_positions: 
tensor([ 1,  6, 10, 13, 15, 30, 33, 35, 47, 56,  0,  0,  0,  0,  0])

all_mlm_weights: 
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0

In [72]:
#@save
def load_data_wiki(batch_size, max_len):
    """Load the WikiText-2 dataset."""
    num_workers = d2l.get_dataloader_workers()
    data_dir = d2l.download_extract('wikitext-2', 'wikitext-2')
    paragraphs = _read_wiki(data_dir)
    train_set = _WikiTextDataset(paragraphs, max_len)
    train_iter = torch.utils.data.DataLoader(train_set, batch_size,
                                        shuffle=True, num_workers=num_workers)
    return train_iter, train_set.vocab

In [73]:
batch_size, max_len = 512, 64
train_iter, vocab = load_data_wiki(batch_size, max_len)

for (tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X,
     mlm_Y, nsp_y) in train_iter:
    print(tokens_X.shape, segments_X.shape, valid_lens_x.shape,
          pred_positions_X.shape, mlm_weights_X.shape, mlm_Y.shape,
          nsp_y.shape)
    break

Downloading ../data/wikitext-2-v1.zip from https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip...
torch.Size([512, 64]) torch.Size([512, 64]) torch.Size([512]) torch.Size([512, 10]) torch.Size([512, 10]) torch.Size([512, 10]) torch.Size([512])


In [75]:
len(vocab)

20256

In [82]:
vocab.token_to_idx['(']

45