# Pretrain Bert on MLM and NSP Simultaneously

Ref:

(i) https://stackoverflow.com/questions/70122842/bert-pre-training-mlm-nsp?rq=1


(ii) https://stackoverflow.com/questions/65646925/how-to-train-bert-from-scratch-on-a-new-domain-for-both-mlm-and-nsp


(iii) https://www.thepythoncode.com/article/pretraining-bert-huggingface-transformers-in-python

In [None]:
!pip install datasets transformers

Collecting datasets
  Downloading datasets-1.17.0-py3-none-any.whl (306 kB)
[?25l[K     |█                               | 10 kB 28.2 MB/s eta 0:00:01[K     |██▏                             | 20 kB 24.2 MB/s eta 0:00:01[K     |███▏                            | 30 kB 17.4 MB/s eta 0:00:01[K     |████▎                           | 40 kB 15.3 MB/s eta 0:00:01[K     |█████▍                          | 51 kB 7.5 MB/s eta 0:00:01[K     |██████▍                         | 61 kB 7.7 MB/s eta 0:00:01[K     |███████▌                        | 71 kB 8.2 MB/s eta 0:00:01[K     |████████▋                       | 81 kB 9.2 MB/s eta 0:00:01[K     |█████████▋                      | 92 kB 9.4 MB/s eta 0:00:01[K     |██████████▊                     | 102 kB 7.3 MB/s eta 0:00:01[K     |███████████▊                    | 112 kB 7.3 MB/s eta 0:00:01[K     |████████████▉                   | 122 kB 7.3 MB/s eta 0:00:01[K     |██████████████                  | 133 kB 7.3 MB/s eta 0:00:01

In [None]:
import os
import nltk
from nltk.tokenize import sent_tokenize
from transformers import (
    BertTokenizer,
    BertTokenizerFast,
    BertConfig,
    BertModel,
    BertForPreTraining, 
    TextDatasetForNextSentencePrediction,
    DataCollatorForLanguageModeling,
    Trainer, 
    TrainingArguments
)
import torch
from datasets import load_dataset, concatenate_datasets

In [None]:
# For sentence tokenization
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

## Configuration

### Training

In [None]:
# CONFIGS

RANDOM_SEED=37

DATASET_LIMIT = 150_000

MODEL_MAX_LEN = 512

NSP_DATESET_PATH = 'nsp.txt'
MLM_MASKING_PROB = .15

MODEL_NAME = "bert-base-uncased"
# MODEL_NAME = "bert-base-multilingual-uncased"

# VOCAB_NAME = "bert-base-uncased"
VOCAB_NAME = "bert-base-multilingual-uncased"

# VOCAB = 'eng'
VOCAB = {
    'bert-base-uncased' : 'eng',
    'bert-base-multilingual-uncased': 'mult'
}[VOCAB_NAME]



### Drive Path

In [None]:
# MOUNTING DRIVE TO ACCESS DATASET
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# MOUNT PATH
DRIVE_PATH = os.path.join('drive','MyDrive','collab','research', 'bert_scratch')


MODEL_SAVE_PATH = os.path.join(DRIVE_PATH, f"{MODEL_NAME.replace('-','_')}_{VOCAB}_wiki_mlm_nsp")
print(MODEL_SAVE_PATH)

drive/MyDrive/collab/research/bert_scratch/bert_base_uncased_mult_wiki_mlm_nsp


## Dataset

In [None]:
wiki = load_dataset("wikipedia", "20200501.en", split="train")
# bookcorpus = load_dataset("bookcorpus", split="train")
# print(wiki.column_names, bookcorpus.column_names)
# # ['title', 'text'] ['text']

# wiki.remove_columns_("title")
# bert_dataset = concatenate_datasets([wiki, bookcorpus])


# dataset = load_dataset("cc_news", split="train")

bert_dataset = wiki

Downloading:   0%|          | 0.00/4.24k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/6.87k [00:00<?, ?B/s]

Downloading and preparing dataset wikipedia/20200501.en (download: 16.99 GiB, generated: 17.07 GiB, post-processed: Unknown size, total: 34.06 GiB) to /root/.cache/huggingface/datasets/wikipedia/20200501.en/1.0.0/009f923d9b6dd00c00c8cdc7f408f2b47f45dd4f5fb7982a21f9448f4afbe475...


Downloading:   0%|          | 0.00/14.6k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/18.3G [00:00<?, ?B/s]

Dataset wikipedia downloaded and prepared to /root/.cache/huggingface/datasets/wikipedia/20200501.en/1.0.0/009f923d9b6dd00c00c8cdc7f408f2b47f45dd4f5fb7982a21f9448f4afbe475. Subsequent calls will reuse this data.


In [None]:
bert_dataset

Dataset({
    features: ['title', 'text'],
    num_rows: 6078422
})

In [None]:
bert_dataset[0]

{'text': 'Yangliuqing () is a market town in Xiqing District, in the western suburbs of Tianjin, People\'s Republic of China. Despite its relatively small size, it has been named since 2006 in the "famous historical and cultural market towns in China".\n\nIt is best known in China for creating nianhua or Yangliuqing nianhua. For more than 400 years, Yangliuqing has in effect specialised in the creation of these woodcuts for the New Year.  wood block prints using vivid colourschemes to portray traditional scenes of children\'s games often interwoven with auspiciouse objects.\n\n, it had 27 residential communities () and 25 villages under its administration.\n\nShi Family Grand Courtyard\n\nShi Family Grand Courtyard (Tiānjīn Shí Jiā Dà Yuàn, 天津石家大院) is situated in Yangliuqing Town of Xiqing District, which is the former residence of wealthy merchant Shi Yuanshi - the 4th son of Shi Wancheng, one of the eight great masters in Tianjin. First built in 1875, it covers over 6,000 square mete

### NSP
(1) One sentence per line. 

(2) Blank lines between documents


ref: https://stackoverflow.com/questions/65646925/how-to-train-bert-from-scratch-on-a-new-domain-for-both-mlm-and-nsp

In [None]:
with open(NSP_DATESET_PATH, "w") as f:
  for document in bert_dataset[:DATASET_LIMIT]["text"]:
    # replace paragraph changes with fullstop for sentence segmentation
    document = document.replace('\n', ' ')
    for sentence in sent_tokenize(document):
      sentence = sentence.strip()
      if sentence != '':
          sentence_tokens = sentence.split(' ')
          # filter successive space chars
          sentence_tokens = list(filter(lambda token: token!='', sentence_tokens)) 
          # print(sentence, file=f)
          print(' '.join(sentence_tokens), file=f)
    # line break for each document
    print('', file=f)

In [None]:
# with open(NSP_DATESET_PATH, "w") as f:
#   for document in bert_dataset[:DATASET_LIMIT]["text"]:
#     # replace paragraph changes with fullstop for sentence segmentation
#     document = document.replace('\n', ' ')
  
#     for sentence in sent_tokenize(document):
#       sentence = sentence.strip()
  
#       if sentence != '':
#         sentence_tokens = sentence.split(' ')
#         # filter successive space chars
#         sentence_tokens = list(filter(lambda token: token!='', sentence_tokens)) 

#         if len(sentence_tokens)<=MODEL_MAX_LEN:
#           # If string is less than the max model length
#           print(' '.join(sentence_tokens), file=f)
#         else:
#           splitted_substrings = split_string(sentence, MODEL_MAX_LEN)
#           for substring in splitted_substrings:
#             print(substring, file=f)
#     # line break for each document
#     print('', file=f)


#### Collapse final line breaks

In [None]:
# Remove final line breaks  
with open(NSP_DATESET_PATH) as f_input:
    data = f_input.read().rstrip('\n')

with open(NSP_DATESET_PATH, 'w') as f_output:    
    f_output.write(data)

## Tokenizer

In [None]:
# bert_tokenizer = BertTokenizer.from_pretrained(VOCAB_NAME, model_max_length=MODEL_MAX_LEN)

bert_tokenizer = BertTokenizerFast.from_pretrained(VOCAB_NAME, max_len=MODEL_MAX_LEN)

print(f'Using tokenizer: {VOCAB_NAME}')
print(f'Vocabulary size is: {bert_tokenizer.vocab_size}')

Downloading:   0%|          | 0.00/851k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.64M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/625 [00:00<?, ?B/s]

Using tokenizer: bert-base-multilingual-uncased
Vocabulary size is: 105879


In [None]:
dataset = TextDatasetForNextSentencePrediction(
    tokenizer=bert_tokenizer,
    file_path=NSP_DATESET_PATH,
    block_size = MODEL_MAX_LEN
)

Token indices sequence length is longer than the specified maximum sequence length for this model (527 > 512). Running this sequence through the model will result in indexing errors


Truncate tokens with Model max length

In [None]:
ending_sep_token_tensor = torch.tensor([102])

for i in dataset.examples:
    if len(i['input_ids'])>512:
        i['input_ids'] = torch.cat((i['input_ids'][:MODEL_MAX_LEN-1], ending_sep_token_tensor), 0)
        i['token_type_ids'] = i['token_type_ids'][:MODEL_MAX_LEN]

In [None]:
# documents = [[]]
# with open(NSP_DATESET_PATH, encoding="utf-8") as f:
#     while True:
#         line = f.readline()
#         if not line:
#             break
#         line = line.strip()

#         # Empty lines are used as document delimiters
#         if not line and len(documents[-1]) != 0:
#             documents.append([])
#         tokens = bert_tokenizer.tokenize(line)
#         tokens = bert_tokenizer.convert_tokens_to_ids(tokens)
#         if tokens:
#             documents[-1].append(tokens)

In [None]:
# c = 0
# m = 0
# for doc in documents:
#     for sent in doc:
#         if len(sent)>512:
#             c+=1
#             if len(sent)>m:
#                 m=len(sent)

# print(c)
# print(m)

## MLM


Use DataCollatorForLanguageModeling for masking and passing the labels that are generated from TextDatasetForNextSentencePrediction. DataCollatorForNextSentencePrediction has been removed, since it was doing the same thing with DataCollatorForLanguageModeling

In [None]:
data_collator = DataCollatorForLanguageModeling(
    tokenizer=bert_tokenizer, 
    mlm=True,
    mlm_probability= MLM_MASKING_PROB
)

## Model

Pretraining BERT from scratch using config

In [None]:
config = BertConfig(
    num_hidden_layers=4, 
    num_attention_heads=4,
    vocab_size = bert_tokenizer.vocab_size,
    max_position_embeddings=MODEL_MAX_LEN
)
model = BertForPreTraining(config)

### GPU

In [None]:
device = torch.device('cuda')# and move our model over to the selected device
model.to(device)

BertForPreTraining(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(105879, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affin

### Training


In [None]:
training_args = TrainingArguments(
    output_dir= MODEL_SAVE_PATH,
    overwrite_output_dir=True,
    num_train_epochs=5,
    per_device_train_batch_size= 8,
    save_steps=10000,
    save_on_each_node=True,
    save_total_limit=1,
    prediction_loss_only=True,
)

In [None]:
# import torch
# torch.cuda.empty_cache()

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset,
)

# trainer.train()
# load from existing checkpoint
trainer.train(os.path.join(MODEL_SAVE_PATH, "checkpoint-260000"))

Loading model from drive/MyDrive/collab/research/bert_scratch/bert_base_uncased_mult_wiki_mlm_nsp/checkpoint-260000).
***** Running training *****
  Num examples = 448093
  Num Epochs = 5
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 280060
  Continuing training from checkpoint, will skip to saved global_step
  Continuing training from epoch 4
  Continuing training from global step 260000
  Will skip the first 4 epochs then the first 35952 batches in the first epoch. If this takes a lot of time, you can add the `--ignore_data_skip` flag to your launch command, but you will resume the training on data already seen by your model.


  0%|          | 0/35952 [00:00<?, ?it/s]

Step,Training Loss
260500,3.0349
261000,3.0504
261500,3.1032
262000,3.11
262500,3.1291
263000,3.0888
263500,3.0696
264000,3.0995
264500,3.0813
265000,3.0721


Saving model checkpoint to drive/MyDrive/collab/research/bert_scratch/bert_base_uncased_mult_wiki_mlm_nsp/checkpoint-270000
Configuration saved in drive/MyDrive/collab/research/bert_scratch/bert_base_uncased_mult_wiki_mlm_nsp/checkpoint-270000/config.json
Model weights saved in drive/MyDrive/collab/research/bert_scratch/bert_base_uncased_mult_wiki_mlm_nsp/checkpoint-270000/pytorch_model.bin
Deleting older checkpoint [drive/MyDrive/collab/research/bert_scratch/bert_base_uncased_mult_wiki_mlm_nsp/checkpoint-260000] due to args.save_total_limit
Saving model checkpoint to drive/MyDrive/collab/research/bert_scratch/bert_base_uncased_mult_wiki_mlm_nsp/checkpoint-280000
Configuration saved in drive/MyDrive/collab/research/bert_scratch/bert_base_uncased_mult_wiki_mlm_nsp/checkpoint-280000/config.json
Model weights saved in drive/MyDrive/collab/research/bert_scratch/bert_base_uncased_mult_wiki_mlm_nsp/checkpoint-280000/pytorch_model.bin
Deleting older checkpoint [drive/MyDrive/collab/research/b

TrainOutput(global_step=280060, training_loss=0.21956682500094846, metrics={'train_runtime': 9067.1981, 'train_samples_per_second': 247.096, 'train_steps_per_second': 30.887, 'total_flos': 2.425860767707909e+17, 'train_loss': 0.21956682500094846, 'epoch': 5.0})

In [None]:
MODEL_SAVE_PATH

'drive/MyDrive/collab/research/bert_scratch/bert_base_uncased_mult_wiki_mlm_nsp'

In [None]:
trainer.save_model(MODEL_SAVE_PATH)

Saving model checkpoint to drive/MyDrive/collab/research/bert_scratch/bert_base_uncased_mult_wiki_mlm_nsp
Configuration saved in drive/MyDrive/collab/research/bert_scratch/bert_base_uncased_mult_wiki_mlm_nsp/config.json
Model weights saved in drive/MyDrive/collab/research/bert_scratch/bert_base_uncased_mult_wiki_mlm_nsp/pytorch_model.bin


In [None]:
trainer.save_model(f'bert_base_uncased_{VOCAB}_wiki_mlm_nsp')

Saving model checkpoint to bert_base_uncased_mult_wiki_mlm_nsp
Configuration saved in bert_base_uncased_mult_wiki_mlm_nsp/config.json
Model weights saved in bert_base_uncased_mult_wiki_mlm_nsp/pytorch_model.bin
