In [1]:
from IPython.core.debugger import set_trace as bk
from pathlib import Path
from functools import partial
from tqdm import tqdm
import torch
import pyarrow as pa
import nlp
from transformers import ElectraTokenizerFast
hf_fast_tokenizer = ElectraTokenizerFast.from_pretrained("google/electra-small-generator")
from fastai2.text.all import *

In [62]:
class HF_TokenizeTfm():
  
  def __init__(self, hf_dset, cols, hf_tokenizer, remove_original=False):
    if isinstance(cols, list): cols = {c:c for c in cols}
    assert isinstance(cols, dict)
    self.hf_dset, self.cols, self.tokenizer = hf_dset, cols, hf_tokenizer
    self.remove_original = remove_original
    """
    If don't specify cache file name, it will be hashed binary of pickled function that
    passed to `map`, so if you pass the same function, it knows to use cache.
    But tokenizer can't be pickled, so use tokenizer config to make tfms use different 
    tokenizers unique.  
    """
    self.tokenizer_config = hf_tokenizer.pretrained_init_configuration
  
  def __call__(self, example):
    for in_col, out_col in self.cols.items():
      example[out_col] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(example[in_col]))
    return example

  def __getstate__(self):
    "specify something you don't want pickle here, remember to use copy to not modfiy orginal instance"
    state = self.__dict__.copy() 
    state['tokenizer'] = None 
    return state

  def map(self, **kwargs):
    if self.remove_original:
      assert 'remove_columns' not in kwargs, "You have specified to remove all original columns."
      return self.hf_dset.map(self, remove_columns=self.hf_dset.column_names, **kwargs)
    else:
      return self.hf_dset.map(self, **kwargs)
      

In [66]:
cola = nlp.load_dataset('glue', 'cola')

tokenized_cola = {}
for split, dset in cola.items():
  tokenized_cola[split] = HF_TokenizeTfm(dset, {'sentence':'text_idxs'}, hf_fast_tokenizer, remove_original=True).map()

8551it [00:01, 6405.28it/s]
1043it [00:00, 9191.29it/s]
1063it [00:00, 8597.16it/s]


# 1. Integration with fastai

In [3]:
@delegates()
class HF_Dataloader(TfmdDL):
  
  def __init__(self, dataset, pad_idx, sort=True, **kwargs):
    if pad_idx is not None:
      kwargs['before_batch'] = partial(pad_input_chunk, pad_idx=pad_idx, pad_first=False)
    if sort:
      self.lens = [ len(sample[0]) for sample in dataset ]
    store_attr(self, 'pad_idx,sort')
    super().__init__(dataset, **kwargs)
  
  def get_idxs(self):
    idxs = super().get_idxs()
    if not self.sort : return idxs
    return sorted(idxs, key=lambda i: self.lens[i], reverse=True)

  def new(self, dataset, **kwargs):
    return super().new(dataset=self.dataset, pad_idx=self.pad_idx, sort=self.sort, **kwargs)

class HF_Dataset(FilteredBase):
  
  def __init__(self, hf_dset, cols, hf_tokenizer=None, pretty_show=False, n_inp=1):
    
    # some default setting for tensor type used in decoding
    if isinstance(cols, list): 
      if n_inp==1: 
        if len(cols)==1: cols = {cols[0]: TensorText}
        elif len(cols)==2: cols = {cols[0]: TensorText, cols[1]: TensorCategory}
      else: cols = { c: noop for c in cols }
    assert isinstance(cols, dict)
    
    # make dataset output pytorch tensor
    if hf_dset.format['type'] != 'torch': 
      hf_dset.set_format( type='torch', columns=list(cols.keys()) )

    # store attributes
    store_attr(self, "hf_dset,cols,n_inp,hf_tokenizer,pretty_show")

  def __getitem__(self, idx):
    sample = self.hf_dset[idx]
    return tuple( tensor_cls(sample[col]) for col, tensor_cls in self.cols.items() )

  def __len__(self): return len(self.hf_dset)

  @property
  def column_names(self): return list(self.cols.keys())

  def decode(self, o, full=True): 
    return tuple( self._decode(o_) for o_ in o )

  @typedispatch
  def _decode(self, t:TensorText): 
    assert self.hf_tokenizer, "You should give huggingface tokenizer if you want to show batch."
    if self.pretty_show: text = self.hf_tokenizer.decode([idx for idx in t if idx != self.hf_tokenizer.pad_token_id])
    else: text = ' '.join(self.hf_tokenizer.convert_ids_to_tokens(t))
    return TitledStr(text)

  @typedispatch
  def _decode(self, t:LMTensorText): return self._decode[TensorText](self, t)

  @typedispatch
  def _decode(self, t:TensorCategory): return Category(t.item())
  
