Skip to content

Commit

Permalink
Merge pull request #46 from texttron/fast_tokenizer
Browse files Browse the repository at this point in the history
to use Fast tokenizer
  • Loading branch information
MXueguang committed Jun 27, 2022
2 parents 2eb0472 + 602133e commit b8f3390
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 16 deletions.
15 changes: 9 additions & 6 deletions src/tevatron/arguments.py
Expand Up @@ -99,12 +99,15 @@ def __post_init__(self):
self.dataset_split = 'train'
self.dataset_language = 'default'
if self.train_dir is not None:
files = os.listdir(self.train_dir)
self.train_path = [
os.path.join(self.train_dir, f)
for f in files
if f.endswith('jsonl') or f.endswith('json')
]
if os.path.isdir(self.train_dir):
files = os.listdir(self.train_dir)
self.train_path = [
os.path.join(self.train_dir, f)
for f in files
if f.endswith('jsonl') or f.endswith('json')
]
else:
self.train_path = [self.train_dir]
else:
self.train_path = None

Expand Down
4 changes: 2 additions & 2 deletions src/tevatron/data.py
Expand Up @@ -30,7 +30,7 @@ def __init__(
self.total_len = len(self.train_data)

def create_one_example(self, text_encoding: List[int], is_query=False):
item = self.tok.encode_plus(
item = self.tok.prepare_for_model(

This comment has been minimized.

Copy link
@eugene-yang

eugene-yang Jul 1, 2022

I recently tried to train a DPR model with the latest version. Looks like it's supposed to be .encode_plus?

This comment has been minimized.

Copy link
@MXueguang

MXueguang Jul 1, 2022

Author Contributor

Hi @eugene-yang, is there any error occurred here?
I am making this change in order to use fast tokenizer, since fast tokenizer don't support List[int] input for encode_plus

This comment has been minimized.

Copy link
@eugene-yang

eugene-yang Jul 1, 2022

This is what I got -- haven't really dug into the trace, but looks like the input of the method is not yet encoded into token ids. So probably should be encoding instead of just adding special token ids.

  File "/expscratch/eyang/workspace/adapter/transformers/src/transformers/trainer.py", line 1317, in train
    return inner_training_loop(
  File "/expscratch/eyang/workspace/adapter/transformers/src/transformers/trainer.py", line 1528, in _inner_training_loop
    for step, inputs in enumerate(epoch_iterator):
  File "/home/hltcoe/eyang/.conda/envs/adapter/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 530, in __next__
    data = self._next_data()
  File "/home/hltcoe/eyang/.conda/envs/adapter/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 570, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/home/hltcoe/eyang/.conda/envs/adapter/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/hltcoe/eyang/.conda/envs/adapter/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/expscratch/eyang/workspace/adapter/tevatron/src/tevatron/data.py", line 53, in __getitem__
    encoded_query = self.create_one_example(qry, is_query=True)
  File "/expscratch/eyang/workspace/adapter/tevatron/src/tevatron/data.py", line 33, in create_one_example
    item = self.tok.prepare_for_model(
  File "/expscratch/eyang/workspace/adapter/transformers/src/transformers/tokenization_utils_base.py", line 3016, in prepare_for_model
    sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
  File "/expscratch/eyang/workspace/adapter/transformers/src/transformers/models/xlm_roberta/tokenization_xlm_roberta_fast.py", line 176, in build_inputs_with_special_tokens
    return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
TypeError: can only concatenate list (not "str") to list

This comment has been minimized.

Copy link
@MXueguang

MXueguang Jul 2, 2022

Author Contributor

Ah, I see. you are using this format {'query': string, 'positives': string, 'negatives': string} as your training data right? I forgot to consider the on-the-fly tokenize option... yes, a quick fix is change that line back to encode_plus. I'll have a patch to handle the case asap.

This comment has been minimized.

Copy link
@eugene-yang

eugene-yang Jul 2, 2022

Ah right. Should have mentioned that. Thanks for helping! Tevatron is really a great tool!

text_encoding,
truncation='only_first',
max_length=self.data_args.q_max_len if is_query else self.data_args.p_max_len,
Expand Down Expand Up @@ -95,7 +95,7 @@ def __len__(self):

def __getitem__(self, item) -> Tuple[str, BatchEncoding]:
text_id, text = (self.encode_data[item][f] for f in self.input_keys)
encoded_text = self.tok.encode_plus(
encoded_text = self.tok.prepare_for_model(
text,
max_length=self.max_len,
truncation='only_first',
Expand Down
2 changes: 1 addition & 1 deletion src/tevatron/datasets/preprocessor.py
Expand Up @@ -54,4 +54,4 @@ def __call__(self, example):
add_special_tokens=False,
max_length=self.text_max_length,
truncation=True)
return {'text_id': docid, 'text': text}
return {'text_id': docid, 'text': text}
3 changes: 1 addition & 2 deletions src/tevatron/driver/encode.py
Expand Up @@ -52,8 +52,7 @@ def main():
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=False,
cache_dir=model_args.cache_dir
)

model = DenseModel.load(
Expand Down
10 changes: 8 additions & 2 deletions src/tevatron/driver/train.py
Expand Up @@ -2,6 +2,7 @@
import os
import sys

import torch
from transformers import AutoConfig, AutoTokenizer
from transformers import (
HfArgumentParser,
Expand Down Expand Up @@ -66,8 +67,7 @@ def main():
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=False,
cache_dir=model_args.cache_dir
)
model = DenseModel.build(
model_args,
Expand All @@ -78,7 +78,13 @@ def main():

train_dataset = HFTrainDataset(tokenizer=tokenizer, data_args=data_args,
cache_dir=data_args.data_cache_dir or model_args.cache_dir)
if training_args.local_rank > 0:
print("Waiting for main process to perform the mapping")
torch.distributed.barrier()
train_dataset = TrainDataset(data_args, train_dataset.process(), tokenizer)
if training_args.local_rank == 0:
print("Loading results from main process")
torch.distributed.barrier()

trainer_cls = GCTrainer if training_args.grad_cache else Trainer
trainer = trainer_cls(
Expand Down
3 changes: 2 additions & 1 deletion src/tevatron/faiss_retriever/__main__.py
Expand Up @@ -19,7 +19,7 @@

def search_queries(retriever, q_reps, p_lookup, args):
if args.batch_size > 0:
all_scores, all_indices = retriever.batch_search(q_reps, args.depth, args.batch_size)
all_scores, all_indices = retriever.batch_search(q_reps, args.depth, args.batch_size, args.quiet)
else:
all_scores, all_indices = retriever.search(q_reps, args.depth)

Expand Down Expand Up @@ -56,6 +56,7 @@ def main():
parser.add_argument('--depth', type=int, default=1000)
parser.add_argument('--save_ranking_to', required=True)
parser.add_argument('--save_text', action='store_true')
parser.add_argument('--quiet', action='store_true')

args = parser.parse_args()

Expand Down
5 changes: 3 additions & 2 deletions src/tevatron/faiss_retriever/retriever.py
Expand Up @@ -2,6 +2,7 @@
import faiss

import logging
from tqdm import tqdm

logger = logging.getLogger(__name__)

Expand All @@ -17,11 +18,11 @@ def add(self, p_reps: np.ndarray):
def search(self, q_reps: np.ndarray, k: int):
return self.index.search(q_reps, k)

def batch_search(self, q_reps: np.ndarray, k: int, batch_size: int):
def batch_search(self, q_reps: np.ndarray, k: int, batch_size: int, quiet: bool=False):
num_query = q_reps.shape[0]
all_scores = []
all_indices = []
for start_idx in range(0, num_query, batch_size):
for start_idx in tqdm(range(0, num_query, batch_size), disable=quiet):
nn_scores, nn_indices = self.search(q_reps[start_idx: start_idx + batch_size], k)
all_scores.append(nn_scores)
all_indices.append(nn_indices)
Expand Down

0 comments on commit b8f3390

Please sign in to comment.