# Extending BERT to another language/domain

Based on:

- https://arxiv.org/abs/2112.14569
- https://github.com/huggingface/transformers/issues/2691
- https://huggingface.co/docs/tokenizers/main/en/quicktour#training-the-tokenizer

In [26]:
BASE_MODEL = "bert-base-multilingual-uncased"
PUSH_TO_HUB = True # push to Hugging Face Hub (requires log in)
TRAIN_DATA_DIR = "data/txt"
BERT_TRAIN_DATA_DIR = TRAIN_DATA_DIR
VOCAB_SIZE = 32000 # size of the (WordPiece) vocabulary inferred from training data

In [11]:
import glob

train_files = glob.glob(f"{TRAIN_DATA_DIR}/*")

#### 1. Build a new WordPiece vocabulary for the target corpus

In [None]:
import os
from tokenizers import processors, BertWordPieceTokenizer

special_tokens = ["[UNK]", "[PAD]", "[SEP]", "[MASK]", "[CLS]"]

tokenizer = BertWordPieceTokenizer(
    clean_text=True,
    handle_chinese_chars=False,
    strip_accents=False,
    lowercase=False
)
tokenizer.train(files=train_files, vocab_size=VOCAB_SIZE, min_frequency=2,
                limit_alphabet=1000, wordpieces_prefix='##',
                special_tokens=special_tokens)
tokenizer.post_processor = processors.TemplateProcessing(
    single=f"[CLS]:0 $A:0 [SEP]:0",
    pair=f"[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
    special_tokens=[(token_str, idx) for idx, token_str in enumerate(special_tokens[1:], start=1)],
)

tok_save_dir = f"tokenizer/tokenizer-trained{VOCAB_SIZE}"
if not os.path.isdir(tok_save_dir):
    os.makedirs(tok_save_dir)
tokenizer.save(f"{tok_save_dir}.json")
tokenizer.save_model(tok_save_dir)

In [33]:
# CHECK TOKENIZATION (must keep diacritics and capitalization)
sample = "Одӥг гужем ӝытэ со чорыганы ӝутказ Кам дуре. Татарстанысь Голюшурмае вуим но, бордамы эшшо олокӧня"
result = tokenizer.encode(sample, add_special_tokens=True)
print(sample)
print(result.tokens)
print(tokenizer.decode(result.ids))
assert sample == tokenizer.decode(result.ids)

Одӥг гужем ӝытэ со чорыганы ӝутказ Кам дуре. Татарстанысь Голюшурмае вуим но, бордамы эшшо олокӧня
['[CLS]', 'Одӥг', 'гужем', 'ӝытэ', 'со', 'чорыганы', 'ӝутказ', 'Кам', 'дуре', '.', 'Татарстанысь', 'Голюшур', '##ма', '##е', 'вуим', 'но', ',', 'бордамы', 'эшшо', 'олокӧня', '[SEP]']
Одӥг гужем ӝытэ со чорыганы ӝутказ Кам дуре. Татарстанысь Голюшурмае вуим но, бордамы эшшо олокӧня


#### 2. Match 2 vocabularies

In [14]:
from transformers import AutoTokenizer, BertTokenizer

tok_old = AutoTokenizer.from_pretrained(BASE_MODEL)
tok_new = BertTokenizer.from_pretrained(tok_save_dir)

with open("tokenizer/vocabs/original.vocab", "w") as wf:
    old_vocab_size = len(tok_old.get_vocab().items())
    for w, id in tok_old.get_vocab().items():
        wf.write(f"{w}\t{id}\n")
with open("tokenizer/vocabs/modified.vocab", "w") as wf:
    for w, id in tok_new.get_vocab().items():
        wf.write(f"{w}\t{id}\n")

In [None]:
from tokenizer.match_vocabs import update_dict, matcher

new_voc = update_dict(vocab_pth="tokenizer/vocabs/modified.vocab")
base_voc = update_dict(vocab_pth="tokenizer/vocabs/original.vocab")
matcher(base_voc, new_voc, out_vocab="tokenizer/vocabs/merged", matcher=1)
# matcher(base_voc, new_voc, out_vocab="tokenizer/vocabs/original.vocab", matcher=2)

#### 3. Initialize new embeddings for the BERT model using unified vocabulary

In [7]:
import torch
from transformers import BertConfig, BertForMaskedLM

from model_utils.modify_model import get_random_embeds, get_mapping_matrices

if not os.path.isdir("models"):
    os.makedirs("models")

In [None]:
import pickle

conf = BertConfig(vocab_size=VOCAB_SIZE)
conf.vocab_size = VOCAB_SIZE
conf.num_hidden_layers = 8
try:
    random_embeds = pickle.load(open(f"./models/random_embeds_{VOCAB_SIZE}.pkl", "rb"))
    print("Loaded saved random embeds")
except:
    random_embeds = get_random_embeds(conf)
    pickle.dump(random_embeds, open(f"./models/random_embeds_{VOCAB_SIZE}.pkl", "wb"))

