# Setup

In [None]:
from os import path
from typing import Optional

__DIR__ = globals()['_dh'][0]
data_dir = path.relpath(path.join(__DIR__, "..", "_data"))

In [None]:
# Settings
_testing = False
_colab_install = True
_pm_log_sections = False

# Parameters
dataset = path.join(data_dir, "wiki", "20220301.en.1gb")

base_model = "bert-base-cased"
max_length = 128
vocab_size = 20_000

tokenize_params = dict(batched=True, num_proc=4)
tokenizer_dir = path.join(data_dir, "pretrain", "tokenizer")

mlm_probability = 0.15
bert_config = dict()
training_args = dict(
    optim = "adamw_torch",
    num_train_epochs = 3,
    per_device_train_batch_size = 64,
    eval_accumulation_steps = 10,
    evaluation_strategy = "steps",
    logging_steps = 5000,
    save_steps = 5000,
    save_total_limit = 5,
)
max_eval_samples: Optional[int] = 2000
model_dir = path.join(data_dir, "pretrain", "model")

In [None]:
if _testing:
    dataset = path.join(data_dir, "wiki", "20220301.en.test")

    training_args.update(dict(
        max_steps = 3,
        logging_steps = 1,
    ))

    max_eval_samples = 1000

## Process settings / parameters

In [None]:
from pprint import pprint
from collections import OrderedDict

if _colab_install:
    try:
        import google.colab
        
        colab_install_script = path.join(__DIR__, "..", "colab_install.sh")

        if not path.isfile(colab_install_script):
            script_url = "https://raw.githubusercontent.com/yenson-lau/pii-remediation/main/colab_install.sh"
            !wget $script_url -O $colab_install_script

        !bash $colab_install_script

        print()

    except ModuleNotFoundError:
        pass

if _pm_log_sections:
    def pm_log_section(message):
        print(f"\n[===== {message} =====]\n")
else:
    def pm_log_section(message):
        return

if _testing:
    pm_log_section("Running on testing mode")

config = OrderedDict(
    dataset = dataset,

    base_model = base_model,
    max_length = max_length,
    vocab_size = vocab_size,

    tokenize_params = tokenize_params,
    tokenizer_dir = tokenizer_dir,

    mlm_probability = mlm_probability,
    bert_config = bert_config,
    training_args = training_args,
    max_eval_samples = max_eval_samples,
    model_dir = model_dir,
)

print(f"{'TESTING ' if _testing else ''}Parameters:")
pprint(config, indent=2)


Parameters:
OrderedDict([ ('dataset', '../_data/wiki/20220301.en.1gb'),
              ('base_model', 'bert-base-cased'),
              ('max_length', 128),
              ('vocab_size', 20000),
              ('tokenize_params', {'batched': True, 'num_proc': 4}),
              ('tokenizer_dir', '../_data/pretrain/tokenizer'),
              ('mlm_probability', 0.15),
              ('bert_config', {}),
              ( 'training_args',
                { 'eval_accumulation_steps': 10,
                  'evaluation_strategy': 'steps',
                  'logging_steps': 5000,
                  'num_train_epochs': 3,
                  'optim': 'adamw_torch',
                  'per_device_train_batch_size': 64,
                  'save_steps': 5000,
                  'save_total_limit': 5}),
              ('max_eval_samples', 2000),
              ('model_dir', '../_data/pretrain/model')])


# Load dataset

In [None]:
from datasets import Dataset, load_dataset

pm_log_section("Loading dataset")

ds_dir = dataset
dataset = dict()
for split in ["train", "val"]:
    data_file = path.join(ds_dir, f"{split}_data.json")
    if not path.isfile(data_file):  data_file += ".gz"
    dataset[split] = load_dataset("json", data_files=data_file, field="data")["train"]

    if ((split != "train") 
        and (max_eval_samples is not None) 
        and (len(dataset[split]) > max_eval_samples)):
        
        dataset[split] = dataset[split].select(range(max_eval_samples))



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]

# Tokenization

In [None]:
from transformers import BertTokenizerFast

pm_log_section("Tokenizing")

In [None]:
tokenizer = (BertTokenizerFast
                .from_pretrained(base_model)
                .train_new_from_iterator(dataset["train"]["text"], vocab_size))
tokenizer.model_max_length = max_length

tokenizer.save_pretrained(tokenizer_dir);

In [None]:
tokenize_function = lambda ex: tokenizer(ex["text"], truncation=True)

tokenized_dataset = {
    k: v.map(tokenize_function, remove_columns = list(v.features), **tokenize_params)
    for k, v in dataset.items()
}

        

#0:   0%|          | 0/2351 [00:00<?, ?ba/s]

#1:   0%|          | 0/2351 [00:00<?, ?ba/s]

#2:   0%|          | 0/2351 [00:00<?, ?ba/s]

#3:   0%|          | 0/2351 [00:00<?, ?ba/s]

        

#0:   0%|          | 0/1 [00:00<?, ?ba/s]

