-
Notifications
You must be signed in to change notification settings - Fork 14
/
pack_dataset.py
41 lines (34 loc) · 1.72 KB
/
pack_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from itertools import chain
from functools import partial
remainder = {"input_ids": [], "attention_mask": [], "token_type_ids": []}
# empty list to save remainder from batches to use in next batch
def pack_dataset(dataset, chunk_length=2048):
print(f"Chunking dataset into chunks of {chunk_length} tokens.")
def chunk(sample, chunk_length=chunk_length):
# define global remainder variable to save remainder from batches to use in next batch
global remainder
# Concatenate all texts and add remainder from previous batch
concatenated_examples = {k: list(chain(*sample[k])) for k in sample.keys()}
concatenated_examples = {k: remainder[k] + concatenated_examples[k] for k in concatenated_examples.keys()}
# get total number of tokens for batch
batch_total_length = len(concatenated_examples[list(sample.keys())[0]])
# get max number of chunks for batch
if batch_total_length >= chunk_length:
batch_chunk_length = (batch_total_length // chunk_length) * chunk_length
# Split by chunks of max_len.
result = {
k: [t[i : i + chunk_length] for i in range(0, batch_chunk_length, chunk_length)]
for k, t in concatenated_examples.items()
}
# add remainder to global variable for next batch
remainder = {k: concatenated_examples[k][batch_chunk_length:] for k in concatenated_examples.keys()}
# prepare labels
result["labels"] = result["input_ids"].copy()
return result
# tokenize and chunk dataset
lm_dataset = dataset.map(
partial(chunk, chunk_length=chunk_length),
batched=True,
)
print(f"Total number of samples: {len(lm_dataset)}")
return lm_dataset