In [0]:
try: import fastai2
except: 
  !git clone https://github.com/richardyy1188/Pretrain-MLM-and-finetune-on-GLUE-with-fastai.git
  %pip install -q fastai2 transformers tqdm

In [0]:
%cd Pretrain-MLM-and-finetune-on-GLUE-with-fastai

from IPython.core.debugger import set_trace as bk
from functools import partial
import pickle
from tqdm import tqdm
import torch
from fastai2.text.all import *
from transformers import ElectraTokenizer
hf_tokenizer = ElectraTokenizer.from_pretrained("google/electra-small-generator")
from _utils.hf_transformers_integration import HF_Tokenizer, HF_TextBlock, HFModelWrapper
from _utils.demo_data import load_demo_dataframe

/content/Pretrain-MLM-and-finetune-on-GLUE-with-fastai


In [0]:
# from the lyrics of "Avicii - Waiting For Love"
data={'text':["Monday left me broken","Tuesday I was through with hoping","Wednesday my empty arms were open","Thursday waiting for love, waiting for love","Thank the stars it's Friday","I'm burning like a fire gone wild on Saturday","Guess I won't be coming to church on Sunday","I'll be waiting for love, waiting for love","To come around",],'is_valid':[False]*7 + [True]*2}
data['length'] = [ len(t.split()) for t in data['text']]
df = pd.DataFrame(data=data)
df

Unnamed: 0,text,is_valid,length
0,Monday left me broken,False,4
1,Tuesday I was through with hoping,False,6
2,Wednesday my empty arms were open,False,6
3,"Thursday waiting for love, waiting for love",False,7
4,Thank the stars it's Friday,False,5
5,I'm burning like a fire gone wild on Saturday,False,9
6,Guess I won't be coming to church on Sunday,False,9
7,"I'll be waiting for love, waiting for love",True,8
8,To come around,True,3


#1. TextDataloader

