In [1]:
from IPython.core.debugger import set_trace as bk
import os
from pathlib import Path
from functools import partial
import json
from tqdm import tqdm
import torch
import pyarrow as pa
import nlp
from transformers import ElectraTokenizerFast
hf_fast_tokenizer = ElectraTokenizerFast.from_pretrained("google/electra-small-generator")
from fastai2.text.all import *

In [2]:
class HF_TokenizeTfm():
  """
  Args:
    `hf_dset` (`nlp.Dataset`)
    `cols`: with one of the following signature:
      - `cols`(`List[str]`): tokenize the col into col
      - `cols`(`Dict[str]`): tokenize the col(key) and into col(value)
    `hf_tokenizer`: tokenizer of HuggingFace/Transformers.
    `remove_original`: after tokenization, remove all original columns to save cache size.
  """
  def __init__(self, hf_dset, cols, hf_tokenizer, remove_original=False):
    if isinstance(cols, list): cols = {c:c for c in cols}
    assert isinstance(cols, dict)
    self.hf_dset, self.cols, self.tokenizer = hf_dset, cols, hf_tokenizer
    self.remove_original = remove_original
    if remove_original:
      for in_col,out_col in cols.items(): assert in_col !=  out_col
    """
    If don't specify cache file name, it will be hashed binary of pickled function that
    passed to `map`, so if you pass the same function, it knows to use cache.
    But tokenizer can't be pickled, so use tokenizer config to make tfms use different 
    tokenizers unique.  
    """
    self.tokenizer_config = hf_tokenizer.pretrained_init_configuration
  
  def __call__(self, example):
    for in_col, out_col in self.cols.items():
      example[out_col] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(example[in_col]))
    return example

  def __getstate__(self):
    "specify something you don't want pickle here, remember to use copy to not modfiy orginal instance"
    state = self.__dict__.copy() 
    state['tokenizer'] = None 
    return state

  def map(self, **kwargs):
    if self.remove_original:
      assert 'remove_columns' not in kwargs, "You have specified to remove all original columns."
      return self.hf_dset.map(self, remove_columns=self.hf_dset.column_names, **kwargs)
    else:
      return self.hf_dset.map(self, **kwargs)
      

In [3]:
cola = nlp.load_dataset('glue', 'cola')

tokenized_cola = {}
for split, dset in cola.items():
  tokenized_cola[split] = HF_TokenizeTfm(dset, {'sentence':'text_idxs'}, hf_fast_tokenizer).map(remove_columns=['sentence', 'idx'])

8551it [00:01, 4834.76it/s]
1043it [00:00, 9376.89it/s]
1063it [00:00, 9970.14it/s]


# 1. Integration with fastai

In [4]:
@delegates()
class HF_Dataloader(TfmdDL):
  """
  Args:
    `dataset`: any class that output a tuple, which has token ids in its first element, from `__getitem__`.
    `pad_idx` (`int`): If sepcified, pad texts to the longest text in the batch.
    `sort` (`Optional[bool]`, default: `True`): Sort the samples with their length, thus samples of similar legnth collated into a batch and we can pad less. Notice if it is True, the shuffle will be overrided and not shuffle.
    `filterout` (`Optional[callable(*args) -> bool]`, , default: `None`): if not `None`, judege whether exclude this sample with this sample(`tuple`) as args
    `cache_file` (`Optional[str]`, default: `None`): A name of json file to store the computed record of results of sort or filterout.   
  """
  def __init__(self, dataset, pad_idx, sort=True, filterout=None, cache_file=None, **kwargs):
    if pad_idx is not None:
      kwargs['before_batch'] = partial(pad_input_chunk, pad_idx=pad_idx, pad_first=False)

    cache_file = Path(cache_file) if cache_file else None
    if cache_file and cache_file.exists():
      with cache_file.open(mode='r') as f: self.samples = json.load(f)
    elif sort or filterout:
      if cache_file: cache_file.touch()
      try:
        if filterout is None: filterout = lambda *args: False
        self.samples = [ (i,len(sample[0])) for i, sample in tqdm(enumerate(dataset), leave=False) if not filterout(*sample) ]
        if sort: self.samples.sort(key=lambda t:t[1], reverse=True)
      except Exception as e:
        os.remove(cache_file)
        raise e
      if cache_file:
        with cache_file.open(mode='w') as f: json.dump(self.samples, f)
    else:
      self.samples = [ (i,None) for i in range(len(dataset))]

    store_attr(self, 'pad_idx,sort,filterout,cache_file')
    super().__init__(dataset, **kwargs)
    if sort: self.shuffle=False
    self.n = len(self.samples)
  
  def create_item(self, i): return self.dataset[self.samples[i][0]]

  def new(self, dataset, **kwargs):
    return super().new(dataset=dataset, pad_idx=self.pad_idx, sort=self.sort, filterout=self.filterout, **kwargs)