class HF_Datasets(FilteredBase):
  _dl_type,_dbunch_type = HF_Dataloader,DataLoaders
  def __init__(self, hs_dsets: dict, *args, **kwargs):
    self.hs_dsets = { split: HF_Dataset(dset, *args, **kwargs) for split, dset in hs_dsets.items()}
  def subset(self, i): return list(self.hs_dsets.values())[i]
  def __getitem__(self, split): return self.hs_dsets[split]
  @property
  def n_subsets(self): return len(self.hs_dsets)

In [4]:
cola_dsets = HF_Datasets(tokenized_cola, ['text_idxs', 'label'], hf_fast_tokenizer, pretty_show=True)
cola_dls = cola_dsets.dataloaders(bs=32, pad_idx=hf_fast_tokenizer.pad_token_id)
cola_dls.show_batch()

Unnamed: 0,text,category
0,"everybody who has ever, worked in any office which contained any typewriter which had ever been used to type any letters which had to be signed by any administrator who ever worked in any department like mine will know what i mean.",1
1,"hank plays the guitar and finds arrangements for all the old folk songs which are still sung in these hills, and ernie writes down all the old folk songs which are still sung in these hills.",1
2,"playing with matches is ; lots of fun, but doing, so and emptying gasoline from one can to another at the same time is a sport best reserved for arsons.",1
3,"in january 2002, a dull star in an obscure constellation suddenly became 600, 000 times more luminous than our sun, temporarily making it the brightest star in our galaxy.",1
4,"which folks up at corporate headquarters do you think that the sooner you solve this problem, the quicker you'll be able to tell t to buzz off?",0
5,"one of the jewish children is a spunky girl, who gave a black eye to the kid with the german roots before the start of the war.",1
6,"will put a picture of bill on your desk before tomorrow, this girl in the red coat will put a picture of bill on your desk before tomorrow.",0
7,"the dumplings which sasha is gobbling down faster than i can reheat the meatballs are extremely tasty, if i do say so.",1
8,"sam picked those packages up which are to be mailed tomorrow rest might, but he didn't want to do so until it had stopped raining.",1


# 2. Aggregate samples of HuggingFace/nlp dataset
~ take traditional language model task for example ~

In [32]:
class AggregateTransform():
  def __init__(self, hf_dset, inp_cols, out_cols, init_attrs, drop_last=False):
    self.hf_dset = hf_dset
    self.inp_cols, self.out_cols =  inp_cols, out_cols
    # batched map need dataset be in python format
    hf_dset.set_format(type=None, columns=inp_cols)
    # dealing with last sample
    self.last_idx = len(hf_dset) - 1
    self.drop_last = drop_last
    # for reset
    self.init_attrs = init_attrs
    self.original_vals = [deepcopy(getattr(self, attr)) for attr in init_attrs]  

  def __call__(self, b, indices):
    # `nlp.Dataset.map` first test with several samples which affects our attrs, so we need to reinitialize.
    if 0 in indices: # reset
      for attr,val in zip(self.init_attrs, self.original_vals): setattr(self, attr, deepcopy(val))

    self.new_b = { c:[] for c in self.out_cols }
    for z in tqdm(list(zip(*b.values())), leave=False):
      self.accumulate(*z)
    
    # whehther put last example when it is last batch of `map`
    if not self.drop_last and self.last_idx in indices: 
      self.commit_example(self.create_example())

    return self.new_b

  def commit_example(self, example):
    if example is None: return
    for col,val in example.items():
      self.new_b[col].append(val) 

  def accumulate(self, *args): raise NotImplementedError
  def create_example(self): raise NotImplementedError

  def map(self, batch_size=1000, test_batch_size=20, **kwargs):
    test_inputs, test_indices = self.hf_dset[:test_batch_size], list(range(test_batch_size))
    test_output = self(test_inputs,test_indices)
    for col,val in test_output.items(): assert val, f"Didn't get any example in test, you might want to try larger `test_batch_size` than {test_batch_size}"
    assert sorted(self.out_cols) == sorted(test_output.keys()), f"Output columns are {self.out_cols}, but get example with {list(test_output.keys())}"
    arrow_schema = pa.Table.from_pydict(test_output).schema
    return self.hf_dset.map(function=self, batched=True, batch_size=batch_size, with_indices=True,
                            arrow_schema=arrow_schema, **kwargs)