In [0]:
@delegates()
class TextDataloader(TfmdDL):
  def __init__(self, dataset, max_seq_len=float('inf'), sort_by_len='desc', agg_mode=None, ignore_gt_maxlen=True, remove_heads=False, remove_tails=False, bos_idx_add=None, eos_idx_add=None, show_bar=True, samples=None, **kwargs):
    super().__init__(dataset, **kwargs)
    assert agg_mode in [None, 'lm', 'lines', 'window']
    assert not (agg_mode and max_seq_len is None)
    assert sort_by_len in [False, 'desc', 'asc']
    if agg_mode in ['window','lm']: sort_by_len=False # sorting makes no sense with these modes
    ignore_gt_maxlen = ignore_gt_maxlen and agg_mode in [None, 'lines'] and max_seq_len is not None
    first_text_tensor = dataset[0][0]
    device, dtype = first_text_tensor.device, first_text_tensor.dtype
    self.bos = torch.tensor([bos_idx_add] if bos_idx_add is not None else [], device=device, dtype=dtype)
    self.eos = torch.tensor([eos_idx_add] if eos_idx_add is not None else [], device=device, dtype=dtype)
    self.add_bos_or_eos = bos_idx_add or eos_idx_add
    # only use [start:end] text to concatenate (if needed)
    self.start = 0 if not remove_heads else 1
    self.end = None if not remove_tails else -1

    store_attr(self,'dataset,max_seq_len,sort_by_len,agg_mode,ignore_gt_maxlen,remove_heads,remove_tails,bos_idx_add,eos_idx_add,show_bar')
    
    if samples is not None: # Load from cache
      if sort_by_len: self.samples = sorted(samples, key=lambda s: s[0], reverse=True if sort_by_len=='desc' else False)
      else: self.samples = samples
      self.n = len(samples)
      return

    self.samples = L()
    # residual_len will reset to initial_residual_len
    # lm mode: max_seq_len text and 1 right-shift text, so take max_seq_len + 1 window
    self.initial_residual_len = max_seq_len + 1 if agg_mode=='lm' else max_seq_len 
    # keep spaces to add bos to final text 
    if bos_idx_add is not None: self.initial_residual_len -= 1
    if eos_idx_add is not None: self.initial_residual_len -= 1
    self.residual_len, self.new_sample = self.initial_residual_len, []

    for i, sample in tqdm(enumerate(dataset), desc='TextDataloader init:', total=len(dataset), disable=not show_bar):
      line_len = len(sample[0])
      if remove_heads: line_len -= 1
      if remove_tails: line_len -= 1
      
      if max_seq_len is not None and line_len > self.initial_residual_len and agg_mode in [None, 'lines']:
        if ignore_gt_maxlen: continue
        else: raise ValueError(f'The {i} th text line in dataset has length {line_len}(without removing head or tail, {len(sample[0])}), and is longer than max length {self.initial_residual_len}(without add bos or eos, {max_seq_len})')
        
      if agg_mode is None: self.samples.append( (line_len, i) )
      elif agg_mode == 'lines': self._accumulate_lines(i, line_len)
      else: self._accumulate_window(i, line_len)
    
    if agg_mode is not None and self.new_sample:
      if agg_mode == 'lines': self.samples.append((self.max_seq_len-self.residual_len, self.new_sample))
      else: self.samples.append(self.new_sample)

    # sort if needed
    if sort_by_len:
      self.samples.sort(key=lambda s: s[0], reverse=True if sort_by_len=='desc' else False)
    # specify total number of samples
    self.n = len(self.samples)
      
  def _accumulate_lines(self, i, line_len):
    if line_len <= self.residual_len:
      self.new_sample.append(i)
      self.residual_len -= line_len
    else:
      self.samples.append((self.max_seq_len-self.residual_len, self.new_sample))
      self.new_sample = [i]
      self.residual_len = self.initial_residual_len - line_len

  def _accumulate_window(self, i, line_len):
    usable_len = line_len
    cursor = self.start
    while usable_len != 0:
      use_len = min(usable_len, self.residual_len)
      self.new_sample.append((i, cursor, cursor+use_len))
      self.residual_len -= use_len
      usable_len -= use_len
      cursor += use_len
      if self.residual_len == 0:
        self.samples.append(self.new_sample)
        self.new_sample = []
        self.residual_len = self.initial_residual_len

  def create_item(self, s):
    if self.agg_mode is None:
      "samples = [ (length, idx), ... ]"
      idx = self.samples[s][1]
      sample = self.dataset[idx]
      line = sample[0][self.start:self.end]
      text = torch.cat([self.bos, line, self.eos]) if self.add_bos_or_eos else line
      return ( TensorText(text), *sample[1:] )
    elif self.agg_mode == 'lines':
      "samples = [ (length, [idx, idx, ...]) , ... ]"
      agg = [ self.dataset[idx][0][self.start:self.end] for idx in self.samples[s][1] ]
      agg_text = concat(self.bos, *agg, self.eos) if self.add_bos_or_eos else concat(*agg)
      return (TensorText(agg_text), )
    else: # window or lm
      "samples = [ (idx,start,end) ]"
      agg = [ self.dataset[idx][0][start:end] for idx,start,end in self.samples[s] ]
      agg_text = concat(self.bos, *agg, self.eos) if self.add_bos_or_eos else concat(*agg)
      if self.agg_mode == 'window':
        return (TensorText(agg_text), )
      else: # 'lm'
        return (LMTensorText(agg_text[:-1]), TensorText(agg_text[1:]))

  def shuffle_fn(self, idxs):
    if not self.sort_by_len: # notice sort_by_len in lm and winodw mode will be False
      self.samples.shuffle()
    return idxs

  def desc_sort(self):
    assert self.agg_mode not in ['window','lm'], f"Sorting by length makes no sense on aggregation mode {self.agg_mode}"
    self.samples.sort(key=lambda s: s[0], reverse=True)
    self.sort_by_len = 'desc'

  def asc_sort(self):
    assert self.agg_mode not in ['window','lm'], f"Sorting by length makes no sense on aggregation mode {self.agg_mode}"
    self.samples.sort(key=lambda s: s[0], reverse=False)
    self.sort_by_len = 'asc'

  def cache(self, file_path):
    torch.save(self, file_path)

  def __getstate__(self):
    "specify something you don't want pickle here, remember to use copy to not modfiy orginal instance"
    state = self.__dict__.copy()
    state['dataset'] = None
    return state

  #@delegates(TextDataloader.__init__) but we haven't evaluated TextDataloader
  @delegates(TfmdDL.new)
  @classmethod
  def from_cache(cls, file_path, dataset, **kwargs):
    dl = torch.load(file_path)
    dl.dataset = dataset

    # Reject change that cause arguments be inconsistent with loaded `self.samples` record 
    for arg in ['max_seq_len','agg_mode','ignore_gt_maxlen','remove_heads','remove_tails']:
      assert arg not in kwargs, f"Specifying {arg} will make it inconsistent with cached internal record."
    if 'sort_by_len' in kwargs:
      assert not (dl.sort_by_len and not kwargs['sort_by_len']), f"Cached textdl is internal sorted, it can't restore orignal order."
    for arg in ['bos_idx_add','eos_idx_add']:
      if arg in kwargs: assert (kwargs[arg] is None) == (getattr(dl, arg) is None), f"You can't change whether to add head/eos from cached setting."
    # TextDataloader.new guess creating validation dataloader if don't drop_last, but it might not be the case
    kwargs['ignore_gt_maxlen'] = dl.ignore_gt_maxlen
    # Even if spefify no kwargs and just load original dataloader, using new method can update device and dtype of bos and eos for this dataset
    dl = dl.new(dataset, samples=dl.samples, **kwargs)
    # Consider whether setting up newly pased batch tfms, cuz new method just make do_setup=false
    # Actually I don't know if it is good, but at leaat it works for pad_input_chunk as before_batch
    if kwargs.pop('do_setup', True):
      for nm in ['after_item','before_batch','after_batch']:
        if nm in kwargs:
          kwargs[nm] = Pipeline(kwargs.get(nm,None)) # don't know why Pipeline creating in TfmdDL won't be done in this case, but we can do it here and even it has done, it is ok we just Pipeline it again. 
          pv(f"Setting up {nm}: {kwargs[nm]}", kwargs.pop('verbose', False))
          kwargs[nm].setup(dl)
    return dl

  @delegates(TfmdDL.new)
  def new(self, dataset=None, **kwargs):
    cur_args = dict(max_seq_len=self.max_seq_len, sort_by_len=self.sort_by_len,agg_mode=self.agg_mode,ignore_gt_maxlen=self.ignore_gt_maxlen,remove_heads=self.remove_heads, remove_tails=self.remove_tails, bos_idx_add=self.bos_idx_add, eos_idx_add=self.eos_idx_add,show_bar=self.show_bar)
    
    # we assume if you don't drop_last, you are going to create validation dl, specify ignore_gt_maxlen in kwargs to overwrite it if this is not in the case  
    if not getattr(kwargs, 'drop_last', self.drop_last): 
      cur_args['ignore_gt_maxlen'] = False # You can't discard data from dataset for validation, especially test set
    
    return super().new(dataset=dataset,
                       **merge(cur_args, kwargs)) # kwargs overwrite cur_args

