In [33]:
from transformers import AutoTokenizer

In [34]:
%%time
pretrained_dir = "pretrained/google/electra-small-discriminator"
tokenizer = AutoTokenizer.from_pretrained(pretrained_dir, model_max_length=10)
print(f"""{repr(tokenizer)}
model_input_names={repr(tokenizer.model_input_names)}
""")

PreTrainedTokenizerFast(name_or_path='pretrained/google/electra-small-discriminator', vocab_size=30522, model_max_len=10, is_fast=True, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})
model_input_names=['input_ids', 'token_type_ids', 'attention_mask']

CPU times: user 31.2 ms, sys: 0 ns, total: 31.2 ms
Wall time: 26.5 ms


# Truncation problems
- Truncate question or passage?
- If passage is truncated, then answer span can also be truncated. How to get indexes of partial answer span over multiple passage chunks?
- [`BatchEncoding.char_to_token`](https://huggingface.co/transformers/main_classes/tokenizer.html#transformers.BatchEncoding.char_to_token) does not work well with truncation. If original string is truncated, it cannot be mapped to token space.

In [35]:
contexts = [
    "one two three one two three",
    "one two three one two three",
    "one two three one two three",
    "interdependence one",
    "one excommunication two",
    "one two excommunication",
]
questions = ["one"] * len(contexts)
answer_start = [0, 14, 0, 0, 4, 8]
# must be the index of last char in the answer span!
answer_end = [2, 16, 12, 14, 18, 22]  
es = tokenizer(contexts, questions, truncation=False, padding=False)
print(repr(es))

{'input_ids': [[101, 2028, 2048, 2093, 2028, 2048, 2093, 102, 2028, 102], [101, 2028, 2048, 2093, 2028, 2048, 2093, 102, 2028, 102], [101, 2028, 2048, 2093, 2028, 2048, 2093, 102, 2028, 102], [101, 6970, 3207, 11837, 4181, 3401, 2028, 102, 2028, 102], [101, 2028, 4654, 9006, 23041, 21261, 2048, 102, 2028, 102], [101, 2028, 2048, 4654, 9006, 23041, 21261, 102, 2028, 102]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 1, 1], [0, 0, 0, 0, 0, 0, 0, 0, 1, 1], [0, 0, 0, 0, 0, 0, 0, 0, 1, 1], [0, 0, 0, 0, 0, 0, 0, 0, 1, 1], [0, 0, 0, 0, 0, 0, 0, 0, 1, 1], [0, 0, 0, 0, 0, 0, 0, 0, 1, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}


In [36]:
input_ids = es["input_ids"]
for i in range(len(answer_start)):
    j = es.char_to_token(i, answer_start[i])
    k = es.char_to_token(i, answer_end[i]) + 1
    _ids = input_ids[i][j:k]
    a = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(_ids))
    print(f"j={j}, k={k}, _ids={_ids}, a={a}")

j=1, k=2, _ids=[2028], a=one
j=4, k=5, _ids=[2028], a=one
j=1, k=4, _ids=[2028, 2048, 2093], a=one two three
j=1, k=6, _ids=[6970, 3207, 11837, 4181, 3401], a=interdependence
j=2, k=6, _ids=[4654, 9006, 23041, 21261], a=excommunication
j=3, k=7, _ids=[4654, 9006, 23041, 21261], a=excommunication