model = BertForMaskedLM.from_pretrained(BASE_MODEL)
mapping_matrix, mask_matrix = get_mapping_matrices(
    mapping_file="tokenizer/vocabs/merged_matcher_f1.tsv",
    new_vocab_size=VOCAB_SIZE,
    old_vocab_size=old_vocab_size,
    use_bad_shift=False,
    use_one_to_one=False,
)
new_embeds = mapping_matrix.matmul(model.bert.embeddings.word_embeddings.weight)
new_embed_matrix = (1. - mask_matrix) * random_embeds + new_embeds
model.bert.embeddings.word_embeddings = torch.nn.Embedding.from_pretrained(new_embed_matrix, freeze=False)
model.config.vocab_size = VOCAB_SIZE
model.resize_token_embeddings(VOCAB_SIZE)
model.save_pretrained("./models/bert_modif_emb")
print(model.bert.embeddings.word_embeddings.weight.size())

#### 4. Fine-tune the new model (with the new vocab) on PLM

In [17]:
from transformers import BertTokenizerFast

tokenizer = BertTokenizerFast.from_pretrained(tok_save_dir, do_lower_case=False, strip_accents=False)

In [25]:
# CHECK TOKENIZATION (must keep diacritics and capitalization)
sample = "Одӥг гужем ӝытэ со чорыганы ӝутказ Кам дуре. Татарстанысь Голюшурмае вуим но, бордамы эшшо олокӧня"
result = tokenizer(sample, add_special_tokens=True)
print(sample)
print(result.tokens())
print(tokenizer.decode(result["input_ids"], skip_special_tokens=True))
assert sample == tokenizer.decode(result["input_ids"], skip_special_tokens=True)

Одӥг гужем ӝытэ со чорыганы ӝутказ Кам дуре. Татарстанысь Голюшурмае вуим но, бордамы эшшо олокӧня
['[CLS]', 'Одӥг', 'гужем', 'ӝытэ', 'со', 'чорыганы', 'ӝутказ', 'Кам', 'дуре', '.', 'Татарстанысь', 'Голюшур', '##ма', '##е', 'вуим', 'но', ',', 'бордамы', 'эшшо', 'олокӧня', '[SEP]']
Одӥг гужем ӝытэ со чорыганы ӝутказ Кам дуре. Татарстанысь Голюшурмае вуим но, бордамы эшшо олокӧня


In [None]:
from datasets import load_dataset
dataset = load_dataset('text', data_files=glob.glob(f"{BERT_TRAIN_DATA_DIR}/*"))

In [45]:
def tokenize_function(examples):
    encoding = tokenizer(examples['text'])
    return encoding

tokenized_datasets = dataset.map(
    tokenize_function, batched=True, remove_columns=['text']
)

  0%|          | 0/4 [00:00<?, ?ba/s]

In [46]:
def group_texts(examples, chunk_size=128):
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    total_length = (total_length // chunk_size) * chunk_size
    result = {
        k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

lm_datasets = tokenized_datasets.map(group_texts, batched=True)

  0%|          | 0/4 [00:00<?, ?ba/s]

In [60]:
whole_words_masking = True
mlm_probability = 0.15

In [66]:
from transformers import DataCollatorForLanguageModeling, DataCollatorForWholeWordMask

if whole_words_masking:
    tokenizer.mask_token = "[MASK]"
    data_collator = DataCollatorForWholeWordMask(tokenizer=tokenizer, mlm_probability=mlm_probability)
else:
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=mlm_probability)

In [None]:
# these are custom sizes
test_ratio = lm_datasets.num_rows["train"] // 11

train_size = test_ratio * 10
test_size = int(0.1 * train_size)

downsampled_dataset = lm_datasets["train"].train_test_split(
    train_size=train_size, test_size=test_size, seed=42
)
print(f"Train size: {len(downsampled_dataset['train'])}\nTest size:  {len(downsampled_dataset['test'])}")

In [None]:
import transformers
transformers.logging.set_verbosity_info()

In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
import shutil
import os
# Helper method to run Tensorboard in Colab
def reinit_tensorboard(logs_base_dir, clear_log = True):
    if clear_log:    
        shutil.rmtree(logs_base_dir, ignore_errors = True)
        os.makedirs(logs_base_dir, exist_ok=True)
    # Colab magic
    %reload_ext tensorboard
    %tensorboard --logdir {logs_base_dir} --reuse_port False

In [None]:
from transformers import TrainingArguments

batch_size = 20
# Show the training loss with every epoch
logging_steps = len(downsampled_dataset["train"]) // batch_size
model_name = BASE_MODEL.split("/")[-1]

model_dir = f"vocab2-{model_name}-udm-tsa"
reinit_tensorboard(f"{model_dir}/runs", clear_log=False)

In [None]:
training_args = TrainingArguments(
    output_dir=model_dir,
    overwrite_output_dir=True,
    evaluation_strategy="epoch",
    num_train_epochs=5,
    learning_rate=1e-5,
    weight_decay=0.01,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    push_to_hub=True,
    fp16=True,
    report_to="tensorboard",
    logging_steps=logging_steps,
    save_strategy="epoch",
    hub_strategy="all_checkpoints",
    warmup_ratio=0.1,
    # resume_from_checkpoint=""
)

In [None]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=downsampled_dataset["train"],
    eval_dataset=downsampled_dataset["test"],
    data_collator=data_collator,
)

In [None]:
trainer.train()