In [1]:
from IPython.core.debugger import set_trace as bk
from pathlib import Path
from functools import partial
import torch
import nlp
from tqdm import tqdm
from transformers import ElectraTokenizerFast
hf_fast_tokenizer = ElectraTokenizerFast.from_pretrained("google/electra-small-generator")
from fastai2.text.all import *

In [2]:
cola = nlp.load_dataset('glue', 'cola', cache_dir='~/tmp')

def HF_TokenizeTfm(cols):
  if isinstance(cols, list): cols = {c:c for c in cols}
  assert isinstance(cols, dict)
  def _tokenize(example):
    for in_col, out_col in cols.items():
      example[out_col] = hf_fast_tokenizer.convert_tokens_to_ids(hf_fast_tokenizer.tokenize(example[in_col]))
    return example
  return _tokenize

tokenized_cola = {}
for split, dset in cola.items():
  tokenized_cola[split] = dset.map(HF_TokenizeTfm({'sentence':'text_idxs'}),
                          remove_columns=['sentence'])

# 1. Integration with fastai

In [3]:
@delegates()
class HF_Dataloader(TfmdDL):
  
  def __init__(self, dataset, pad_idx, sort=True, **kwargs):
    if pad_idx is not None:
      kwargs['before_batch'] = partial(pad_input_chunk, pad_idx=pad_idx, pad_first=False)
    if sort:
      self.lens = [ len(sample[0]) for sample in dataset ]
    store_attr(self, 'pad_idx,sort')
    super().__init__(dataset, **kwargs)
  
  def get_idxs(self):
    idxs = super().get_idxs()
    if not self.sort : return idxs
    return sorted(idxs, key=lambda i: self.lens[i], reverse=True)

  def new(self, dataset, **kwargs):
    return super().new(dataset=self.dataset, pad_idx=self.pad_idx, sort=self.sort, **kwargs)

class HF_Dataset(FilteredBase):
  
  def __init__(self, hf_dset, cols, hf_tokenizer=None, pretty_show=False, n_inp=1):
    
    # some default setting for tensor type used in decoding
    if isinstance(cols, list): 
      if n_inp==1: 
        if len(cols)==1: cols = {cols[0]: TensorText}
        elif len(cols)==2: cols = {cols[0]: TensorText, cols[1]: TensorCategory}
      else: cols = { c: noop for c in cols }
    assert isinstance(cols, dict)
    
    # make dataset output pytorch tensor
    if hf_dset.format['type'] != 'torch': 
      hf_dset.set_format( type='torch', columns=list(cols.keys()) )

    # store attributes
    store_attr(self, "hf_dset,cols,n_inp,hf_tokenizer,pretty_show")

  def __getitem__(self, idx):
    sample = self.hf_dset[idx]
    return tuple( tensor_cls(sample[col]) for col, tensor_cls in self.cols.items() )

  def __len__(self): return len(self.hf_dset)

  @property
  def column_names(self): return list(self.cols.keys())

  def decode(self, o, full=True): 
    return tuple( self._decode(o_) for o_ in o )

  @typedispatch
  def _decode(self, t:TensorText): 
    assert self.hf_tokenizer, "You should give huggingface tokenizer if you want to show batch."
    if self.pretty_show: text = self.hf_tokenizer.decode([idx for idx in t if idx != self.hf_tokenizer.pad_token_id])
    else: text = ' '.join(self.hf_tokenizer.convert_ids_to_tokens(t))
    return TitledStr(text)

  @typedispatch
  def _decode(self, t:LMTensorText): return self._decode[TensorText](self, t)

  @typedispatch
  def _decode(self, t:TensorCategory): return Category(t.item())
  
class HF_Datasets(FilteredBase):
  _dl_type,_dbunch_type = HF_Dataloader,DataLoaders
  def __init__(self, hs_dsets: dict, *args, **kwargs):
    self.hs_dsets = { split: HF_Dataset(dset, *args, **kwargs) for split, dset in hs_dsets.items()}
  def subset(self, i): return list(self.hs_dsets.values())[i]
  def __getitem__(self, split): return self.hs_dsets[split]
  @property
  def n_subsets(self): return len(self.hs_dsets)

In [4]:
cola_dsets = HF_Datasets(tokenized_cola, ['text_idxs', 'label'], hf_fast_tokenizer, pretty_show=True)
cola_dls = cola_dsets.dataloaders(bs=32, pad_idx=hf_fast_tokenizer.pad_token_id)
cola_dls.show_batch()

Unnamed: 0,text,category
0,"everybody who has ever, worked in any office which contained any typewriter which had ever been used to type any letters which had to be signed by any administrator who ever worked in any department like mine will know what i mean.",1
1,"hank plays the guitar and finds arrangements for all the old folk songs which are still sung in these hills, and ernie writes down all the old folk songs which are still sung in these hills.",1
2,"playing with matches is ; lots of fun, but doing, so and emptying gasoline from one can to another at the same time is a sport best reserved for arsons.",1
3,"in january 2002, a dull star in an obscure constellation suddenly became 600, 000 times more luminous than our sun, temporarily making it the brightest star in our galaxy.",1
4,"which folks up at corporate headquarters do you think that the sooner you solve this problem, the quicker you'll be able to tell t to buzz off?",0
5,"the dumplings which sasha is gobbling down faster than i can reheat the meatballs are extremely tasty, if i do say so.",1
6,"will put a picture of bill on your desk before tomorrow, this girl in the red coat will put a picture of bill on your desk before tomorrow.",0
7,a burlap sack of potatoes with mealy skins fell on the professor of linguistics with the terrible taste in t - shirts from the twelfth story.,1
8,"one of the jewish children is a spunky girl, who gave a black eye to the kid with the german roots before the start of the war.",1


