In [68]:
from functools import partial
import numpy as np

In [None]:
import tensorflow as tf

In [78]:
import jax

In [None]:
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

In [None]:
from datasets import Dataset, load_dataset

In [None]:
d = {'text': [[1, 2, 3], [2, 4]], 'label': [0, 1]}
ds = Dataset.from_dict(d)

In [39]:
dataset_name = "rotten_tomatoes"
dataset = load_dataset(dataset_name, split="train")

In [None]:
cache_dir = '/nas/xd/.cache/torch/transformers/'
model_name = 'EleutherAI/gpt-j-6B'
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)

In [None]:
def tokenize_function(examples): return tokenizer(examples['text'])
tokenized_dataset = dataset.map(lambda examples: tokenizer(examples['text']),
    batched=True, num_proc=None, remove_columns=dataset.column_names)  # run_clm_flax.py
tokenized_dataset = tokenized_dataset.remove_columns('attention_mask')

In [97]:
# mesh-transformer-jax, https://www.tensorflow.org/tutorials/load_data/tfrecord
def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def write_tfrecords(sequences, fp):
    with tf.io.TFRecordWriter(fp) as writer:
        for seq in sequences:
            feature = {"input_ids": _int64_feature(seq)}
            example = tf.train.Example(features=tf.train.Features(feature=feature))
            writer.write(example.SerializeToString())
            
def _parse_function(example_proto): # https://zhuanlan.zhihu.com/p/552951305
    feature_desc = {"input_ids": tf.io.VarLenFeature(tf.int64)}
    example = tf.io.parse_single_example(example_proto, feature_desc)
    for name in list(example.keys()):
        t = example[name]
        if t.dtype == tf.int64: t = tf.cast(t, dtype=tf.int32)
        example[name] = tf.sparse.to_dense(t, default_value=0)
        # example[name] = tf.sparse.to_dense(tf.sparse.reorder(t)) # mesh-transformer-jax
    return example

def shard(data, batch_size=None):
    return jax.tree_map(lambda x: x.numpy().reshape(batch_size + x.shape[1:]), data)  # mtj
    
def prefetch(dataset, n_prefetch=None):
    # Taken from: https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py
    ds_iter = iter(dataset)
    ds_iter = map(lambda x: jax.tree_map(lambda t: np.asarray(memoryview(t)), x), ds_iter)
    if n_prefetch: ds_iter = flax.jax_utils.prefetch_to_device(ds_iter, n_prefetch)
    return ds_iter

In [74]:
sequences = tokenized_dataset['input_ids']
fp = f'{dataset_name}_train_{len(sequences)}.tfrecords'
write_tfrecords(sequences, fp)

In [101]:
ds = tf.data.TFRecordDataset(fp)
# ds = ds.shuffle(buffer_size=min(1000, len(sequences))) # flaxmodels, https://zhuanlan.zhihu.com/p/552951305
ds = ds.map(_parse_function, num_parallel_calls=tf.data.AUTOTUNE)

In [102]:
gradient_accumulation_steps = 8
train_mbs_per_replica = 2 # train_micro_batch_size_per_gpu in deepspeed
mp_size, dp_size = 8, 1
train_batch_size = (gradient_accumulation_steps, train_mbs_per_replica * dp_size)
max_len = 80  # max(len(s) for s in sequences) == 78
# ds = ds.apply(tf.data.experimental.dense_to_ragged_batch(np.prod(self.bs), drop_remainder=True)) # mtj
ds = ds.padded_batch(batch_size=np.prod(train_batch_size), padded_shapes={'input_ids': [max_len]},
                     padding_values={'input_ids': 0}, drop_remainder=True)
ds = ds.prefetch(10)  # mesh-transformer-jax
# ds = ds.repeat()  # gpt-neo/inputs.py
# map shard directly over ds won't work, getting AttributeError: 'Tensor' object has no attribute 'numpy'
# because inside tf.function?, see e.g.:
# 1) https://stackoverflow.com/questions/34097281/convert-a-tensor-to-numpy-array-in-tensorflow
# 2) https://github.com/tensorflow/tensorflow/issues/27519
# ds = ds.map(partial(shard, batch_size=train_batch_size), num_parallel_calls=tf.data.AUTOTUNE)
# matthias-wright/flaxmodels/training/stylegan2/data_pipeline.py
ds_iter = iter(ds)
ds_iter = map(lambda x: shard(x, batch_size=train_batch_size), ds_iter)

In [103]:
for batch in ds_iter: break