#1:   0%|          | 0/1 [00:00<?, ?ba/s]

#2:   0%|          | 0/1 [00:00<?, ?ba/s]

#3:   0%|          | 0/1 [00:00<?, ?ba/s]

# Train masked language model

In [None]:
import numpy as np
from transformers import (BertConfig,
                          BertForMaskedLM,
                          DataCollatorForLanguageModeling,
                          Trainer,
                          TrainingArguments)

pm_log_section("Training MLM")

data_collator = DataCollatorForLanguageModeling(tokenizer = tokenizer,
                                                mlm_probability = mlm_probability)

bert_config = BertConfig(vocab_size = tokenizer.vocab_size, **bert_config)
model = BertForMaskedLM(config = bert_config)

training_args = TrainingArguments(output_dir = model_dir,
                                  overwrite_output_dir = True,
                                  **training_args)

def compute_metrics(eval_preds):
    idxs0, idxs1 = np.where(eval_preds.label_ids!=-100)

    preds = np.argmax(eval_preds.predictions[idxs0, idxs1, :], axis=-1)
    labels = eval_preds.label_ids[idxs0, idxs1]

    acc = (preds==labels).sum()/len(preds)

    return {"accuracy": acc}

In [None]:
trainer = Trainer(model = model,
                  args = training_args,
                  data_collator = data_collator,
                  compute_metrics=compute_metrics,
                  train_dataset = tokenized_dataset["train"],
                  eval_dataset=tokenized_dataset["val"])

trainer.train()
trainer.save_model(model_dir)

***** Running training *****
  Num examples = 9403586
  Num Epochs = 3
  Instantaneous batch size per device = 64
  Total train batch size (w. parallel, distributed & accumulation) = 64
  Gradient Accumulation steps = 1
  Total optimization steps = 440796
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss,Validation Loss,Accuracy
5000,6.5319,5.548443,0.264958
10000,5.304,4.949771,0.301118
15000,4.8247,4.632231,0.322211
20000,4.5332,4.409497,0.345477
25000,4.3229,4.217476,0.35553
30000,4.1632,4.06336,0.368654
35000,4.0361,4.071036,0.362708
40000,3.9245,3.822901,0.392621
45000,3.8208,3.748804,0.399106
50000,3.7349,3.69564,0.407196


***** Running Evaluation *****
  Num examples = 2000
  Batch size = 8
Saving model checkpoint to ../_data/pretrain/model/checkpoint-5000
Configuration saved in ../_data/pretrain/model/checkpoint-5000/config.json
Model weights saved in ../_data/pretrain/model/checkpoint-5000/pytorch_model.bin
***** Running Evaluation *****
  Num examples = 2000
  Batch size = 8
Saving model checkpoint to ../_data/pretrain/model/checkpoint-10000
Configuration saved in ../_data/pretrain/model/checkpoint-10000/config.json
Model weights saved in ../_data/pretrain/model/checkpoint-10000/pytorch_model.bin
***** Running Evaluation *****
  Num examples = 2000
  Batch size = 8
Saving model checkpoint to ../_data/pretrain/model/checkpoint-15000
Configuration saved in ../_data/pretrain/model/checkpoint-15000/config.json
Model weights saved in ../_data/pretrain/model/checkpoint-15000/pytorch_model.bin
***** Running Evaluation *****
  Num examples = 2000
  Batch size = 8
Saving model checkpoint to ../_data/pretrain/

Step,Training Loss,Validation Loss,Accuracy
5000,6.5319,5.548443,0.264958
10000,5.304,4.949771,0.301118
15000,4.8247,4.632231,0.322211
20000,4.5332,4.409497,0.345477
25000,4.3229,4.217476,0.35553
30000,4.1632,4.06336,0.368654
35000,4.0361,4.071036,0.362708
40000,3.9245,3.822901,0.392621
45000,3.8208,3.748804,0.399106
50000,3.7349,3.69564,0.407196


***** Running Evaluation *****
  Num examples = 2000
  Batch size = 8
Saving model checkpoint to ../_data/pretrain/model/checkpoint-125000
Configuration saved in ../_data/pretrain/model/checkpoint-125000/config.json
Model weights saved in ../_data/pretrain/model/checkpoint-125000/pytorch_model.bin
Deleting older checkpoint [../_data/pretrain/model/checkpoint-100000] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 2000
  Batch size = 8
Saving model checkpoint to ../_data/pretrain/model/checkpoint-130000
Configuration saved in ../_data/pretrain/model/checkpoint-130000/config.json
Model weights saved in ../_data/pretrain/model/checkpoint-130000/pytorch_model.bin
Deleting older checkpoint [../_data/pretrain/model/checkpoint-105000] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 2000
  Batch size = 8
Saving model checkpoint to ../_data/pretrain/model/checkpoint-135000
Configuration saved in ../_data/pretrain/model/checkpoint-135000/

Colab runtime got terminated at this point.