import sys

In [1]:
import sys
sys.path.append('../../thai2transformers')

In [2]:
import torch
import transformers
from transformers import AutoTokenizer , DataCollatorForLanguageModeling



In [3]:
transformers.__version__

'4.6.1'

In [4]:
tokenizer = AutoTokenizer.from_pretrained('airesearchth/wangchanberta-base-wiki-20210520-spm')

In [5]:
import math
from typing import List, Dict, Union, Optional, Tuple, Any
from dataclasses import dataclass
import numpy as np
import torch
import random
from bisect import bisect
from transformers.data.data_collator import DataCollatorForLanguageModeling, _collate_batch, tolist
from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase

SPECIAL_TOKEN_NAMES = ['bos_token', 'eos_token', 'sep_token', 'cls_token', 'pad_token']

@dataclass
class DataCollatorForSpanLevelMask(DataCollatorForLanguageModeling):
    """
    Data collator used for span-level masked language modeling
     
    adapted from NGramMaskGenerator class
    
    https://github.com/microsoft/DeBERTa/blob/11fa20141d9700ba2272b38f2d5fce33d981438b/DeBERTa/apps/tasks/mlm_task.py#L36
    and
    https://github.com/zihangdai/xlnet/blob/0b642d14dd8aec7f1e1ecbf7d6942d5faa6be1f0/data_utils.py

    """
    tokenizer: PreTrainedTokenizerBase
    mlm: bool = True
    mlm_probability: float = 0.15
    max_gram: int = 3
    keep_prob: float = 0.0
    mask_prob: float = 1.0
    max_preds_per_seq: int = None
    max_seq_len: int = 510

    def __init__(self, tokenizer, mlm=True, mlm_probability=0.15, *args, **kwargs):
        super().__init__(tokenizer, mlm=mlm, mlm_probability=mlm_probability)

        assert self.mask_prob + self.keep_prob <= 1, \
            f'The prob of using [MASK]({self.mask_prob}) and the prob of using original token({self.keep_prob}) should between [0,1]'

        if self.max_preds_per_seq is None:
            self.max_preds_per_seq = math.ceil(self.max_seq_len * self.mlm_probability / 10) * 10
            self.mask_window = int(1 / self.mlm_probability) # make ngrams per window sized context
        self.vocab_words = list(self.tokenizer.get_vocab().keys())
        self.vocab_mapping = self.tokenizer.get_vocab()
        
        self.special_tokens = [self.tokenizer.special_tokens_map[name] for name in  SPECIAL_TOKEN_NAMES]
#         print(' self.special_tokens', self.special_tokens)
        self.ngrams = np.arange(1, self.max_gram + 1, dtype=np.int64)
        _pvals = 1. / np.arange(1, self.max_gram + 1)
        self.pvals = _pvals / _pvals.sum(keepdims=True)

    def _choice(self, rng, data, p):
        cul = np.cumsum(p)
        x = rng.random()*cul[-1]
        id = bisect(cul, x)
        return data[id]

    def _per_token_mask(self, idx, tokens, rng, mask_prob, keep_prob):
        label = tokens[idx]
        mask = self.tokenizer.mask_token
        rand = rng.random()
        if rand < mask_prob:
            new_label = mask
        elif rand < mask_prob + keep_prob:
            new_label = label
        else:
            new_label = rng.choice(self.vocab_words)

        tokens[idx] = new_label

        return label

    def _mask_tokens(self, tokens: List[str], rng=random, **kwargs):

        indices = [i for i in range(len(tokens)) if tokens[i] not in self.special_tokens]
