Skip to content

Commit

Permalink
Merge pull request #13 from shoarora/tokenize-padding
Browse files Browse the repository at this point in the history
Move padding to forward
  • Loading branch information
shoarora committed Apr 7, 2020
2 parents 655e357 + c587390 commit 86f31ba
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 22 deletions.
24 changes: 15 additions & 9 deletions lmtuners/datasets/pretokenized.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import torch
from lmtuners.utils import mask_tokens
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import ConcatDataset, Dataset

from lmtuners.utils import mask_tokens


class PreTokenizedFileDataset(Dataset):
def __init__(self, path):
Expand Down Expand Up @@ -47,20 +49,24 @@ def __init__(self,
self.rand_replace = rand_replace

def __call__(self, examples):
inputs, attention_masks, special_tokens_masks, token_type_ids = zip(*examples)
inputs = torch.stack(inputs).long()
attention_masks = torch.stack(attention_masks).long()
special_tokens_masks = torch.stack(special_tokens_masks)
inputs, attention_masks, special_tokens_masks, token_type_ids = zip(
*examples)
inputs = pad_sequence(inputs, batch_first=True, padding_value=self.pad_token_id).long()
attention_masks = pad_sequence(attention_masks, batch_first=True, padding_value=0).long()
special_tokens_masks = pad_sequence(special_tokens_masks, batch_first=True, padding_value=1)

if token_type_ids[0] is not None:
token_type_ids = torch.stack(token_type_ids).long()
token_type_ids = pad_sequence(token_type_ids, batch_first=True, padding_value=1).long()
else:
token_type_ids = None

if self.mlm:
inputs, labels = mask_tokens(inputs, special_tokens_masks,
self.pad_token_id, self.mask_token_id,
self.vocab_size, self.mlm_prob,
inputs, labels = mask_tokens(inputs,
special_tokens_masks,
self.pad_token_id,
self.mask_token_id,
self.vocab_size,
self.mlm_prob,
rand_replace=self.rand_replace)
return inputs, labels, attention_masks, token_type_ids
else:
Expand Down
28 changes: 15 additions & 13 deletions lmtuners/utils/tokenize_and_cache_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from tokenizers import BertWordPieceTokenizer
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence


def tokenize_and_cache_data(data_dir,
Expand All @@ -22,7 +23,6 @@ def tokenize_and_cache_data(data_dir,
tokenizer = BertWordPieceTokenizer(tokenizer_path)

tokenizer.enable_truncation(max_length=max_length)
tokenizer.enable_padding(max_length=max_length)

num_tokens = 0
num_examples = 0
Expand All @@ -39,7 +39,7 @@ def tokenize_and_cache_data(data_dir,
num_tokens += result['num_tokens']

pbar.set_description(
f"{num_tokens} tokens, {num_examples} examples, {num_tokens/(num_examples*max_length)} non-pad tokens"
f"{num_tokens} tokens, {num_examples} examples"
)


Expand Down Expand Up @@ -105,10 +105,10 @@ def add_example(encoded):
random.shuffle(indices)
indices = [(indices[i], indices[i+1]) for i in range(0, len(indices)-1, 2)]
for i, j in indices:
_ids = ids[i] + ids[j][:1]
_attention_mask = attention_masks[i] + attention_masks[j][:1]
_special_tokens_mask = special_tokens_masks[i] + special_tokens_masks[j][:1]
_token_type_ids = ([0] * len(ids[i])) + ([1] * len(ids[j]))
_ids = ids[i] + ids[j][1:]
_attention_mask = attention_masks[i] + attention_masks[j][1:]
_special_tokens_mask = special_tokens_masks[i] + special_tokens_masks[j][1:]
_token_type_ids = ([0] * len(ids[i])) + ([1] * len(ids[j][1:]))
new_ids.append(_ids)
new_attention_masks.append(_attention_mask)
new_special_tokens_masks.append(_special_tokens_mask)
Expand All @@ -118,15 +118,17 @@ def add_example(encoded):
special_tokens_masks = new_special_tokens_masks
token_type_ids = new_token_type_ids

ids = [torch.tensor(i, dtype=torch.int32) for i in ids]
attention_masks = [torch.tensor(i, dtype=torch.bool) for i in attention_masks]
special_tokens_masks = [torch.tensor(i, dtype=torch.bool) for i in special_tokens_masks]
token_type_ids = [torch.tensor(i, dtype=torch.int8) for i in token_type_ids]

torch.save(
{
'ids':
torch.tensor(ids, dtype=torch.int16),
'attention_masks':
torch.tensor(attention_masks, dtype=torch.bool),
'special_tokens_masks':
torch.tensor(special_tokens_masks, dtype=torch.bool),
'token_type_ids': torch.tensor(token_type_ids, dtype=torch.int8)
'ids': pad_sequence(ids, batch_first=True, padding_value=tokenizer.token_to_id('[PAD]')),
'attention_masks': pad_sequence(attention_masks, batch_first=True, padding_value=0),
'special_tokens_masks': pad_sequence(special_tokens_masks, batch_first=True, padding_value=1),
'token_type_ids': pad_sequence(token_type_ids, batch_first=True, padding_value=1)
}, output_file)

return {'num_tokens': num_tokens, 'num_examples': num_examples}
Expand Down

0 comments on commit 86f31ba

Please sign in to comment.