In [None]:
from fastai2.torch_basics import *
from fastai2.data.all import *
from fastai2.text.core import *
from fastai2.text.data import *



In [None]:
def _maybe_first(o): return o[0] if isinstance(o, tuple) else o
def _get_lengths(ds):
    tok = _get_tokenizer(ds)
    if tok is None: return
    return tok.get_lengths(ds.items)
def _get_tokenizer(ds):
    tok = getattr(ds, 'tokenizer', None)
    if isinstance(tok, Tokenizer): return tok
    if isinstance(tok, (list,L)):
        for t in tok:
            if isinstance(t, Tokenizer): return t
            
@delegates()
class LMDataLoader(TfmdDL):
    def __init__(self, dataset, lens=None, cache=2, bs=64, seq_len=72, num_workers=0, **kwargs):
        self.items = ReindexCollection(dataset, cache=cache, tfm=_maybe_first)
        self.seq_len = seq_len
        if lens is None: lens = _get_lengths(dataset)
        if lens is None: lens = [len(o) for o in self.items]
        self.lens = ReindexCollection(lens, idxs=self.items.idxs)
        # The "-1" is to allow for final label, we throw away the end that's less than bs
        corpus = round_multiple(sum(lens)-1, bs, round_down=True)
        self.bl = corpus//bs #bl stands for batch length
        self.n_batches = self.bl//(seq_len) + int(self.bl%seq_len!=0)
        self.last_len = self.bl - (self.n_batches-1)*seq_len
        self.make_chunks()
        super().__init__(dataset=dataset, bs=bs, num_workers=num_workers, **kwargs)
        self.n = self.n_batches*bs

    @delegates(DataLoader.new)
    def new(self, dataset=None, cls=None, **kwargs):
        res = super().new(dataset, cls, **kwargs)         
        res.seq_len = self.seq_len
        res.items = self.items
        res.lens = self.lens 
        res.bl = self.bl
        res.n_batches = self.n_batches
        res.last_len = self.last_len
        res.make_chunks()      
        res.n = self.n_batches*bs        
        return res
    
    def make_chunks(self): self.chunks = Chunks(self.items, self.lens)
    def shuffle_fn(self,idxs):
        self.items.shuffle()
        self.make_chunks()
        return idxs

    def create_item(self, seq):
        if seq>=self.n: raise IndexError
        sl = self.last_len if seq//self.bs==self.n_batches-1 else self.seq_len
        st = (seq%self.bs)*self.bl + (seq//self.bs)*self.seq_len
        txt = self.chunks[st : st+sl+1]
        return LMTensorText(txt[:-1]),txt[1:]

In [None]:
bs,sl = 4,3
ints = L([0,1,2,3,4],[5,6,7,8,9,10],[11,12,13,14,15,16,17,18],[19,20],[21,22]).map(tensor)
dl = LMDataLoader(ints, bs=bs, seq_len=sl)
list(dl)
test_eq(list(dl),
    [[tensor([[0, 1, 2], [5, 6, 7], [10, 11, 12], [15, 16, 17]]),
      tensor([[1, 2, 3], [6, 7, 8], [11, 12, 13], [16, 17, 18]])],
     [tensor([[3, 4], [8,  9], [13, 14], [18, 19]]),
      tensor([[4, 5], [9, 10], [14, 15], [19, 20]])]])

dl_new = dl.new()
test_eq(dl.one_batch(),dl_new.one_batch())