#         print('debug: indices to be able to be masked', indices)
        
        unigrams = [ [idx] for idx in indices ]
        num_to_predict = min(self.max_preds_per_seq, max(1, int(round(len(tokens) * self.mlm_probability))))
           
        offset = 0
        mask_grams = np.array([False]*len(unigrams))
        while offset < len(unigrams):
            n = self._choice(rng, self.ngrams, p=self.pvals)
            ctx_size = min(n * self.mask_window, len(unigrams)-offset)
            m = rng.randint(0, ctx_size-1)
            s = offset + m
            e = min(offset + m + n, len(unigrams))
            offset = max(offset+ctx_size, e)
            mask_grams[s:e] = True

        target_labels = [None]*len(tokens)
        w_cnt = 0
        for m,word in zip(mask_grams, unigrams):
            if m:
                for idx in word:
                    label = self._per_token_mask(idx, tokens, rng, self.mask_prob, self.keep_prob)
                    target_labels[idx] = label
                    w_cnt += 1
                if w_cnt >= num_to_predict:
                    break

        target_labels = [self.vocab_mapping[x] if x else -100 for x in target_labels]
        return tokens, target_labels    


    def mask_tokens(
        self, inputs: torch.Tensor, special_tokens_mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
        """
        labels = []
        # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probabilityability`)
        # probability_matrix = torch.full(labels.shape, self.mlm_probabilityability)
        # if special_tokens_mask is None:
        #     special_tokens_mask = [
        #         self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
        #     ]
        #     special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        # else:
        #     special_tokens_mask = special_tokens_mask.bool()

#         print('inputs', inputs.shape, inputs)
        inputs_masked = []
        
        for i, input in enumerate(inputs):
#             print('input',input)
            input_tokens = self.tokenizer.convert_ids_to_tokens(input)
            

            input_masked, _labels = self._mask_tokens(input_tokens)
#             print('DEBUG: input_masked', input_masked)
            input_masked_ids = self.tokenizer.convert_tokens_to_ids(input_masked)
            inputs_masked.append(input_masked_ids)
#             print('_labels, ', _labels)
#             print('inputs_masked, ', input_masked_ids)
            labels.append(_labels)
      
        return inputs_masked, labels


In [6]:
probability_matrix = torch.full((5,5), 0.5)
probability_matrix


tensor([[0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000, 0.5000, 0.5000]])

In [7]:
masked_indices = torch.bernoulli(probability_matrix).bool()
masked_indices

tensor([[False, False, False,  True, False],
        [ True,  True,  True,  True, False],
        [ True,  True, False, False,  True],
        [ True, False,  True, False, False],
        [ True,  True, False, False,  True]])

In [8]:
tokenizer.special_tokens_map

{'bos_token': '<s>',
 'eos_token': '</s>',
 'unk_token': '<unk>',
 'sep_token': '</s>',
 'pad_token': '<pad>',
 'cls_token': '<s>',
 'mask_token': '<mask>',
 'additional_special_tokens': "['<s>NOTUSED', '</s>NOTUSED', '▁']"}

In [9]:
_data_collator = DataCollatorForSpanLevelMask(tokenizer=tokenizer,
                                              mlm=True,
                                              mlm_probability=0.15,
                                              max_gram=3,
                                              keep_prob=0.0,
                                              mask_prob=1.0,
                                              max_seq_len=510)

In [10]:
# test_data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer,
#                                               mlm=True,
#                                               mlm_probability=0.15,
#                                               max_gram=3,
#                                               keep_prob=0.0,
#                                               mask_prob=1.0,
#                                               max_seq_len=510)

In [25]:
# text = "We present SpanBERT, a pre-training method that is designed to better represent and predict spans of text."
text = """ภาษาอินโด-ยูโรเปียนดั้งเดิม ภาษาอินโด-ยูโรเปียนดั้งเดิมคือภาษาดั้งเดิมที่เป็นที่สนใจมากที่สุด 
รวมถึงเข้าใจมากที่สุดอีกด้วย งานส่วนใหญ่ของนักภาษาศาสตร์ในช่วงศตวรรษที่ 19 มักจะเป็นเรื่องบูรณะภาษานี้ และภาษาลูกหลาน เช่นภาษาเจอร์แมนิกดั้งเดิม"""
tokens = tokenizer.tokenize(text)
print(tokens, len(tokens))


['▁', 'ภาษา', 'อินโด', '-', 'ยูโรเปียน', 'ดั้งเดิม', '▁', 'ภาษา', 'อินโด', '-', 'ยูโรเปียน', 'ดั้งเดิม', 'คือ', 'ภาษา', 'ดั้งเดิม', 'ที่เป็น', 'ที่สนใจ', 'มากที่สุด', '▁', 'รวมถึง', 'เข้าใจ', 'มากที่สุด', 'อีกด้วย', '▁', 'งาน', 'ส่วนใหญ่', 'ของนัก', 'ภาษาศาสตร์', 'ในช่วง', 'ศตวรรษที่', '▁', '19', '▁', 'มักจะเป็น', 'เรื่อง', 'บูรณะ', 'ภาษานี้', '▁', 'และภาษา', 'ลูกหลาน', '▁', 'เช่น', 'ภาษา', 'เจอร์แมนิก', 'ดั้งเดิม'] 45


In [26]:
inputs_1 = tokenizer.encode_plus(text, return_tensors='pt')['input_ids'].squeeze(0)
inputs_2 = tokenizer.encode_plus(text, return_tensors='pt')['input_ids'].squeeze(0)
inputs_1, inputs_2

(tensor([    5,     8,   213,  7203,    31, 12146,  1389,     8,   213,  7203,
            31, 12146,  1389,    33,   213,  1389,   328,  8624,   484,     8,
           383,  2477,   484,   535,     8,   166,   258,  6350,  8318,   254,
           472,     8,   368,     8,  7828,    85,  3757, 11278,     8,  4979,
          8452,     8,    61,   213, 23420,  1389,     6]),
 tensor([    5,     8,   213,  7203,    31, 12146,  1389,     8,   213,  7203,
            31, 12146,  1389,    33,   213,  1389,   328,  8624,   484,     8,
           383,  2477,   484,   535,     8,   166,   258,  6350,  8318,   254,
           472,     8,   368,     8,  7828,    85,  3757, 11278,     8,  4979,
          8452,     8,    61,   213, 23420,  1389,     6]))

In [27]:
(inputs_1, inputs_2)

(tensor([    5,     8,   213,  7203,    31, 12146,  1389,     8,   213,  7203,
            31, 12146,  1389,    33,   213,  1389,   328,  8624,   484,     8,
           383,  2477,   484,   535,     8,   166,   258,  6350,  8318,   254,
           472,     8,   368,     8,  7828,    85,  3757, 11278,     8,  4979,
          8452,     8,    61,   213, 23420,  1389,     6]),
 tensor([    5,     8,   213,  7203,    31, 12146,  1389,     8,   213,  7203,
            31, 12146,  1389,    33,   213,  1389,   328,  8624,   484,     8,
           383,  2477,   484,   535,     8,   166,   258,  6350,  8318,   254,
           472,     8,   368,     8,  7828,    85,  3757, 11278,     8,  4979,
          8452,     8,    61,   213, 23420,  1389,     6]))

In [28]:
# torch.stack((inputs['input_ids'],), dim=0)

In [29]:
res = _data_collator((inputs_1, inputs_2))
print(res.keys())

dict_keys(['input_ids', 'labels'])


In [30]:
print(res['input_ids'])

[[5, 8, 213, 7203, 31, 12146, 1389, 8, 24004, 24004, 31, 12146, 1389, 33, 213, 1389, 328, 8624, 24004, 8, 383, 2477, 484, 535, 8, 24004, 24004, 6350, 8318, 254, 472, 8, 368, 8, 7828, 85, 24004, 24004, 8, 4979, 8452, 8, 61, 213, 23420, 1389, 6], [5, 8, 213, 7203, 31, 24004, 1389, 8, 213, 24004, 31, 12146, 1389, 33, 213, 1389, 328, 8624, 24004, 8, 383, 2477, 484, 535, 8, 166, 258, 24004, 24004, 24004, 472, 8, 368, 8, 7828, 85, 3757, 11278, 8, 4979, 8452, 8, 24004, 213, 23420, 1389, 6]]


In [31]:
print(res['labels'])

[[-100, -100, -100, -100, -100, -100, -100, -100, 213, 7203, -100, -100, -100, -100, -100, -100, -100, -100, 484, -100, -100, -100, -100, -100, -100, 166, 258, -100, -100, -100, -100, -100, -100, -100, -100, -100, 3757, 11278, -100, -100, -100, -100, -100, -100, -100, -100, -100], [-100, -100, -100, -100, -100, 12146, -100, -100, -100, 7203, -100, -100, -100, -100, -100, -100, -100, -100, 484, -100, -100, -100, -100, -100, -100, -100, -100, 6350, 8318, 254, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 61, -100, -100, -100, -100]]


## Profiling Data Collator


In [32]:
%load_ext line_profiler
%load_ext memory_profiler


In [33]:
import glob, os
from torch.utils.data.dataloader import DataLoader

from torch.utils.data.sampler import RandomSampler, SequentialSampler


from thai2transformers.datasets import MLMDataset

In [34]:
TRAIN_DATA_PATH = '../../dataset/split/thwiki-for-ddp_6.11.2020/train'

In [35]:
glob.glob(os.path.join(TRAIN_DATA_PATH, '*'))

['../../dataset/split/thwiki-for-ddp_6.11.2020/train/train_debug.txt']

In [36]:
!wc -l ../../dataset/split/thwiki-for-ddp_6.11.2020/train/train_debug.txt

   20000 ../../dataset/split/thwiki-for-ddp_6.11.2020/train/train_debug.txt


In [37]:
%%time 

train_dataset = MLMDataset(tokenizer,
                           TRAIN_DATA_PATH,
                           510)


[INFO] Build features (parallel).

[INFO] Start groupping results.
[INFO] Done.
CPU times: user 324 ms, sys: 101 ms, total: 425 ms
Wall time: 8.27 s


### Data loader with bz=1

In [38]:
train_sampler = SequentialSampler(train_dataset)

In [44]:
data_collator_subword_mlm = DataCollatorForLanguageModeling(tokenizer,
                                                        pad_to_multiple_of=8)

data_loader_subword_mlm = DataLoader(
            train_dataset,
            batch_size=8,
            sampler=train_sampler,
            collate_fn=data_collator_subword_mlm,
            drop_last=False,
            num_workers=0,
            pin_memory=True,
        )

In [45]:
data_collator_span_mlm =  DataCollatorForSpanLevelMask(tokenizer=tokenizer,
                                              mlm=True,
                                              mlm_probability=0.15,
                                              max_gram=3,
                                              keep_prob=0.0,
                                              mask_prob=1.0,
                                              max_seq_len=510,
                                              pad_to_multiple_of=8)


data_loader_span_mlm = DataLoader(
            train_dataset,
            batch_size=8,
            sampler=train_sampler,
            collate_fn=data_collator_span_mlm,
            drop_last=False,
            num_workers=0,
            pin_memory=True,
        )

In [46]:
%%timeit
list(data_loader_subword_mlm)
print('.', end='')

........1.54 s ± 63.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [47]:
%prun next(iter(data_loader_subword_mlm))

 

         529 function calls (526 primitive calls) in 0.002 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.001    0.001 data_collator.py:361(mask_tokens)
        3    0.000    0.000    0.000    0.000 {built-in method bernoulli}
        8    0.000    0.000    0.000    0.000 tokenization_utils_base.py:3128(<listcomp>)
        1    0.000    0.000    0.000    0.000 data_collator.py:195(_collate_batch)
        3    0.000    0.000    0.000    0.000 {method 'bool' of 'torch._C._TensorBase' objects}
        1    0.000    0.000    0.000    0.000 {built-in method empty}
        3    0.000    0.000    0.000    0.000 {built-in method full}
        1    0.000    0.000    0.000    0.000 {method 'tolist' of 'torch._C._TensorBase' objects}
        8    0.000    0.000    0.000    0.000 tokenization_utils_base.py:1225(all_special_tokens_extended)
       66    0.000    0.000    0.000    0.000 {method 'token_to

In [48]:
%%timeit
list(data_loader_span_mlm)
print('.', end='')

........20.1 s ± 672 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [49]:
%prun next(iter(data_loader_span_mlm))

 

         10404 function calls (10401 primitive calls) in 0.014 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        8    0.004    0.001    0.010    0.001 tokenization_utils_fast.py:275(convert_ids_to_tokens)
     1320    0.003    0.000    0.003    0.000 tensor.py:468(<lambda>)
     1312    0.002    0.000    0.002    0.000 {method 'id_to_token' of 'tokenizers.Tokenizer' objects}
        8    0.001    0.000    0.003    0.000 <ipython-input-20-f16f671911d4>:73(_mask_tokens)
     1313    0.001    0.000    0.001    0.000 {method 'token_to_id' of 'tokenizers.Tokenizer' objects}
        9    0.000    0.000    0.002    0.000 tokenization_utils_fast.py:220(convert_tokens_to_ids)
     1313    0.000    0.000    0.001    0.000 tokenization_utils_fast.py:242(_convert_token_to_id_with_added_voc)
     2648    0.000    0.000    0.000    0.000 {method 'append' of 'list' objects}
        8    0.000    0.000    0.000    0.000 <ipython-inpu

### Data loader with bz=8

In [198]:
BZ=8

data_collator_subword_mlm = DataCollatorForLanguageModeling(tokenizer,
                                                        pad_to_multiple_of=8)

data_loader_subword_mlm = DataLoader(
            train_dataset,
            batch_size=BZ,
            sampler=train_sampler,
            collate_fn=data_collator_subword_mlm,
            drop_last=False,
            num_workers=0,
            pin_memory=True,
        )

data_collator_span_mlm =  DataCollatorForSpanLevelMask(tokenizer=tokenizer,
                                              mlm=True,
                                              mlm_probability=0.5,
                                              max_gram=3,
                                              keep_prob=0.0,
                                              mask_prob=1.0,
                                              max_seq_len=510,
                                              pad_to_multiple_of=8)


data_loader_span_mlm = DataLoader(
            train_dataset,
            batch_size=BZ,
            sampler=train_sampler,
            collate_fn=data_collator_span_mlm,
            drop_last=False,
            num_workers=0,
            pin_memory=True,
        )

In [199]:
%%timeit
list(data_loader_subword_mlm)
print('.', end='')

........1.66 s ± 57.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [63]:
%%timeit
list(data_loader_span_mlm)
print('.', end='')

........20.4 s ± 425 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [62]:
%prun next(iter(data_loader_subword_mlm))

 

         529 function calls (526 primitive calls) in 0.002 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.001    0.001    0.001    0.001 data_collator.py:361(mask_tokens)
        8    0.000    0.000    0.000    0.000 tokenization_utils_base.py:3128(<listcomp>)
        3    0.000    0.000    0.000    0.000 {built-in method bernoulli}
        1    0.000    0.000    0.000    0.000 data_collator.py:195(_collate_batch)
        1    0.000    0.000    0.000    0.000 {method 'item' of 'torch._C._TensorBase' objects}
        8    0.000    0.000    0.000    0.000 tokenization_utils_base.py:1225(all_special_tokens_extended)
        3    0.000    0.000    0.000    0.000 {method 'bool' of 'torch._C._TensorBase' objects}
        3    0.000    0.000    0.000    0.000 {built-in method full}
        1    0.000    0.000    0.000    0.000 {built-in method tensor}
       10    0.000    0.000    0.000    0.000 tokenization_utils

In [64]:
%prun next(iter(data_loader_span_mlm))

 

         10286 function calls (10283 primitive calls) in 0.016 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        8    0.005    0.001    0.011    0.001 tokenization_utils_fast.py:275(convert_ids_to_tokens)
     1320    0.003    0.000    0.003    0.000 tensor.py:468(<lambda>)
     1312    0.002    0.000    0.002    0.000 {method 'id_to_token' of 'tokenizers.Tokenizer' objects}
        8    0.001    0.000    0.003    0.000 <ipython-input-20-f16f671911d4>:73(_mask_tokens)
     1313    0.001    0.000    0.001    0.000 {method 'token_to_id' of 'tokenizers.Tokenizer' objects}
        9    0.000    0.000    0.002    0.000 tokenization_utils_fast.py:220(convert_tokens_to_ids)
     1313    0.000    0.000    0.001    0.000 tokenization_utils_fast.py:242(_convert_token_to_id_with_added_voc)
     2648    0.000    0.000    0.000    0.000 {method 'append' of 'list' objects}
        8    0.000    0.000    0.000    0.000 <ipython-inpu