class HF_Dataset(FilteredBase):
  """
  Args:
    `hf_dset` (`nlp.Dataset`)
    `cols`: with one of the following signature:
    - `cols`(`List[str]`): 
      - if of length 1, regard the 1st element as text
      - if of length 2, regrad the 1st element as text, 2nd as category
    - `cols`(`Dict[Fastai2 Semantic Tesor]`): {`inp_col`:tensor type}: output sample as tuple of values of `inp_col` in order, and encode/decode with the tensor type,
    `hf_tokenizer`: tokenizer of HuggingFace/Transformers
    `pretty_show` (`Optional[bool]`, default:`False`): Show the original sentence instead of tokens.
  """
  def __init__(self, hf_dset, cols, hf_tokenizer=None, pretty_show=False, n_inp=1):
    
    # some default setting for tensor type used in decoding
    if isinstance(cols, list): 
      if n_inp==1: 
        if len(cols)==1: cols = {cols[0]: TensorText}
        elif len(cols)==2: cols = {cols[0]: TensorText, cols[1]: TensorCategory}
      else: cols = { c: noop for c in cols }
    assert isinstance(cols, dict)
    
    # make dataset output pytorch tensor
    if hf_dset.format['type'] != 'torch': 
      hf_dset.set_format( type='torch', columns=list(cols.keys()) )

    # store attributes
    store_attr(self, "hf_dset,cols,n_inp,hf_tokenizer,pretty_show")

  def __getitem__(self, idx):
    sample = self.hf_dset[idx]
    return tuple( tensor_cls(sample[col]) for col, tensor_cls in self.cols.items() )

  def __len__(self): return len(self.hf_dset)

  @property
  def column_names(self): return list(self.cols.keys())

  def decode(self, o, full=True): 
    return tuple( self._decode(o_) for o_ in o )

  @typedispatch
  def _decode(self, t:torch.Tensor): assert False, "You didn't specify a tensor type, thus not be able to decode and show."

  @typedispatch
  def _decode(self, t:TensorText): 
    assert self.hf_tokenizer, "You should give huggingface tokenizer if you want to show batch."
    if self.pretty_show: text = self.hf_tokenizer.decode([idx for idx in t if idx != self.hf_tokenizer.pad_token_id])
    else: text = ' '.join(self.hf_tokenizer.convert_ids_to_tokens(t))
    return TitledStr(text)

  @typedispatch
  def _decode(self, t:LMTensorText): return self._decode[TensorText](self, t)

  @typedispatch
  def _decode(self, t:TensorCategory): return Category(t.item())
  
class HF_Datasets(FilteredBase):
  _dl_type,_dbunch_type = HF_Dataloader,DataLoaders
  def __init__(self, hs_dsets: dict, *args, **kwargs):
    """
    Args:
      `hs_dsets` (`Dict[nlp.Dataset]`): the order of dict items will be the order of `HF_Dataloader`s  
    """
    self.hs_dsets = { split: HF_Dataset(dset, *args, **kwargs) for split, dset in hs_dsets.items()}
  def subset(self, i): return list(self.hs_dsets.values())[i]
  def __getitem__(self, split): return self.hs_dsets[split]
  @property
  def n_subsets(self): return len(self.hs_dsets)
  def dataloaders(self, *args, cache_files=None, device='cpu', **kwargs):
    """
    Args:
      `*args, **kwargs`: `for FilteredBase.dataloaders`
      `cache_files` (`Optional[str]`, default:`None`): cache file names for `HF_Dataloader`s
      `device` (`Optional[str]`, default:`'cpu'`): cuz will read a batch for test when creating `Dataloader`, so I set the default device to cpu to less the memory burden of cuda:0 
    """
    dl_kwargs = kwargs.pop('dl_kwargs', [{} for _ in range(len(self.hs_dsets))])
    if cache_files:
      assert len(cache_files) == len(self.hs_dsets)
      for i, dl_kwarg in enumerate(dl_kwargs): dl_kwarg['cache_file'] = cache_files[i]
    return super().dataloaders(*args, dl_kwargs=dl_kwargs, device=device, **kwargs)