class LMTransform(AggregateTransform):
  def __init__(self, hf_dset, max_len, text_col, x_text_col='x_text', y_text_col='y_text', **kwargs):
    self.text_col, self.x_text_col, self.y_text_col = text_col, x_text_col, y_text_col
    self._max_len = max_len + 1
    self.residual_len, self.new_text = self._max_len, []
    super().__init__(hf_dset, inp_cols=[text_col], out_cols=[x_text_col, y_text_col], init_attrs=['residual_len', 'new_text'], **kwargs)
    

  def accumulate(self, text): # *inp_cols
    "text: a list of indices"
    usable_len = len(text)
    cursor = 0
    while usable_len != 0:
      use_len = min(usable_len, self.residual_len)
      self.new_text += text[cursor:cursor+use_len]
      self.residual_len -= use_len
      usable_len -= use_len
      cursor += use_len
      if self.residual_len == 0:
        self.commit_example(self.create_example())   

  def create_example(self):
    # when read all data, the accumulated new_text might be less than two characters.
    if len(self.new_text) >= 2: 
      example = {self.x_text_col:self.new_text[:-1], self.y_text_col:self.new_text[1:]}
    else:
      example = None # mark "don't commit this"
    # reset accumulators
    self.new_text = []
    self.residual_len = self._max_len

    return example

In [34]:
cola_val = tokenized_cola['validation']
lm_dataset = LMTransform(cola_val, max_len=20, text_col='text_idxs').map()

print('Original dataset:')
print('num of samples:', len(cola['validation']))
print('second to last sentence:', cola['validation'][-2]['sentence'])
print('          last sentence:', cola['validation'][-1]['sentence'])
print('LM dataset:')
print('num of sampels:', len(lm_dataset))
print('last text (x):', hf_fast_tokenizer.decode(lm_dataset[-1]['x_text']))
print('last text (y):', hf_fast_tokenizer.decode(lm_dataset[-1]['y_text']))

0%|          | 0/2 [00:00<?, ?it/s]Original dataset:
num of samples: 1043
second to last sentence: John arranged for himself to get the prize.
          last sentence: John talked to Bill about himself.
LM dataset:
num of sampels: 481
last text (x): . john talked to bill about himself
last text (y): john talked to bill about himself.


In [35]:
lm_dl = HF_Dataloader(HF_Dataset(lm_dataset, {'x_text':LMTensorText, 'y_text':TensorText},hf_fast_tokenizer), sort=False, pad_idx=hf_fast_tokenizer.pad_token_id)
lm_dl.show_batch()

Unnamed: 0,text,text_
0,the sailors rode the breeze clear of the rocks . the weights made the rope stretch over the pull ##ey,sailors rode the breeze clear of the rocks . the weights made the rope stretch over the pull ##ey .
1,"the mechanical doll wr ##ig ##gled itself loose . if you had eaten more , you would want less .","mechanical doll wr ##ig ##gled itself loose . if you had eaten more , you would want less . as"
2,"you eat the most , you want the least . the more you would want , the less you would","eat the most , you want the least . the more you would want , the less you would eat"
3,". i demand that the more john eat , the more he pays . mary listen ##s to the grateful","i demand that the more john eat , the more he pays . mary listen ##s to the grateful dead"
4,", she gets depressed . the ang ##rier mary got , the more she looked at pictures . the higher","she gets depressed . the ang ##rier mary got , the more she looked at pictures . the higher the"
5,"stakes , the lower his expectations are . the more fred is ob ##no ##xious , the less attention you",", the lower his expectations are . the more fred is ob ##no ##xious , the less attention you should"
6,pay to him . john was lots more ob ##no ##xious than fred . the more people you give beer,to him . john was lots more ob ##no ##xious than fred . the more people you give beer to
7,", the more people get sick . the more does bill smoke , the more susan hates him . the","the more people get sick . the more does bill smoke , the more susan hates him . the more"
8,"pictures of him that appear in the news , the more embarrassed john becomes . every senator seems to become","of him that appear in the news , the more embarrassed john becomes . every senator seems to become more"


# 3. ELECTRA Dataloading