# 2. Try

We'll try different param of `TextDataloader` to show its capability, but **!! it doesn't mean these are the best practices. !!**

In [0]:
db = DataBlock(splitter=ColSplitter(),
              blocks=HF_TextBlock.from_df('text', hf_tokenizer),
              get_x=ColReader('text'),)

Default behavior:
* a line a sample
* collect samples by their length. (try to make samples with the same length as a batch, to reduce number of pad)

In [0]:
default_dls = db.dataloaders(df, bs=4, dl_type=TextDataloader, show_bar=False)
default_dls.show_batch(max_n=4)
"""
We sort the sample by its length.
Observe that the 3rd sample of batch is Friday (9 tokens) but not Thursday (10 tokens), 
thus we can reduce number of pad need to add, 
becuase we have to make all samples in a batch the same legth.
"""
print('x batch size:', default_dls.one_batch()[0].shape)

Unnamed: 0,text
0,[CLS] i ' m burning like a fire gone wild on saturday [SEP]
1,[CLS] guess i won ' t be coming to church on sunday [SEP]
2,"[CLS] thursday waiting for love , waiting for love [SEP] [PAD] [PAD] [PAD]"
3,[CLS] thank the stars it ' s friday [SEP] [PAD] [PAD] [PAD] [PAD]


x batch size: torch.Size([4, 13])