In [5]:
# tokenized_cola is {'train':nlp.Dataset, 'validation':nlp.Dataset, 'test':nlp.Dataset}
cola_dsets = HF_Datasets(tokenized_cola, ['text_idxs', 'label'], hf_fast_tokenizer, pretty_show=True)
cola_dls = cola_dsets.dataloaders(bs=32, pad_idx=hf_fast_tokenizer.pad_token_id)
cola_dls.show_batch()

823it [00:00, 4065.32it/s]

Unnamed: 0,text,category
0,"everybody who has ever, worked in any office which contained any typewriter which had ever been used to type any letters which had to be signed by any administrator who ever worked in any department like mine will know what i mean.",1
1,"hank plays the guitar and finds arrangements for all the old folk songs which are still sung in these hills, and ernie writes down all the old folk songs which are still sung in these hills.",1
2,"playing with matches is ; lots of fun, but doing, so and emptying gasoline from one can to another at the same time is a sport best reserved for arsons.",1
3,"in january 2002, a dull star in an obscure constellation suddenly became 600, 000 times more luminous than our sun, temporarily making it the brightest star in our galaxy.",1
4,"which folks up at corporate headquarters do you think that the sooner you solve this problem, the quicker you'll be able to tell t to buzz off?",0
5,"i finally worked up enough courage to ask which people up at corporate headquarters the sooner i solve this problem, the quicker i'll get free of.",0
6,"ron wanted to wear a tuxedo to the party, but wear a tuxedo to the party caspar couldn't decide whether to.",0
7,"the dumplings which sasha is gobbling down faster than i can reheat the meatballs are extremely tasty, if i do say so.",1
8,"sam picked those packages up which are to be mailed tomorrow rest might, but he didn't want to do so until it had stopped raining.",1


# 2. Aggregate samples of HuggingFace/nlp dataset
~ take traditional language model task for example ~

In [6]:
class AggregateTransform():
  """
  Inherit this class and implement `accumulate` and `create_example`
  """
  def __init__(self, hf_dset, inp_cols, out_cols, init_attrs, drop_last=False):
    """
    Args:
      `hf_dset` (`nlp.Dataset`)
      `inp_cols` (`List[str]`)
      `out_cols` (`List[str]`)
      `init_attrs` (`List[str]`): name of attributes of children class that need to be their initial status when starts to aggregate dataset. i.e. Those defined in `__init__` and the value will changed during `accumulate`
      `drop_last` (`Optional[bool]`, default: `False`): whether to drop the last accumulated sample.
    """
    self.hf_dset = hf_dset
    self.inp_cols, self.out_cols =  inp_cols, out_cols
    # batched map need dataset be in python format
    hf_dset.set_format(type=None, columns=inp_cols)
    # dealing with last sample
    self.last_idx = len(hf_dset) - 1
    self.drop_last = drop_last
    # for reset
    self.init_attrs = init_attrs
    self.original_vals = [deepcopy(getattr(self, attr)) for attr in init_attrs]  

  def __call__(self, b, indices):
    # `nlp.Dataset.map` first test with several samples which affects our attrs, so we need to reinitialize.
    if 0 in indices: # reset
      for attr,val in zip(self.init_attrs, self.original_vals): setattr(self, attr, deepcopy(val))

    self.new_b = { c:[] for c in self.out_cols }
    for z in zip(*b.values()):
      self.accumulate(*z)
    
    # whehther put last example when it is last batch of `map`
    if not self.drop_last and self.last_idx in indices: 
      self.commit_example(self.create_example())

    return self.new_b

  def commit_example(self, example):
    if example is None: return
    for col,val in example.items():
      self.new_b[col].append(val) 

  def accumulate(self, *args):
    """
    Given a example, do `self.commit_example(self.create_example()) when a new aggregated sample is ready.`
    Args:
      `args`: nlp.Dataset[i][inp_col] for inp_col in self.inp_cols
    """ 
    raise NotImplementedError
  
  def create_example(self): 
    """
    When it is ready, create a sample (Dict[Any])
    """
    raise NotImplementedError

  def map(self, batch_size=1000, test_batch_size=20, **kwargs):
    """
    `batch_size`: see `nlp.Dataset.map`
    `test_batch_size` (`int`, default=`20`): we infer the new schema of the aggregated dataset by the outputs of testing that passed first `test_batch_size` samples to aggregate. Depending how many sample aggreagted can you have a sample, this number might need to be higher.
    """
    test_inputs, test_indices = self.hf_dset[:test_batch_size], list(range(test_batch_size))
    test_output = self(test_inputs,test_indices)
    for col,val in test_output.items(): assert val, f"Didn't get any example in test, you might want to try larger `test_batch_size` than {test_batch_size}"
    assert sorted(self.out_cols) == sorted(test_output.keys()), f"Output columns are {self.out_cols}, but get example with {list(test_output.keys())}"
    arrow_schema = pa.Table.from_pydict(test_output).schema
    return self.hf_dset.map(function=self, batched=True, batch_size=batch_size, with_indices=True,
                            arrow_schema=arrow_schema, **kwargs)

