In [1]:
from IPython.core.debugger import set_trace as bk
from pathlib import Path
from functools import partial
import torch
import nlp
from tqdm import tqdm
from transformers import ElectraTokenizerFast
hf_fast_tokenizer = ElectraTokenizerFast.from_pretrained("google/electra-small-generator")
from fastai2.text.all import *

In [31]:
cola = nlp.load_dataset('glue', 'cola', cache_dir='~/tmp')['validation']

class HF_TokenizeTfm():
  def __init__(self, tokenizer, cols):
    if isinstance(cols, list): cols = {c:c for c in cols}
    assert isinstance(cols, dict)
    self.cols = cols
  def __call__(self, example):
    for in_col, out_col in self.cols.items():
      example[out_col] = hf_fast_tokenizer.convert_tokens_to_ids(hf_fast_tokenizer.tokenize(example[in_col]))
    return example

tokenized_cola = cola.map(HF_TokenizeTfm(hf_fast_tokenizer, {'sentence':'text_idxs'}),
                          remove_columns=['sentence'])

1043it [00:00, 8892.32it/s]


In [34]:


class LMTransform():
  def __init__(self, max_len, ds_size, text_col, x_text_col='x_text', y_text_col='y_text'):
    assert text_col != x_text_col and text_col != y_text_col
    self._max_len = max_len + 1
    self.last_idx = ds_size - 1
    self.text_col, self.x_text_col, self.y_text_col = text_col, x_text_col, y_text_col
    self.residual_len, self.new_text = self._max_len, []

  def __call__(self, b, idxs):
    self.x_texts, self.y_texts =  [], []
    for text in tqdm(b[self.text_col], leave=False):
      self._accumulate(text)
    if self.last_idx in idxs and len(self.new_text) >= 2:
      self.x_texts.append(self.new_text[:-1])
      self.y_texts.append(self.new_text[1:])
    new_b = {self.x_text_col: self.x_texts, self.y_text_col: self.y_texts}
    for key in b: new_b[key] = [None]*len(self.x_texts) # map require the returned includes original columns
    return new_b

  def _accumulate(self, text):
    "text: a list of indices"
    usable_len = len(text)
    cursor = 0
    while usable_len != 0:
      use_len = min(usable_len, self.residual_len)
      self.new_text += text[cursor:cursor+use_len]
      self.residual_len -= use_len
      usable_len -= use_len
      cursor += use_len
      if self.residual_len == 0:
        self.x_texts.append(self.new_text[:-1])
        self.y_texts.append(self.new_text[1:])
        self.new_text = []
        self.residual_len = self._max_len
    
lm_dataset = tokenized_cola.map(LMTransform(20, len(tokenized_cola), 'text_idxs'), 
                        batched=True, batch_size=3, with_indices=True, remove_columns=tokenized_cola.column_names)

