In [26]:
import jax
import jax.numpy as jnp
import optax
import flax
import math 
import random

from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard

import numpy as np

from tqdm.notebook import tqdm

from transformers import FlaxAutoModelForCausalLM
from transformers import AutoTokenizer

from datasets import Dataset

In [31]:
MAX_SEQ_LENGTH = 9
TRAINING_SEED = 20
BATCH_SIZE = 512

In [29]:
model = FlaxAutoModelForCausalLM.from_pretrained("distilgpt2", 
                                        seed=TRAINING_SEED, 
                                        dtype=jnp.dtype("bfloat16"))

Downloading:   0%|          | 0.00/312M [00:00<?, ?B/s]



In [30]:
tokenizer = AutoTokenizer.from_pretrained(
    "../tokenizer/tokenizer_40000"
)
print(tokenizer.eos_token)

<|endoftext|>


In [32]:
seq1 = "1 2 3 4 5 6 7 8"
seq2 = "a b c d e f g h"
seq3 = "h g f e d c b a"
print(tokenizer(seq1).tokens())
print(tokenizer(seq2).tokens())
print(tokenizer(seq3).tokens())

data_list = [seq1 for _ in range(1024)]
data_list += [seq2 for _ in range(1024)]
data_list += [seq3 for _ in range(1024)]

random.shuffle(data_list)

data_dict = {"text": data_list}

dataset = Dataset.from_dict(data_dict)

['1', 'Ġ2', 'Ġ3', 'Ġ4', 'Ġ5', 'Ġ6', 'Ġ7', 'Ġ8']
['a', 'Ġb', 'Ġc', 'Ġd', 'Ġe', 'Ġf', 'Ġg', 'Ġh']
['h', 'Ġg', 'Ġf', 'Ġe', 'Ġd', 'Ġc', 'Ġb', 'Ġa']


I want my dataset iterator to take in parameters (dataset, tokenizer, number of training tokens, batch_size, seqlen). 

In [55]:
from itertools import islice
import sys

# this toy dataset assumes all dataset examples are exactly `seqlen` 
# long when tokenized and that batch_size divides training_tokens
class ToyDataIterator(): 
    def __init__(self, dataset, tokenizer, training_tokens, 
                seqlen, batch_size): 
        self.tokenizer = tokenizer
        self.dataset = dataset
        self.seqlen = seqlen
        self.training_tokens = training_tokens
        self.batch_size = batch_size
        
    def __iter__(self): 
        # keeps looping until number of training tokens reached
        tokens_so_far = 0 
        while True: 
            iterator = iter(self.dataset)
            while iterator: 
                batch = list(islice(iterator, self.batch_size))
                tokens = tokenizer([x["text"] for x in batch],
                                  max_length=self.seqlen, 
                                  truncation=True, 
                                  return_tensors="np")
                yield tokens["input_ids"]
                tokens_so_far += self.batch_size*self.seqlen
                if tokens_so_far > self.training_tokens: 
                    return

In [56]:
iterator = ToyDataIterator(dataset, tokenizer, training_tokens=2**20, 
                          seqlen=8, batch_size=512)

for _ in tqdm(iterator): 
    pass

0it [00:00, ?it/s]

{'input_ids': array([[  72,  353,  283, ...,  278,  313,  262],
       [  17,  476,  715, ..., 1396, 1761, 1776],
       [  65,  313,  278, ...,  283,  353,  440],
       ...,
       [  17,  476,  715, ..., 1396, 1761, 1776],
       [  17,  476,  715, ..., 1396, 1761, 1776],
       [  17,  476,  715, ..., 1396, 1761, 1776]]), 'attention_mask': array([[1, 1, 1, ..., 1, 1, 1],
       [1, 1, 1, ..., 1, 1, 1],
       [1, 1, 1, ..., 1, 1, 1],
       ...,
       [1, 1, 1, ..., 1, 1, 1],
       [1, 1, 1, ..., 1, 1, 1],
       [1, 1, 1, ..., 1, 1, 1]])}
{'input_ids': array([[  17,  476,  715, ..., 1396, 1761, 1776],
       [  72,  353,  283, ...,  278,  313,  262],
       [  72,  353,  283, ...,  278,  313,  262],
       ...,
       [  17,  476,  715, ..., 1396, 1761, 1776],
       [  17,  476,  715, ..., 1396, 1761, 1776],
       [  72,  353,  283, ...,  278,  313,  262]]), 'attention_mask': array([[1, 1, 1, ..., 1, 1, 1],
       [1, 1, 1, ..., 1, 1, 1],
       [1, 1, 1, ..., 1, 1, 1],
      

IndexError: list index out of range