class LMTransform(AggregateTransform):
  def __init__(self, hf_dset, max_len, text_col, x_text_col='x_text', y_text_col='y_text', **kwargs):
    self.text_col, self.x_text_col, self.y_text_col = text_col, x_text_col, y_text_col
    self._max_len = max_len + 1
    self.residual_len, self.new_text = self._max_len, []
    super().__init__(hf_dset, inp_cols=[text_col], out_cols=[x_text_col, y_text_col], init_attrs=['residual_len', 'new_text'], **kwargs)
    

  def accumulate(self, text): # *inp_cols
    "text: a list of indices"
    usable_len = len(text)
    cursor = 0
    while usable_len != 0:
      use_len = min(usable_len, self.residual_len)
      self.new_text += text[cursor:cursor+use_len]
      self.residual_len -= use_len
      usable_len -= use_len
      cursor += use_len
      if self.residual_len == 0:
        self.commit_example(self.create_example())   

  def create_example(self):
    # when read all data, the accumulated new_text might be less than two characters.
    if len(self.new_text) >= 2: 
      example = {self.x_text_col:self.new_text[:-1], self.y_text_col:self.new_text[1:]}
    else:
      example = None # mark "don't commit this"
    # reset accumulators
    self.new_text = []
    self.residual_len = self._max_len

    return example

In [7]:
cola_val = tokenized_cola['validation']
lm_dataset = LMTransform(cola_val, max_len=20, text_col='text_idxs').map()

print('Original dataset:')
print('num of samples:', len(cola['validation']))
print('second to last sentence:', cola['validation'][-2]['sentence'])
print('          last sentence:', cola['validation'][-1]['sentence'])
print('LM dataset:')
print('num of sampels:', len(lm_dataset))
print('last text (x):', hf_fast_tokenizer.decode(lm_dataset[-1]['x_text']))
print('last text (y):', hf_fast_tokenizer.decode(lm_dataset[-1]['y_text']))

100%|██████████| 2/2 [00:00<00:00, 321.32it/s]Original dataset:
num of samples: 1043
second to last sentence: John arranged for himself to get the prize.
          last sentence: John talked to Bill about himself.
LM dataset:
num of sampels: 481
last text (x): . john talked to bill about himself
last text (y): john talked to bill about himself.



In [8]:
lm_dl = HF_Dataloader(HF_Dataset(lm_dataset, {'x_text':LMTensorText, 'y_text':TensorText},hf_fast_tokenizer), sort=False, pad_idx=hf_fast_tokenizer.pad_token_id)
lm_dl.show_batch()