0%|          | 0/348 [00:00<?, ?it/s]
  0%|          | 0/3 [00:00<?, ?it/s][A
[A
  0%|          | 0/3 [00:00<?, ?it/s][A
[A
  0%|          | 0/3 [00:00<?, ?it/s][A
[A
  0%|          | 0/3 [00:00<?, ?it/s][A
[A
  0%|          | 0/3 [00:00<?, ?it/s][A
[A
  0%|          | 0/3 [00:00<?, ?it/s][A
[A
  0%|          | 0/3 [00:00<?, ?it/s][A
[A
  0%|          | 0/3 [00:00<?, ?it/s][A
[A
  0%|          | 0/3 [00:00<?, ?it/s][A
[A
  0%|          | 0/3 [00:00<?, ?it/s][A
[A
  0%|          | 0/3 [00:00<?, ?it/s][A
[A
  0%|          | 0/3 [00:00<?, ?it/s][A
[A
  0%|          | 0/3 [00:00<?, ?it/s][A
[A
  0%|          | 0/3 [00:00<?, ?it/s][A
[A
  0%|          | 0/3 [00:00<?, ?it/s][A
[A
  0%|          | 0/3 [00:00<?, ?it/s][A
[A
  0%|          | 0/3 [00:00<?, ?it/s][A
[A
  0%|          | 0/3 [00:00<?, ?it/s][A
[A
  0%|          | 0/3 [00:00<?, ?it/s][A
[A
  0%|          | 0/3 [00:00<?, ?it/s][A
[A
  0%|          | 0/3 [00:00<?, ?it/s][A
[A
  0%|          | 

In [35]:
print('Original dataset:')
print('num of samples:', len(cola_dataset))
print('second to last sentence:', cola_dataset[-2]['sentence'])
print('          last sentence:', cola_dataset[-1]['sentence'])
print('LM dataset:')
print('num of sampels:', len(lm_dataset))
print('last text (x):', hf_fast_tokenizer.decode(lm_dataset[-1]['x_text']))
print('last text (y):', hf_fast_tokenizer.decode(lm_dataset[-1]['y_text']))

Original dataset:
num of samples: 1043
second to last sentence: John arranged for himself to get the prize.
          last sentence: John talked to Bill about himself.
LM dataset:
num of sampels: 481
last text (x): . john talked to bill about himself
last text (y): john talked to bill about himself.


In [55]:
class HF_Dataset(FilteredBase):
  
  def __init__(self, hf_dset, cols, n_inp=1, hf_tokenizer=None, pretty_show=False)
    
    # some default setting for tensor type used in decoding
    if isinstance(cols, list): 
      if n_inp==1: 
        if len(cols)==1: cols = {cols[0]: TensorText}
        elif len(cols)==2: cols = {cols[0]: TensorText, cols[1]: TensorCategory}
      else: cols = { c: noop for c in cols }
    assert isinstance(cols, dict)
    
    # make dataset output pytorch tensor
    if hf_ds.format['type'] != 'torch': 
      hf_ds.set_format( type='torch', cols=list(cols.keys()) )

    # store attributes
    store_attr(self, "hf_ds,cols,n_inp,hf_tokenizer,pretty_show")

  def __getitem__(self, idx):
    sample = self.hf_ds[idx]
    return tuple( tensor_cls(sample[col]) for col, tensor_cls in self.cols.items() )

  def decode(self, o, full=True): 
    return tuple( self._decode(o_) ) for o_ in o)

  @typedispatch
  def _decode(self, t:TensorText): 
    if self.pretty_show: text = self.hf_tokenizer.decode([idx for idx in t if idx != self.hf_tokenizer.pad_token_id])
    else: text = ' '.join(self.hf_tokenizer.convert_ids_to_tokens(t))
    return TitledStr(text)

  @typedispatch
  def _decode(self, t:LMTensorText): return self._decode[TensorText](self, t)

  @typedispatch
  def _decode(self, t:TensorCategory): return Category(t.item())
  
def HF_datasets(FilteredBase)
  def __init__(self, hs_dsets: dict):
    self.hs_dsets = hs_dsets
  def subset(self, i): return list(self.hs_dsets.values)[i]
  def __getitem__(self, split): return self.hs_dsets[split]
  @property
  def n_subsets(self): return len(self.hs_dsets)

@delegates()
class HF_Dataloader(TfmdDL):
  
  def __init__(self, dataset, pad_idx=None, sort=True, **kwargs):
    if pad_idx is not None:
      kwargs['before_batch'] = partial(pad_input_chunk, pad_idx=pad_idx, pad_first=False)
    super().__init__(dataset, **kwargs)
    self.sort = sort
    if sort:
      col = list(cols.keys())[0]
      self.lens = [ len(sample[col]) for sample in dataset ]
  
  def get_idxs(self):
    idxs = super().get_idxs()
    return sorted(idxs, key=lambda i: self.lens[i], reverse=True)

  def new(self, **kwargs):
    return super().new(dataset=self.dataset, sort=self.sort, **kwargs)

In [56]:
cola_dl = HF_Dataloader(HF_Dataset(tokenized_cola, ['text_idxs', 'label']),         
                        pad_idx=hf_fast_tokenizer.pad_token_id)
cola_dl.show_batch()

AttributeError: 'TensorText' object has no attribute 'truncate'

In [7]:
lm_dl = HF_Dataloader(lm_dataset, {'x_text':LMTensorText, 'y_text':TensorText}, sort=False)
lm_dl.show_batch()

NameError: name 'HF_Dataloader' is not defined