**Window mode**
* Want to use broader context
* sliding context window
* less pad (only samples in the last batch may have pad)
* every sample is of `max_seq_len` length. (Unless the last batch only have one sample shorter than `max_seq_len`)

In [0]:
window_dls = db.dataloaders(df, shuffle_train=False, bs=2,
                     dl_type=partial(TextDataloader,
                                     max_seq_len=15,
                                     agg_mode='window',
                                     remove_heads=True,
                                     remove_tails=True,
                                     bos_idx_add=hf_tokenizer.cls_token_id,
                                     eos_idx_add=hf_tokenizer.sep_token_id))
window_dls.show_batch(max_n=2)
"""
To use CLS...SEP format, first remove heads(CLS) and tails(SEP) for every line,
and then add bos(CLS) and eos(SEP) to the head and tail of concatenated sequence.
"""
print('x batch size:', window_dls.one_batch()[0].shape)

TextDataloader init:: 100%|██████████| 7/7 [00:00<00:00, 1287.22it/s]
TextDataloader init:: 100%|██████████| 2/2 [00:00<00:00, 724.22it/s]


Unnamed: 0,text
0,[CLS] monday left me broken tuesday i was through with hoping wednesday my empty [SEP]
1,"[CLS] arms were open thursday waiting for love , waiting for love thank the [SEP]"


x batch size: torch.Size([2, 15])


**Lines mode**
* Want to attend to wider context, but also don't want shattered sentence.
* Sequentially concat lines.
* Note that `max_seq_len` is not definitely length of sample, and increasing it doesn't definitely increase number of pads used.

In [0]:
lines_dls = db.dataloaders(df, shuffle_train=False, bs=2,
                     dl_type=partial(TextDataloader,
                                     max_seq_len=13,
                                     agg_mode='lines',
                                     remove_heads=True,
                                     bos_idx_add=hf_tokenizer.cls_token_id))
lines_dls.show_batch(max_n=2)
"""
To get CLS ... SEP ... SEP format, we remove head (CLS) for every line,
and add back an bos (CLS) to head of concated sample.
"""
print('x batch size:', lines_dls.one_batch()[0].shape)

TextDataloader init:: 100%|██████████| 7/7 [00:00<00:00, 1475.98it/s]
TextDataloader init:: 100%|██████████| 2/2 [00:00<00:00, 508.96it/s]


Unnamed: 0,text
0,[CLS] monday left me broken [SEP] tuesday i was through with hoping [SEP]
1,[CLS] i ' m burning like a fire gone wild on saturday [SEP]


x batch size: torch.Size([2, 13])


**(Traditional) Language model mode**
* predict i th token in y, using 0~i-1 tokens in x
* sliding context window
* samples in the last batch may have pad
* every sample is of `max_seq_len` length. (Unless the last batch only have one sample shorter than `max_seq_len`)

In [0]:
lm_dls = db.dataloaders(df, shuffle_train=False, bs=2,
                        dl_type=partial(TextDataloader,
                                        max_seq_len=7,
                                        agg_mode='lm',))