# 2. Aggregate samples of HuggingFace/nlp dataset
~ task traditional language model task for example ~

In [5]:
class AggregateTransform():
  def __init__(self, hf_dset, inp_cols, out_cols):
    for inp_col in inp_cols: assert inp_col in hf_dset.column_names, f"{inp_col} is not in {hf_dset.column_names}"
    for out_col in out_cols: assert out_col not in hf_dset.column_names, f"Because we will remove existing columns after transform, output columns can't share the same names with existing columns"
    # Hugginface nlp dataset
    # batched map need dataset be in python format
    hf_dset.set_format(type=None, columns=hf_dset.column_names)
    self.hf_dset = hf_dset
    self.original_cols = hf_dset.column_names

  def map(self, batch_size=1000, **kwargs):
    return self.hf_dset.map(function=self, batched=True, batch_size=batch_size, with_indices=True,
                            remove_columns=self.original_cols, **kwargs)

class LMTransform(AggregateTransform):
  def __init__(self, hf_dset, max_len, text_col, x_text_col='x_text', y_text_col='y_text'):
    super().__init__(hf_dset, [text_col], [x_text_col, y_text_col])
    self.text_col, self.x_text_col, self.y_text_col = text_col, x_text_col, y_text_col

    # Global scop variables for the loop
    self._max_len = max_len + 1
    self.last_idx = len(hf_dset) - 1
    self.residual_len, self.new_text = self._max_len, []

  def __call__(self, b, idxs):
    self.x_texts, self.y_texts =  [], []
    for text in tqdm(b[self.text_col], leave=False):
      self._accumulate(text)
    # last
    if self.last_idx in idxs and len(self.new_text) >= 2:
      self.x_texts.append(self.new_text[:-1])
      self.y_texts.append(self.new_text[1:]) 
    new_b = {self.x_text_col: self.x_texts, self.y_text_col: self.y_texts}
    # map require the returned includes original columns
    for key in b: new_b[key] = [None]*len(self.x_texts)
    
    return new_b

  def _accumulate(self, text):
    "text: a list of indices"
    usable_len = len(text)
    cursor = 0
    while usable_len != 0:
      use_len = min(usable_len, self.residual_len)
      self.new_text += text[cursor:cursor+use_len]
      self.residual_len -= use_len
      usable_len -= use_len
      cursor += use_len
      if self.residual_len == 0:
        self.x_texts.append(self.new_text[:-1])
        self.y_texts.append(self.new_text[1:])
        self.new_text = []
        self.residual_len = self._max_len

In [10]:
cola_val = tokenized_cola['validation']
lm_dataset = LMTransform(cola_val, max_len=20, text_col='text_idxs').map()

print('Original dataset:')
print('num of samples:', len(cola['validation']))
print('second to last sentence:', cola['validation'][-2]['sentence'])
print('          last sentence:', cola['validation'][-1]['sentence'])
print('LM dataset:')
print('num of sampels:', len(lm_dataset))
print('last text (x):', hf_fast_tokenizer.decode(lm_dataset[-1]['x_text']))
print('last text (y):', hf_fast_tokenizer.decode(lm_dataset[-1]['y_text']))

0%|          | 0/2 [00:00<?, ?it/s]Original dataset:
num of samples: 1043
second to last sentence: John arranged for himself to get the prize.
          last sentence: John talked to Bill about himself.
LM dataset:
num of sampels: 481
last text (x): . john talked to bill about himself
last text (y): john talked to bill about himself.


In [11]:
lm_dl = HF_Dataloader(HF_Dataset(lm_dataset, {'x_text':LMTensorText, 'y_text':TensorText},hf_fast_tokenizer), sort=False, pad_idx=hf_fast_tokenizer.pad_token_id)
lm_dl.show_batch()

Unnamed: 0,text,text_
0,the sailors rode the breeze clear of the rocks . the weights made the rope stretch over the pull ##ey,sailors rode the breeze clear of the rocks . the weights made the rope stretch over the pull ##ey .
1,"the mechanical doll wr ##ig ##gled itself loose . if you had eaten more , you would want less .","mechanical doll wr ##ig ##gled itself loose . if you had eaten more , you would want less . as"
2,"you eat the most , you want the least . the more you would want , the less you would","eat the most , you want the least . the more you would want , the less you would eat"
3,". i demand that the more john eat , the more he pays . mary listen ##s to the grateful","i demand that the more john eat , the more he pays . mary listen ##s to the grateful dead"
4,", she gets depressed . the ang ##rier mary got , the more she looked at pictures . the higher","she gets depressed . the ang ##rier mary got , the more she looked at pictures . the higher the"
5,"stakes , the lower his expectations are . the more fred is ob ##no ##xious , the less attention you",", the lower his expectations are . the more fred is ob ##no ##xious , the less attention you should"
6,pay to him . john was lots more ob ##no ##xious than fred . the more people you give beer,to him . john was lots more ob ##no ##xious than fred . the more people you give beer to
7,", the more people get sick . the more does bill smoke , the more susan hates him . the","the more people get sick . the more does bill smoke , the more susan hates him . the more"
8,"pictures of him that appear in the news , the more embarrassed john becomes . every senator seems to become","of him that appear in the news , the more embarrassed john becomes . every senator seems to become more"