In [38]:
class ELECTRADataTransform(AggregateTransform):
  
  def __init__(self, hf_dset, in_col, out_col, max_length, cls_idx, sep_idx):
    self.in_col, self.out_col = in_col, out_col
    self._current_sentences = []
    self._current_length = 0
    self._max_length = max_length
    self._target_length = max_length
    self.cls_idx, self.sep_idx = cls_idx, sep_idx
    super().__init__(hf_dset, inp_cols=[in_col], out_cols=[out_col], 
                    init_attrs=['_current_sentences', '_current_length', '_target_length'])

  # two functions required by AggregateTransform
  def accumulate(self, tokids):
    self.add_line(tokids)
  
  def create_example(self):
    input_ids = self._create_example()
    return {self.out_col: input_ids}

  def add_line(self, tokids):
    """Adds a line of text to the current example being built."""
    self._current_sentences.append(tokids)
    self._current_length += len(tokids)
    if self._current_length >= self._target_length:
      self.commit_example(self.create_example())

  def _create_example(self):
    """Creates a pre-training example from the current list of sentences."""
    # small chance to only have one segment as in classification tasks
    if random.random() < 0.1:
      first_segment_target_length = 100000
    else:
      # -3 due to not yet having [CLS]/[SEP] tokens in the input text
      first_segment_target_length = (self._target_length - 3) // 2

    first_segment = []
    second_segment = []
    for sentence in self._current_sentences:
      # the sentence goes to the first segment if (1) the first segment is
      # empty, (2) the sentence doesn't put the first segment over length or
      # (3) 50% of the time when it does put the first segment over length
      if (len(first_segment) == 0 or
          len(first_segment) + len(sentence) < first_segment_target_length or
          (len(second_segment) == 0 and
           len(first_segment) < first_segment_target_length and
           random.random() < 0.5)):
        first_segment += sentence
      else:
        second_segment += sentence

    # trim to max_length while accounting for not-yet-added [CLS]/[SEP] tokens
    first_segment = first_segment[:self._max_length - 2]
    second_segment = second_segment[:max(0, self._max_length -
                                         len(first_segment) - 3)]

    # prepare to start building the next example
    self._current_sentences = []
    self._current_length = 0
    # small chance for random-length instead of max_length-length example
    if random.random() < 0.05:
      self._target_length = random.randint(5, self._max_length)
    else:
      self._target_length = self._max_length

    return self._make_example(first_segment, second_segment)

  def _make_example(self, first_segment, second_segment):
    """Converts two "segments" of text into a tf.train.Example."""
    input_ids = [self.cls_idx] + first_segment + [self.sep_idx]
    if second_segment:
      input_ids += second_segment + [self.sep_idx]
    return input_ids

edset = ELECTRADataTransform(cola_val, 'text_idxs', 'tokids', 50, hf_fast_tokenizer.cls_token_id, hf_fast_tokenizer.sep_token_id).map()

0%|          | 0/2 [00:00<?, ?it/s]
  0%|          | 0/1000 [00:00<?, ?it/s][A
[A
  0%|          | 0/43 [00:00<?, ?it/s][A
100%|██████████| 2/2 [00:00<00:00, 109.82it/s]


In [40]:
e_dl = HF_Dataloader(HF_Dataset(edset, ['tokids'], hf_fast_tokenizer), pad_idx=hf_fast_tokenizer.pad_token_id, sort=False)
e_dl.show_batch()

Unnamed: 0,text
0,"[CLS] the sailors rode the breeze clear of the rocks . the weights made the rope stretch over the pull ##ey . the mechanical doll wr ##ig ##gled itself loose . [SEP] if you had eaten more , you would want less . as you eat the most , [SEP]"
1,"[CLS] the more you would want , the less you would eat . i demand that the more john eat , the more he pays . [SEP] mary listen ##s to the grateful dead , she gets depressed . the ang ##rier mary got , the more she looked [SEP]"
2,"[CLS] the higher the stakes , the lower his expectations are . the more fred is ob ##no ##xious , the less attention you should pay to him . [SEP] john was lots more ob ##no ##xious than fred . the more people you give beer to , the [SEP]"
3,"[CLS] the more does bill smoke , the more susan hates him . who does john visit sally because he likes ? [SEP] the more pictures of him that appear in the news , the more embarrassed john becomes . every senator seems to become more corrupt , as [SEP]"
4,[CLS] marianne did not leave . he could not ] have been working . he can not have been working . you will believe bob . [SEP] john has not kissed mary . i said that never in my life had i seen a place like bangor . mickey [SEP]
5,[CLS] there tended to be a lot of discussion . john tried to be a good boy . john is eager . [SEP] we want john to win . the box contained the ball from the tree . the tube was escaped by gas . water bubble ##d up [SEP]
6,[CLS] the tub leaked water . what the water did to the bottle was fill it . what the water did to the whole bottle was fill it . [SEP] the tank leaked the fluid free . john lay the ball in the box . john owns the book [SEP]
7,"[CLS] most people probably consider , even though the courts didn ' t actually find , klaus guilty of murder . [SEP] mary beautifully plays the violin . clearly , john probably will immediately learn french perfectly . sue gave to bill a book . the men will all [SEP]"
8,[CLS] they represented seriously to the dean mary as a genuine linguist . us love they . [SEP] it is nice to go abroad . mary intended john to go abroad . [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