lm_dls.show_batch(max_n=2)
print('x batch size:', lm_dls.one_batch()[0].shape)

TextDataloader init:: 100%|██████████| 7/7 [00:00<00:00, 1413.86it/s]
TextDataloader init:: 100%|██████████| 2/2 [00:00<00:00, 1175.04it/s]


Unnamed: 0,text,text_
0,[CLS] monday left me broken [SEP] [CLS],monday left me broken [SEP] [CLS] tuesday
1,i was through with hoping [SEP] [CLS],was through with hoping [SEP] [CLS] wednesday


x batch size: torch.Size([2, 7])


# 3. Speed comparison to existing dataloader for text

Create datasets first to not count the time of creating datasets

In [0]:
another_df = load_demo_dataframe()
print('Size of this demo dataset: ', len(another_df))
another_datasets = db.datasets(another_df)

Size of this demo dataset:  14489


## 3.1 Compare time for initialization

In [0]:
def dataloaders_from_db_and_datasets(db, dsets, path='.', verbose=False, **kwargs):
    kwargs = {**db.dls_kwargs, **kwargs, 'verbose': verbose}
    return dsets.dataloaders(path=path, after_item=db.item_tfms, after_batch=db.batch_tfms, **kwargs)
get_dataloaders = partial(dataloaders_from_db_and_datasets, db, another_datasets)

In [0]:
%timeit get_dataloaders(dl_type=SortedDL)
%timeit get_dataloaders(dl_type=partial(TextDataloader, sort_by_len='desc', show_bar=False))

1 loop, best of 3: 8.27 s per loop
1 loop, best of 3: 7.39 s per loop


In [0]:
%timeit get_dataloaders(dl_type=LMDataLoader)
%timeit get_dataloaders(dl_type=partial(TextDataloader, max_seq_len=72, agg_mode='lm',show_bar=False))

1 loop, best of 3: 7.53 s per loop
1 loop, best of 3: 7.53 s per loop


## 3.2 Compare time for load batches

In [0]:
# We reinitialize because assignment in %timeit is local to %timeit special function scope 
# BTW, there's four bar beacause there are two (train/valid) textdl for two dls each
sorted_dls = get_dataloaders(dl_type=SortedDL)
my_sorted_dls = get_dataloaders(dl_type=partial(TextDataloader, sort_by_len='desc'))
LM_dls = get_dataloaders(dl_type=LMDataLoader)
my_LM_dls = get_dataloaders(dl_type=partial(TextDataloader, max_seq_len=72, agg_mode='lm',))

TextDataloader init:: 100%|██████████| 6938/6938 [00:04<00:00, 1720.30it/s]
TextDataloader init:: 100%|██████████| 7551/7551 [00:04<00:00, 1687.07it/s]
TextDataloader init:: 100%|██████████| 6938/6938 [00:04<00:00, 1706.60it/s]
TextDataloader init:: 100%|██████████| 7551/7551 [00:04<00:00, 1696.23it/s]


In [0]:
%timeit for b in sorted_dls.train: pass
%timeit for b in my_sorted_dls.train: pass

1 loop, best of 3: 5.44 s per loop
1 loop, best of 3: 5.76 s per loop


In [0]:
%timeit for b in LM_dls.train: pass
%timeit for b in my_LM_dls.train: pass

1 loop, best of 3: 9.16 s per loop
1 loop, best of 3: 6.17 s per loop


# 4. Cache
So you don't need to initailize dataloader from scratch every time.

Note that we cache mainly internal record of which sample should concatenate with which sample, but not the dataset itself. If you want cachable dataset, take a look at huggingface/nlp

You should pass the same dataset, especially note that order of samples should be as the same as the original one.

In [0]:
same_datasets = db.datasets(df)

