In [1]:
import os
import re
import math
import random

import datasets
import spacy
import tokenizations
from collections.abc import Mapping

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from functools import partial

from transformers import  DataCollatorForWholeWordMask
from transformers.data.data_collator import tolist, _torch_collate_batch

from transformers import BertConfig, BertTokenizerFast, BertForMaskedLM
from transformers import TrainingArguments, Trainer
from transformers.integrations import WandbCallback, rewrite_logs

In [2]:
pos_tagger = spacy.load('en_core_web_sm')

In [3]:
class BertDataProcessor():
  def __init__(self, hf_dset, hf_tokenizer, max_length, text_col='text', lines_delimiter='\n', minimize_data_size=True, apply_cleaning=True):
    self.hf_tokenizer = hf_tokenizer
    self._current_sentences = []
    self._current_length = 0
    self._max_length = max_length
    self._target_length = max_length

    self.hf_dset = hf_dset
    self.text_col = text_col
    self.lines_delimiter = lines_delimiter
    self.minimize_data_size = minimize_data_size
    self.apply_cleaning = apply_cleaning
    pos_classes = ['ADJ', 'ADP', 'ADV', 'AUX', 'CCONJ', 'DET', 'INTJ', 'NOUN', 'NUM', 'PART', 'PRON', 'PROPN', 'PUNCT', 'SCONJ', 'SYM', 'VERB', 'X']
    self.pos_hash = {c: i for i, c in enumerate(pos_classes)}

  def map(self, **kwargs) -> datasets.arrow_dataset.Dataset:
    num_proc = kwargs.pop('num_proc', os.cpu_count())
    cache_file_name = kwargs.pop('cache_file_name', None)
    if cache_file_name is not None:
        if not cache_file_name.endswith('.arrow'): 
            cache_file_name += '.arrow'        
        if '/' not in cache_file_name: 
            cache_dir = os.path.abspath(os.path.dirname(self.hf_dset.cache_files[0]['filename']))
            cache_file_name = os.path.join(cache_dir, cache_file_name)

    return self.hf_dset.map(
        function=self,
        batched=True,
        cache_file_name=cache_file_name,
        remove_columns=self.hf_dset.column_names,
        disable_nullable=True,
        input_columns=[self.text_col],
        writer_batch_size=10**4,
        num_proc=num_proc,
        **kwargs     
    )

  def __call__(self, texts):
    if self.minimize_data_size: new_example = {'input_ids':[], 'sentA_length':[], 'pos_subword_info':[]}
    else: new_example = {'input_ids':[], 'input_mask': [], 'segment_ids': []}

    for text in texts: # for every doc
      
      for line in re.split(self.lines_delimiter, text): # for every paragraph
        
        if re.fullmatch(r'\s*', line): continue # empty string or string with all space characters
        if self.apply_cleaning and self.filter_out(line): continue
        
        example = self.add_line(line)
        if example:
          for k,v in example.items(): new_example[k].append(v)
      
      if self._current_length != 0:
        example = self._create_example()
        for k,v in example.items(): new_example[k].append(v)

    return new_example

  def filter_out(self, line):
    if len(line) < 80: return True
    return False 

  def clean(self, line):
    # () is remainder after link in it filtered out
    return line.strip().replace("\n", " ").replace("()","")

  def add_line(self, line):
    """Adds a line of text to the current example being built."""
    line = self.clean(line)
    tokens = self.hf_tokenizer.tokenize(line, max_length=512, truncation=True)
    tokids = self.hf_tokenizer.convert_tokens_to_ids(tokens)
    self._current_sentences.append(tokids)
    self._current_length += len(tokids)
    if self._current_length >= self._target_length:
      return self._create_example()
    return None

  def _create_example(self):
    """Creates a pre-training example from the current list of sentences."""
    # small chance to only have one segment as in classification tasks
    if random.random() < 0.1:
      first_segment_target_length = 100000
    else:
      # -3 due to not yet having [CLS]/[SEP] tokens in the input text
      first_segment_target_length = (self._target_length - 3) // 2

    first_segment = []
    second_segment = []
    for sentence in self._current_sentences:
      # the sentence goes to the first segment if (1) the first segment is
      # empty, (2) the sentence doesn't put the first segment over length or
      # (3) 50% of the time when it does put the first segment over length
      if (len(first_segment) == 0 or
          len(first_segment) + len(sentence) < first_segment_target_length or
          (len(second_segment) == 0 and
           len(first_segment) < first_segment_target_length and
           random.random() < 0.5)):
        first_segment += sentence
      else:
        second_segment += sentence

    # trim to max_length while accounting for not-yet-added [CLS]/[SEP] tokens
    first_segment = first_segment[:self._max_length - 2]
    second_segment = second_segment[:max(0, self._max_length -
                                         len(first_segment) - 3)]

    # prepare to start building the next example
    self._current_sentences = []
    self._current_length = 0
    # small chance for random-length instead of max_length-length example
    if random.random() < 0.05:
      self._target_length = random.randint(5, self._max_length)
    else:
      self._target_length = self._max_length

    return self._make_example(first_segment, second_segment)

  def _make_example(self, first_segment, second_segment):
    """Converts two "segments" of text into a tf.train.Example."""
    input_ids = [self.hf_tokenizer.cls_token_id] + first_segment + [self.hf_tokenizer.sep_token_id]

    bert_tokens = self.hf_tokenizer.convert_ids_to_tokens(first_segment)
    sentence = self.hf_tokenizer.decode(first_segment)

    with pos_tagger.select_pipes(enable=['morphologizer', 'tok2vec', 'tagger', 'attribute_ruler']):
      spacy_doc = pos_tagger(sentence)
    spacy_tokens = [t.text for t in spacy_doc]
    pos = torch.tensor([self.pos_hash[t.pos_] for t in spacy_doc])

    # align spacy_tokens to bert_tokens
    a2b, b2a = tokenizations.get_alignments(spacy_tokens, bert_tokens)

    count = 0
    align_index = []
    token_top = -1
    for i in range(len(spacy_tokens)):
      for j in a2b[i]:
        if j > token_top:
          align_index.append(count)
      count += 1
      token_top = a2b[i][-1]
    
    align_index = torch.tensor(align_index)
    # assign pos to bert_tokens
    pos_subword_info = torch.index_select(pos, dim=0, index=align_index)
    pos_subword_info = [-1] + pos_subword_info.tolist() + [-1]

    sentA_length = len(input_ids)
    segment_ids = [0] * sentA_length
    assert len(input_ids) == len(pos_subword_info)

    # if second_segment:
    #   input_ids += second_segment + [self.hf_tokenizer.sep_token_id]
    #   segment_ids += [1] * (len(second_segment) + 1)

    if self.minimize_data_size:
      return {
        'input_ids': input_ids,
        'sentA_length': sentA_length,
        'pos_subword_info': pos_subword_info
      }
    else:
      input_mask = [1] * len(input_ids)
      input_ids += [0] * (self._max_length - len(input_ids))
      input_mask += [0] * (self._max_length - len(input_mask))
      segment_ids += [0] * (self._max_length - len(segment_ids))
      return {
        'input_ids': input_ids,
        'input_mask': input_mask,
        'segment_ids': segment_ids,
      }

In [4]:
hf_tokenizer = BertTokenizerFast.from_pretrained(f"bert-base-uncased")
BertProcessor = partial(BertDataProcessor, hf_tokenizer=hf_tokenizer, max_length=128)

Downloading tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

Downloading vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [None]:
wiki = datasets.load_dataset('wikitext', 'default', cache_dir='./')['train']

Downloading readme:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

Downloading and preparing dataset None/wikitext-103-raw-v1 to /home/pasitt/work/ALBEF/parquet/wikitext-103-raw-v1-7bb180478b704b56/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec...


Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/157M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/157M [00:00<?, ?B/s]

In [None]:
wiki[21]

In [None]:
e_wiki = BertProcessor(wiki).map(cache_file_name=f"bert_wikitext_128.arrow", num_proc=4)

In [None]:
e_wiki

In [None]:
e_wiki[0]['pos_subword_info']

In [None]:
BertProcessor(wiki).pos_hash