# Pretrain Bert on MLM and NSP Simultaneously

Ref:

(i) https://stackoverflow.com/questions/70122842/bert-pre-training-mlm-nsp?rq=1


(ii) https://stackoverflow.com/questions/65646925/how-to-train-bert-from-scratch-on-a-new-domain-for-both-mlm-and-nsp


(iii) https://www.thepythoncode.com/article/pretraining-bert-huggingface-transformers-in-python

In [1]:
!pip install datasets transformers

Collecting datasets
  Downloading datasets-1.16.1-py3-none-any.whl (298 kB)
[K     |████████████████████████████████| 298 kB 5.4 MB/s 
[?25hCollecting transformers
  Downloading transformers-4.14.1-py3-none-any.whl (3.4 MB)
[K     |████████████████████████████████| 3.4 MB 36.0 MB/s 
Collecting huggingface-hub<1.0.0,>=0.1.0
  Downloading huggingface_hub-0.2.1-py3-none-any.whl (61 kB)
[K     |████████████████████████████████| 61 kB 246 kB/s 
Collecting xxhash
  Downloading xxhash-2.0.2-cp37-cp37m-manylinux2010_x86_64.whl (243 kB)
[K     |████████████████████████████████| 243 kB 38.8 MB/s 
Collecting fsspec[http]>=2021.05.0
  Downloading fsspec-2021.11.1-py3-none-any.whl (132 kB)
[K     |████████████████████████████████| 132 kB 32.5 MB/s 
Collecting aiohttp
  Downloading aiohttp-3.8.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.1 MB)
[K     |████████████████████████████████| 1.1 MB 24.1 MB/s 
Collecting pyyaml
  Downloading Py

In [2]:
import nltk
from nltk.tokenize import sent_tokenize
from transformers import (
    BertTokenizer,
    BertTokenizerFast,
    BertConfig, 
    BertForPreTraining, 
    TextDatasetForNextSentencePrediction,
    DataCollatorForLanguageModeling,
    Trainer, 
    TrainingArguments
)
import torch
from datasets import load_dataset, concatenate_datasets

In [None]:
# For sentence tokenization
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [3]:
# CONFIGS

RANDOM_SEED=37

DATASET_LIMIT = 20_000

MODEL_MAX_LEN = 512

NSP_DATESET_PATH = 'nsp.txt'
MLM_TRAIN_DATESET_PATH = 'mlm_train.txt'
MLM_TEST_DATESET_PATH = 'mlm_test.txt'
MLM_MASKING_PROB = .15

MODEL_NAME = "bert-base-uncased"
# MODEL_NAME = "bert-base-multilingual-uncased"

MODEL_SAVE_PATH = MODEL_NAME

## Load Dataset

In [None]:
# wiki = load_dataset("wikipedia", "20200501.en", split="train")
# bookcorpus = load_dataset("bookcorpus", split="train")
# print(wiki.column_names, bookcorpus.column_names)
# # ['title', 'text'] ['text']

# wiki.remove_columns_("title")
# bert_dataset = concatenate_datasets([wiki, bookcorpus])


dataset = load_dataset("cc_news", split="train")

bert_dataset = dataset

Downloading:   0%|          | 0.00/1.75k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/932 [00:00<?, ?B/s]

Downloading and preparing dataset cc_news/plain_text (download: 805.98 MiB, generated: 1.88 GiB, post-processed: Unknown size, total: 2.67 GiB) to /root/.cache/huggingface/datasets/cc_news/plain_text/1.0.0/ae469e556251e6e7e20a789f93803c7de19d0c4311b6854ab072fecb4e401bd6...


Downloading:   0%|          | 0.00/845M [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

Dataset cc_news downloaded and prepared to /root/.cache/huggingface/datasets/cc_news/plain_text/1.0.0/ae469e556251e6e7e20a789f93803c7de19d0c4311b6854ab072fecb4e401bd6. Subsequent calls will reuse this data.


In [None]:
bert_dataset

Dataset({
    features: ['title', 'text', 'domain', 'date', 'description', 'url', 'image_url'],
    num_rows: 708241
})

In [None]:
bert_dataset[0]

{'date': '2017-12-11 20:19:05',
 'description': "There's a surprising twist to Regina Willoughby's last season with\xa0Columbia City Ballet: It's also her 18-year-old daughter Melina's first season with the company.",
 'domain': 'www.pointemagazine.com',
 'image_url': 'https://pointe-img.rbl.ms/simage/https%3A%2F%2Fassets.rbl.ms%2F16807693%2F980x.png/2000%2C2000/3VnhNGWp75K4SwMx/img.png',
 'text': 'There\'s a surprising twist to Regina Willoughby\'s last season with Columbia City Ballet: It\'s also her 18-year-old daughter Melina\'s first season with the company. Regina, 40, will retire from the stage in March, just as her daughter starts her own career as a trainee. But for this one season, they\'re sharing the stage together.\nPerforming Side-By-Side In The Nutcracker\nRegina and Melina are not only dancing in the same Nutcracker this month, they\'re onstage at the same time: Regina is doing Snow Queen, while Melina is in the snow corps, and they\'re both in the Arabian divertissemen

### For MLM

In [None]:
# def split_string(str, limit, sep=" "):
#     """
#     Split a long string into list of substrings each of
#     which has length less than the given limit.
#     """
#     words = str.split()
#     words = list(filter(lambda x: len(x)<limit, words))
#     if max(map(len, words)) > limit:
#         raise ValueError("limit is too small")
#     res, part, others = [], words[0], words[1:]
#     for word in others:
#         if len(sep)+len(word) > limit-len(part):
#             res.append(part)
#             part = word
#         else:
#             part += sep+word
#     if part:
#         res.append(part)
#     return res

In [None]:
# # split the dataset into training (90%) and testing (10%)
# d = bert_dataset.train_test_split(test_size=0.1, seed=RANDOM_SEED)

### For NSP
(1) One sentence per line. 

(2) Blank lines between documents


ref: https://stackoverflow.com/questions/65646925/how-to-train-bert-from-scratch-on-a-new-domain-for-both-mlm-and-nsp

In [None]:
with open(NSP_DATESET_PATH, "w") as f:
  for document in bert_dataset[:DATASET_LIMIT]["text"]:
    # replace paragraph changes with fullstop for sentence segmentation
    document = document.replace('\n', ' ')
    for sentence in sent_tokenize(document):
      sentence = sentence.strip()
      if sentence != '':
          print(sentence, file=f)
    # line break for each document
    print('', file=f)

In [None]:
# with open(NSP_DATESET_PATH, "w") as f:
#   for document in bert_dataset[:DATASET_LIMIT]["text"]:
#     # replace paragraph changes with fullstop for sentence segmentation
#     document = document.replace('\n', ' ')
  
#     for sentence in sent_tokenize(document):
#       sentence = sentence.strip()
  
#       if sentence != '':
#         sentence_tokens = sentence.split(' ')
#         # filter successive space chars
#         sentence_tokens = list(filter(lambda token: token!='', sentence_tokens)) 

#         if len(sentence_tokens)<=MODEL_MAX_LEN:
#           # If string is less than the max model length
#           print(' '.join(sentence_tokens), file=f)
#         else:
#           splitted_substrings = split_string(sentence, MODEL_MAX_LEN)
#           for substring in splitted_substrings:
#             print(substring, file=f)
#     # line break for each document
#     print('', file=f)


# Remove final line breaks  
with open(NSP_DATESET_PATH) as f_input:
    data = f_input.read().rstrip('\n')

with open(NSP_DATESET_PATH, 'w') as f_output:    
    f_output.write(data)

## Tokenizer

In [4]:
bert_tokenizer = BertTokenizer.from_pretrained(MODEL_NAME, model_max_length=MODEL_MAX_LEN)
# bert_tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME, max_len=512)

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

## Model

In [5]:
config = BertConfig(
    num_hidden_layers=4, 
    num_attention_heads=4,
    vocab_size = bert_tokenizer.vocab_size,
    max_position_embeddings=MODEL_MAX_LEN
)
model = BertForPreTraining(config)

In [6]:
device = torch.device('cuda')# and move our model over to the selected device
model.to(device)

BertForPreTraining(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine

#### NSP

In [7]:
dataset = TextDatasetForNextSentencePrediction(
    tokenizer=bert_tokenizer,
    file_path=NSP_DATESET_PATH,
    block_size = MODEL_MAX_LEN
)



### Truncate the tokens with Max model length

In [8]:
ending_sep_token_tensor = torch.tensor([102])

for i in dataset.examples:
    if len(i['input_ids'])>512:
        i['input_ids'] = torch.cat((i['input_ids'][:MODEL_MAX_LEN-1], ending_sep_token_tensor), 0)
        i['token_type_ids'] = i['token_type_ids'][:MODEL_MAX_LEN]

In [9]:
len(dataset.examples[2]['input_ids'])

300

In [None]:
# documents = [[]]
# with open(NSP_DATESET_PATH, encoding="utf-8") as f:
#     while True:
#         line = f.readline()
#         if not line:
#             break
#         line = line.strip()

#         # Empty lines are used as document delimiters
#         if not line and len(documents[-1]) != 0:
#             documents.append([])
#         tokens = bert_tokenizer.tokenize(line)
#         tokens = bert_tokenizer.convert_tokens_to_ids(tokens)
#         if tokens:
#             documents[-1].append(tokens)

In [None]:
# c = 0
# m = 0
# for doc in documents:
#     for sent in doc:
#         if len(sent)>512:
#             c+=1
#             if len(sent)>m:
#                 m=len(sent)

# print(c)
# print(m)

#### MLM


Use DataCollatorForLanguageModeling for masking and passing the labels that are generated from TextDatasetForNextSentencePrediction. DataCollatorForNextSentencePrediction has been removed, since it was doing the same thing with DataCollatorForLanguageModeling

In [10]:
data_collator = DataCollatorForLanguageModeling(
    tokenizer=bert_tokenizer, 
    mlm=True,
    mlm_probability= MLM_MASKING_PROB
)

### Training


In [11]:
training_args = TrainingArguments(
    output_dir= "results",
    overwrite_output_dir=True,
    num_train_epochs=5,
    per_device_train_batch_size= 16,
    save_steps=1000,
    save_on_each_node=True,
    save_total_limit=2,
    prediction_loss_only=True,
)

In [30]:
import torch
torch.cuda.empty_cache()

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset,
)

trainer.train()

***** Running training *****
  Num examples = 53003
  Num Epochs = 5
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 16565


Step,Training Loss
500,8.0946
1000,7.6094
1500,7.4675
2000,7.3625
2500,7.2866
3000,7.2311
3500,7.1758
4000,7.1325
4500,7.0998
5000,7.0752


Saving model checkpoint to results/checkpoint-1000
Configuration saved in results/checkpoint-1000/config.json
Model weights saved in results/checkpoint-1000/pytorch_model.bin
Saving model checkpoint to results/checkpoint-2000
Configuration saved in results/checkpoint-2000/config.json
Model weights saved in results/checkpoint-2000/pytorch_model.bin
Saving model checkpoint to results/checkpoint-3000
Configuration saved in results/checkpoint-3000/config.json
Model weights saved in results/checkpoint-3000/pytorch_model.bin
Deleting older checkpoint [results/checkpoint-1000] due to args.save_total_limit
Saving model checkpoint to results/checkpoint-4000
Configuration saved in results/checkpoint-4000/config.json
Model weights saved in results/checkpoint-4000/pytorch_model.bin
Deleting older checkpoint [results/checkpoint-2000] due to args.save_total_limit
Saving model checkpoint to results/checkpoint-5000
Configuration saved in results/checkpoint-5000/config.json
Model weights saved in resul

In [None]:
trainer.save_model(MODEL_SAVE_PATH)