# Setup

In [91]:
from os import path
from pprint import PrettyPrinter
from typing import Optional

__DIR__ = globals()['_dh'][0]
data_dir = path.relpath(path.join(__DIR__, "..", "_data"))

pp = PrettyPrinter(indent=2, width=120)

In [2]:
# Settings
_testing = False
_colab_install = True
_pm_log_sections = False

# Parameters
dataset = path.join(data_dir, "wiki", "20220301.en.1gb")

base_model = "bert-base-cased"
max_length = 128
vocab_size = 20_000

tokenize_params = dict(batched=True, num_proc=4)
tokenizer_dir = path.join(data_dir, "pretrain", "tokenizer")

mlm_probability = 0.15
bert_config = dict()
training_args = dict(
    optim = "adamw_torch",
    num_train_epochs = 1,
    per_device_train_batch_size = 64,
    eval_accumulation_steps = 10,
    evaluation_strategy = "steps",
    logging_steps = 10000,
    save_steps = 10000,
    save_total_limit = 3,
)
max_eval_samples: Optional[int] = 5000
model_dir = path.join(data_dir, "pretrain", "model")

In [3]:
ds_dir = dataset

if _testing:
    dataset = path.join(data_dir, "wiki", "20220301.en.test")

    training_args.update(dict(
        max_steps = 3,
        logging_steps = 1,
    ))

    max_eval_samples = 1000

## Process settings / parameters

In [49]:
# Colab
try:
    import google.colab
    
    # Wrap output text
    from IPython.display import HTML, display
    
    def set_css():
        display(HTML('''
        <style>
            pre {
                white-space: pre-wrap;
            }
        </style>
        '''))
        get_ipython().events.register('pre_run_cell', set_css)
    
    if _colab_install:
        colab_install_script = path.join(__DIR__, "..", "colab_install.sh")

        if not path.isfile(colab_install_script):
            script_url = ("https://raw.githubusercontent.com/"
                            "yenson-lau/pii-remediation/main/colab_install.sh")
            !wget $script_url -O $colab_install_script

        !bash $colab_install_script

        print()

except ModuleNotFoundError:
    pass

ERROR: source file is older than destination: b2://pretrain/tokenizer/special_tokens_map.json with a time of 1664377989380 cannot be synced to local://pretrain/tokenizer/special_tokens_map.json with a time of 1664469263465, unless a valid newer_file_mode is provided



In [4]:
from collections import OrderedDict

if _pm_log_sections:
    def pm_log_section(message):
        print(f"\n[===== {message} =====]\n")
else:
    def pm_log_section(message):
        return

if _testing:
    pm_log_section("Running on testing mode")

config = OrderedDict(
    dataset = dataset,

    base_model = base_model,
    max_length = max_length,
    vocab_size = vocab_size,

    tokenize_params = tokenize_params,
    tokenizer_dir = tokenizer_dir,

    mlm_probability = mlm_probability,
    bert_config = bert_config,
    training_args = training_args,
    max_eval_samples = max_eval_samples,
    model_dir = model_dir,
)

print(f"{'TESTING ' if _testing else ''}Parameters:")
pp.pprint(config)


Parameters:
OrderedDict([ ('dataset', '../_data/wiki/20220301.en.1gb'),
              ('base_model', 'bert-base-cased'),
              ('max_length', 128),
              ('vocab_size', 20000),
              ('tokenize_params', {'batched': True, 'num_proc': 4}),
              ('tokenizer_dir', '../_data/pretrain/tokenizer'),
              ('mlm_probability', 0.15),
              ('bert_config', {}),
              ( 'training_args',
                { 'eval_accumulation_steps': 10,
                  'evaluation_strategy': 'steps',
                  'logging_steps': 10000,
                  'num_train_epochs': 1,
                  'optim': 'adamw_torch',
                  'per_device_train_batch_size': 64,
                  'save_steps': 10000,
                  'save_total_limit': 3}),
              ('max_eval_samples', 5000),
              ('model_dir', '../_data/pretrain/model')])


# Load dataset

In [5]:
from datasets import Dataset, load_dataset

pm_log_section("Loading dataset")

dataset = dict()
for split in ["train", "val"]:
    data_file = path.join(ds_dir, f"{split}_data.json")
    if not path.isfile(data_file):  data_file += ".gz"
    dataset[split] = load_dataset("json", data_files=data_file, field="data")["train"]

    if ((split != "train") 
        and (max_eval_samples is not None) 
        and (len(dataset[split]) > max_eval_samples)):
        
        dataset[split] = dataset[split].select(range(max_eval_samples))



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]

# Tokenization

In [6]:
from transformers import BertTokenizerFast

pm_log_section("Tokenizing")