Unnamed: 0,text,text_
0,the sailors rode the breeze clear of the rocks . the weights made the rope stretch over the pull ##ey,sailors rode the breeze clear of the rocks . the weights made the rope stretch over the pull ##ey .
1,"the mechanical doll wr ##ig ##gled itself loose . if you had eaten more , you would want less .","mechanical doll wr ##ig ##gled itself loose . if you had eaten more , you would want less . as"
2,"you eat the most , you want the least . the more you would want , the less you would","eat the most , you want the least . the more you would want , the less you would eat"
3,". i demand that the more john eat , the more he pays . mary listen ##s to the grateful","i demand that the more john eat , the more he pays . mary listen ##s to the grateful dead"
4,", she gets depressed . the ang ##rier mary got , the more she looked at pictures . the higher","she gets depressed . the ang ##rier mary got , the more she looked at pictures . the higher the"
5,"stakes , the lower his expectations are . the more fred is ob ##no ##xious , the less attention you",", the lower his expectations are . the more fred is ob ##no ##xious , the less attention you should"
6,pay to him . john was lots more ob ##no ##xious than fred . the more people you give beer,to him . john was lots more ob ##no ##xious than fred . the more people you give beer to
7,", the more people get sick . the more does bill smoke , the more susan hates him . the","the more people get sick . the more does bill smoke , the more susan hates him . the more"
8,"pictures of him that appear in the news , the more embarrassed john becomes . every senator seems to become","of him that appear in the news , the more embarrassed john becomes . every senator seems to become more"


# 3. ELECTRA Dataloading

In [9]:
class ELECTRADataTransform(AggregateTransform):
  
  def __init__(self, hf_dset, in_col, out_col, max_length, cls_idx, sep_idx):
    self.in_col, self.out_col = in_col, out_col
    self._current_sentences = []
    self._current_length = 0
    self._max_length = max_length
    self._target_length = max_length
    self.cls_idx, self.sep_idx = cls_idx, sep_idx
    super().__init__(hf_dset, inp_cols=[in_col], out_cols=[out_col], 
                    init_attrs=['_current_sentences', '_current_length', '_target_length'])

  # two functions required by AggregateTransform
  def accumulate(self, tokids):
    self.add_line(tokids)
  
  def create_example(self):
    input_ids = self._create_example()
    return {self.out_col: input_ids}

  def add_line(self, tokids):
    """Adds a line of text to the current example being built."""
    self._current_sentences.append(tokids)
    self._current_length += len(tokids)
    if self._current_length >= self._target_length:
      self.commit_example(self.create_example())

  def _create_example(self):
    """Creates a pre-training example from the current list of sentences."""
    # small chance to only have one segment as in classification tasks
    if random.random() < 0.1:
      first_segment_target_length = 100000
    else:
      # -3 due to not yet having [CLS]/[SEP] tokens in the input text
      first_segment_target_length = (self._target_length - 3) // 2

    first_segment = []
    second_segment = []
    for sentence in self._current_sentences:
      # the sentence goes to the first segment if (1) the first segment is
      # empty, (2) the sentence doesn't put the first segment over length or
      # (3) 50% of the time when it does put the first segment over length
      if (len(first_segment) == 0 or
          len(first_segment) + len(sentence) < first_segment_target_length or
          (len(second_segment) == 0 and
           len(first_segment) < first_segment_target_length and
           random.random() < 0.5)):
        first_segment += sentence
      else:
        second_segment += sentence

    # trim to max_length while accounting for not-yet-added [CLS]/[SEP] tokens
    first_segment = first_segment[:self._max_length - 2]
    second_segment = second_segment[:max(0, self._max_length -
                                         len(first_segment) - 3)]

    # prepare to start building the next example
    self._current_sentences = []
    self._current_length = 0
    # small chance for random-length instead of max_length-length example
    if random.random() < 0.05:
      self._target_length = random.randint(5, self._max_length)
    else:
      self._target_length = self._max_length

    return self._make_example(first_segment, second_segment)

  def _make_example(self, first_segment, second_segment):
    """Converts two "segments" of text into a tf.train.Example."""
    input_ids = [self.cls_idx] + first_segment + [self.sep_idx]
    if second_segment:
      input_ids += second_segment + [self.sep_idx]
    return input_ids

edset = ELECTRADataTransform(cola_val, 'text_idxs', 'tokids', 50, hf_fast_tokenizer.cls_token_id, hf_fast_tokenizer.sep_token_id).map()

100%|██████████| 2/2 [00:00<00:00, 414.97it/s]


In [10]:
e_dl = HF_Dataloader(HF_Dataset(edset, ['tokids'], hf_fast_tokenizer), pad_idx=hf_fast_tokenizer.pad_token_id, sort=False)
e_dl.show_batch()