In [0]:
def dataloaders_from_cache(db, source, file_paths, path='.', device=None, **kwargs):
  device = default_device()
  file_paths = L(file_paths).map(lambda p: Path(p))
  datasets = db.datasets(source)
  dl_s = L()
  for i, f in enumerate(file_paths):
    dl_s.append( TextDataloader.from_cache(f, datasets.subset(i), **kwargs) )
  return DataLoaders(*dl_s, path=path, device=device)

## 5.1 Cache and Loading cache

In [0]:
default_dls.train.cache('default_train.pth')
loaded_default_dl = TextDataloader.from_cache('default_train.pth', same_datasets.subset(0))
loaded_default_dl.show_batch()

Unnamed: 0,text
0,[CLS] i ' m burning like a fire gone wild on saturday [SEP]
1,[CLS] guess i won ' t be coming to church on sunday [SEP]
2,"[CLS] thursday waiting for love , waiting for love [SEP] [PAD] [PAD] [PAD]"
3,[CLS] thank the stars it ' s friday [SEP] [PAD] [PAD] [PAD] [PAD]


In [0]:
print('Loaded:', TextDataloader.from_cache('default_train.pth', same_datasets.subset(0),
                                        bs=3,pin_memory=True).one_batch()[0].shape )

Loaded: torch.Size([3, 13])


In [0]:
window_dls.valid.cache('window_valid.pth')
loaded_window_dl = TextDataloader.from_cache('window_valid.pth', same_datasets.subset(1))
loaded_window_dl.show_batch()

Unnamed: 0,text
0,"[CLS] i ' ll be waiting for love , waiting for love to come [SEP]"
1,[CLS] around [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]


Also can change args for `TextDataloader`

In [0]:
TextDataloader.from_cache('window_valid.pth', same_datasets.subset(1), 
                          bos_idx_add=hf_tokenizer.unk_token_id).show_batch()

Unnamed: 0,text
0,"[UNK] i ' ll be waiting for love , waiting for love to come [SEP]"
1,[UNK] around [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]


To load dataloader's', you can simply refer to this script. 

Actually, to make `Dataloaders` from `dataloader`, all you need is pass the `dataloader` s and `path` and `device`. Easy.

In [0]:
lines_dls.train.cache('lines_train.pth')
lines_dls.valid.cache('lines_valid.pth')
loaded_lines_dls = dataloaders_from_cache(db, df, ['lines_train.pth','lines_valid.pth'])
loaded_lines_dls.show_batch()

Unnamed: 0,text
0,[CLS] monday left me broken [SEP] tuesday i was through with hoping [SEP]
1,[CLS] i ' m burning like a fire gone wild on saturday [SEP]


In [0]:
#bk()
dataloaders_from_cache(db, df, ['lines_train.pth','lines_valid.pth'], 
                       before_batch=partial(pad_input_chunk, pad_first=False, 
                                            pad_idx=hf_tokenizer.unk_token_id)).show_batch()

Unnamed: 0,text
0,[CLS] monday left me broken [SEP] tuesday i was through with hoping [SEP]
1,[CLS] i ' m burning like a fire gone wild on saturday [SEP]


In [0]:
lm_dls.train.cache('lm_train.pth')
lm_dls.valid.cache('lm_valid.pth')
loaded_lm_dls = dataloaders_from_cache(db, df, ['lm_train.pth','lm_valid.pth'])
loaded_lm_dls.show_batch()

Unnamed: 0,text,text_
0,[CLS] monday left me broken [SEP] [CLS],monday left me broken [SEP] [CLS] tuesday
1,i was through with hoping [SEP] [CLS],was through with hoping [SEP] [CLS] wednesday


`TextDataloader` will also reject changes that make it inconsistent with loaded internal records.

So you can be safe.

In [0]:
dataloaders_from_cache(db, df, ['lm_train.pth','lm_valid.pth'],  
                       bos_idx_add=hf_tokenizer.unk_token_id).show_batch()

AssertionError: ignored