In [7]:
if path.isdir(tokenizer_dir):
    tokenizer = BertTokenizerFast.from_pretrained(tokenizer_dir)
else:
    tokenizer = (BertTokenizerFast
                    .from_pretrained(base_model)
                    .train_new_from_iterator(dataset["train"]["text"], vocab_size))
    tokenizer.model_max_length = max_length

    tokenizer.save_pretrained(tokenizer_dir);

Downloading:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/213k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/436k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [8]:
tokenize_function = lambda ex: tokenizer(ex["text"], truncation=True)

tokenized_dataset = {
    k: v.map(tokenize_function, remove_columns = list(v.features), **tokenize_params)
    for k, v in dataset.items()
}

      

#0:   0%|          | 0/2351 [00:00<?, ?ba/s]

  

#1:   0%|          | 0/2351 [00:00<?, ?ba/s]

#3:   0%|          | 0/2351 [00:00<?, ?ba/s]

#2:   0%|          | 0/2351 [00:00<?, ?ba/s]

       

#0:   0%|          | 0/2 [00:00<?, ?ba/s]

 

#1:   0%|          | 0/2 [00:00<?, ?ba/s]

#2:   0%|          | 0/2 [00:00<?, ?ba/s]

#3:   0%|          | 0/2 [00:00<?, ?ba/s]

# Train masked language model

In [9]:
import numpy as np
from transformers import (BertConfig,
                          BertForMaskedLM,
                          DataCollatorForLanguageModeling,
                          Trainer,
                          TrainingArguments)

pm_log_section("Training MLM")

data_collator = DataCollatorForLanguageModeling(tokenizer = tokenizer,
                                                mlm_probability = mlm_probability)

bert_config = BertConfig(vocab_size = tokenizer.vocab_size, **bert_config)
model = BertForMaskedLM(config = bert_config)

training_args = TrainingArguments(output_dir = model_dir,
                                  overwrite_output_dir = True,
                                  **training_args)

# def compute_metrics(eval_preds):
#     idxs0, idxs1 = np.where(eval_preds.label_ids!=-100)

#     preds = np.argmax(eval_preds.predictions[idxs0, idxs1, :], axis=-1)
#     labels = eval_preds.label_ids[idxs0, idxs1]

#     acc = (preds==labels).sum()/len(preds)

#     return {"accuracy": acc}

trainer = Trainer(model = model,
                  args = training_args,
                  data_collator = data_collator,
                  # compute_metrics=compute_metrics,
                  train_dataset = tokenized_dataset["train"],
                  eval_dataset = tokenized_dataset["val"])

In [None]:
trainer.train()
trainer.save_model(model_dir)

***** Running training *****
  Num examples = 9403586
  Num Epochs = 1
  Instantaneous batch size per device = 64
  Total train batch size (w. parallel, distributed & accumulation) = 64
  Gradient Accumulation steps = 1
  Total optimization steps = 146932
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss,Validation Loss
10000,5.9727,4.935154
20000,4.7029,4.35125
30000,4.2704,4.028666
40000,4.0094,3.825895
50000,3.8189,3.638723
60000,3.6626,3.555936
70000,3.5422,3.428748
80000,3.4416,3.319535
90000,3.3547,3.251397
100000,3.2839,3.193127


***** Running Evaluation *****
  Num examples = 5000
  Batch size = 8
Saving model checkpoint to ../_data/pretrain/model/checkpoint-10000
Configuration saved in ../_data/pretrain/model/checkpoint-10000/config.json
Model weights saved in ../_data/pretrain/model/checkpoint-10000/pytorch_model.bin
***** Running Evaluation *****
  Num examples = 5000
  Batch size = 8
Saving model checkpoint to ../_data/pretrain/model/checkpoint-20000
Configuration saved in ../_data/pretrain/model/checkpoint-20000/config.json
Model weights saved in ../_data/pretrain/model/checkpoint-20000/pytorch_model.bin
***** Running Evaluation *****
  Num examples = 5000
  Batch size = 8
Saving model checkpoint to ../_data/pretrain/model/checkpoint-30000
Configuration saved in ../_data/pretrain/model/checkpoint-30000/config.json
Model weights saved in ../_data/pretrain/model/checkpoint-30000/pytorch_model.bin
***** Running Evaluation *****
  Num examples = 5000
  Batch size = 8
Saving model checkpoint to ../_data/pretra

Step,Training Loss,Validation Loss
10000,5.9727,4.935154
20000,4.7029,4.35125
30000,4.2704,4.028666
40000,4.0094,3.825895
50000,3.8189,3.638723
60000,3.6626,3.555936
70000,3.5422,3.428748
80000,3.4416,3.319535
90000,3.3547,3.251397
100000,3.2839,3.193127