Unnamed: 0,text
0,"[CLS] the sailors rode the breeze clear of the rocks . the weights made the rope stretch over the pull ##ey . [SEP] the mechanical doll wr ##ig ##gled itself loose . if you had eaten more , you would want less . as you eat the most , [SEP]"
1,"[CLS] the more you would want , the less you would eat . i demand that the more john eat , the more he pays . mary listen ##s to the grateful dead , she gets depressed . the ang ##rier mary got , the more she looked at [SEP]"
2,"[CLS] the higher the stakes , the lower his expectations are . john was lots more ob ##no ##xious than fred . [SEP] the more fred is ob ##no ##xious , the less attention you should pay to him . the more people you give beer to , the [SEP]"
3,"[CLS] the more does bill smoke , the more susan hates him . who does john visit sally because he likes ? [SEP] the more pictures of him that appear in the news , the more embarrassed john becomes . every senator seems to become more corrupt , as [SEP]"
4,[CLS] marianne did not leave . he could not ] have been working . he can not have been working . [SEP] you will believe bob . john has not kissed mary . i said that never in my life had i seen a place like bangor . mickey [SEP]
5,[CLS] there tended to be a lot of discussion . john tried to be a good boy . john is eager . we want john to win . [SEP] the box contained the ball from the tree . the tube was escaped by gas . water bubble ##d up [SEP]
6,[CLS] the tub leaked water . what the water did to the bottle was fill it . john owns the book . [SEP] what the water did to the whole bottle was fill it . the tank leaked the fluid free . john lay the ball in the box [SEP]
7,"[CLS] most people probably consider , even though the courts didn ' t actually find , klaus guilty of murder . [SEP] mary beautifully plays the violin . clearly , john probably will immediately learn french perfectly . sue gave to bill a book . the men will all [SEP]"
8,"[CLS] they represented seriously to the dean mary as a genuine linguist . us love they . it is nice to go abroad . [SEP] mary intended john to go abroad . i remembered having kissed mary . i can ' t believe fred won ' t , either [SEP]"


# 4. Test caching and filtering feature

In [11]:
l = 23
num = {}
for split in tokenized_cola:
  num[split] = reduce(lambda sum, sample: sum+(1 if len(sample['text_idxs'])==l else 0), 
                      tokenized_cola[split], 0)
print(num)

{'train': 26, 'validation': 2, 'test': 6}


In [12]:
for f in ['/tmp/cctrain.json','/tmp/ccval.json', '/tmp/cctest.json']
  if Path(f).exists(): os.remove(f)

SyntaxError: invalid syntax (<ipython-input-12-6a20dc8f1b48>, line 1)

In [13]:
ccola_dsets = HF_Datasets(tokenized_cola, ['text_idxs', 'label'], hf_fast_tokenizer, pretty_show=True)
ccola_dls = ccola_dsets.dataloaders(bs=32, pad_idx=hf_fast_tokenizer.pad_token_id, filterout=lambda inpids,label: len(inpids)==l, cache_files=['/tmp/cctrain.json','/tmp/ccval.json', '/tmp/cctest.json'])
ccola_dls.show_batch(max_n=2)

Unnamed: 0,text,category
0,"everybody who has ever, worked in any office which contained any typewriter which had ever been used to type any letters which had to be signed by any administrator who ever worked in any department like mine will know what i mean.",1
1,"hank plays the guitar and finds arrangements for all the old folk songs which are still sung in these hills, and ernie writes down all the old folk songs which are still sung in these hills.",1


Test if we correctly filter out by checking the number of samples

In [14]:
for i, split in enumerate(tokenized_cola):
  assert ccola_dls[i].n == len(tokenized_cola[split])-num[split],f"{split}: {ccola_dls[i].n}, {len(tokenized_cola[split])}, {num[split]}"

This time we load the caches, and it should be fast and bars sholdn't appear

In [15]:
ccola_dls = ccola_dsets.dataloaders(bs=32, pad_idx=hf_fast_tokenizer.pad_token_id, filterout=lambda inpids,label: len(inpids)==l, cache_files=['/tmp/cctrain.json','/tmp/ccval.json', '/tmp/cctest.json'])