***** Running Evaluation *****
  Num examples = 5000
  Batch size = 8
Saving model checkpoint to ../_data/pretrain/model/checkpoint-140000
Configuration saved in ../_data/pretrain/model/checkpoint-140000/config.json
Model weights saved in ../_data/pretrain/model/checkpoint-140000/pytorch_model.bin
Deleting older checkpoint [../_data/pretrain/model/checkpoint-110000] due to args.save_total_limit


Training completed. Do not forget to share your model on huggingface.co/models =)


Saving model checkpoint to ../_data/pretrain/model
Configuration saved in ../_data/pretrain/model/config.json
Model weights saved in ../_data/pretrain/model/pytorch_model.bin


# Evaluation

In [10]:
from datasets import load_dataset
from transformers import BertTokenizerFast

pm_log_section("Evaluating MLM")

data_file = path.join(ds_dir, "val_data.json")
if not path.isfile(data_file):  data_file += ".gz"
val_dataset = load_dataset("json", data_files=data_file, field="data")["train"]

tokenized_val_dataset = val_dataset.map(
    tokenize_function, 
    remove_columns=list(val_dataset.features), 
    **tokenize_params
)



  0%|          | 0/1 [00:00<?, ?it/s]

loading file vocab.txt
loading file tokenizer.json
loading file added_tokens.json
loading file special_tokens_map.json
loading file tokenizer_config.json


        

#0:   0%|          | 0/294 [00:00<?, ?ba/s]

#1:   0%|          | 0/294 [00:00<?, ?ba/s]

#3:   0%|          | 0/294 [00:00<?, ?ba/s]

#2:   0%|          | 0/294 [00:00<?, ?ba/s]

In [11]:
import numpy as np
import torch
from transformers import (BertConfig,
                          BertForMaskedLM,
                          DataCollatorForLanguageModeling,
                          Trainer,
                          TrainingArguments)


model = BertForMaskedLM.from_pretrained(model_dir)

trainer = Trainer(model = model,
                  args = training_args,
                  data_collator = data_collator,
                  eval_dataset = tokenized_val_dataset)

loading configuration file ../_data/pretrain/model/config.json
Model config BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "torch_dtype": "float32",
  "transformers_version": "4.22.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 20000
}

loading weights file ../_data/pretrain/model/pytorch_model.bin
All model checkpoint weights were used when initializing BertForMaskedLM.

All the weights of BertForMaskedLM were initialized from the model checkpoint at ../_data/pretrain/model.
If your task is similar to the task the model of the checkpoint was trained on,

In [12]:
trainer.evaluate()

***** Running Evaluation *****
  Num examples = 1174433
  Batch size = 8
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'eval_loss': 3.2420713901519775,
 'eval_runtime': 2554.7871,
 'eval_samples_per_second': 459.699,
 'eval_steps_per_second': 57.463}

## Random examples

In [103]:
np.random.seed(0)

samples = np.random.permutation(len(tokenized_val_dataset))[:5]
samples = [tokenized_val_dataset[int(i)] for i in samples]

inputs = {k: v.to("cuda") for k, v in data_collator(samples).items()}
preds = torch.argmax(model(**inputs).logits.cpu(), axis=-1)

In [104]:
decode_kwargs = dict(
    skip_special_tokens=False
)

for sample, input, pred in zip(samples, inputs["input_ids"], preds):
    len_sample = len(sample["input_ids"])
    pp.pprint(tokenizer.decode(sample["input_ids"][1:len_sample-1], **decode_kwargs))
    pp.pprint(tokenizer.decode(input[1:len_sample-1], **decode_kwargs))
    pp.pprint(tokenizer.decode(pred[1:len_sample-1], **decode_kwargs))
    print()

'Handball players at the 2016 Summer Olympics'
'Handball players at the 2016 Summer Olympics'
'Handball players at the 2016 Summer Olympics'

('Of special interest is the six petal rosette derived from the " seven overlapping circles " pattern, also known as " '
 'Sun of the Alps " from its frequent use in alpine folk art in the 17th and 18th century.')
('Of special [MASK] is the six petal rosette derived from the " [MASK] overlapping circles " pattern, also [MASK] as " '
 'Sun of [MASK] Alps " from [MASK] frequent use in alpine folk art in the 17thther 18th century.')
('Of special names is the six petal rosette derived from the " two overlapping circles " pattern, also known as " Sun '
 'of the Alps " from its frequent use in alpine folk art in the 17th and 18th century.')

("For example, in Scholasticism, it was believed that God was capable of performing any miracle so long as it didn't "
 'lead to a logical contradiction.')
("For example, in Scholasticism, it was believed that God 

In [77]:
import torch

try:
    del input
except:
    pass

torch.cuda.empty_cache()