From 9b49dc8a95a0e3129a4fdbfba3cdb312737c212f Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 5 Oct 2021 16:08:58 +0200 Subject: [PATCH] Fixing question-answering with long contexts (#13873) * Tmp. * Fixing BC for question answering with long context. * Capping model_max_length to avoid tf overflow. * Bad workaround bugged roberta. * Fixing name. --- .../pipelines/question_answering.py | 218 ++++++++++-------- tests/test_modeling_led.py | 5 + tests/test_modeling_reformer.py | 1 + tests/test_pipelines_question_answering.py | 73 ++++++ 4 files changed, 199 insertions(+), 98 deletions(-) diff --git a/src/transformers/pipelines/question_answering.py b/src/transformers/pipelines/question_answering.py index cf8c550e162fc9..f4e73917a77b3d 100644 --- a/src/transformers/pipelines/question_answering.py +++ b/src/transformers/pipelines/question_answering.py @@ -248,7 +248,13 @@ def __call__(self, *args, **kwargs): return super().__call__(examples[0], **kwargs) return super().__call__(examples, **kwargs) - def preprocess(self, example, padding="do_not_pad", doc_stride=128, max_question_len=64, max_seq_len=384): + def preprocess(self, example, padding="do_not_pad", doc_stride=None, max_question_len=64, max_seq_len=None): + + if max_seq_len is None: + max_seq_len = min(self.tokenizer.model_max_length, 384) + if doc_stride is None: + doc_stride = min(max_seq_len // 4, 128) + if not self.tokenizer.is_fast: features = squad_convert_examples_to_features( examples=[example], @@ -277,7 +283,6 @@ def preprocess(self, example, padding="do_not_pad", doc_stride=128, max_question return_offsets_mapping=True, return_special_tokens_mask=True, ) - # When the input is too long, it's converted in a batch of inputs with overflowing tokens # and a stride of overlap between the inputs. If a batch of inputs is given, a special output # "overflow_to_sample_mapping" indicate which member of the encoded batch belong to which original batch sample. @@ -308,12 +313,15 @@ def preprocess(self, example, padding="do_not_pad", doc_stride=128, max_question token_type_ids_span_idx = ( encoded_inputs["token_type_ids"][span_idx] if "token_type_ids" in encoded_inputs else None ) + submask = p_mask[span_idx] + if isinstance(submask, np.ndarray): + submask = submask.tolist() features.append( SquadFeatures( input_ids=input_ids_span_idx, attention_mask=attention_mask_span_idx, token_type_ids=token_type_ids_span_idx, - p_mask=p_mask[span_idx].tolist(), + p_mask=submask, encoding=encoded_inputs[span_idx], # We don't use the rest of the values - and actually # for Fast tokenizer we could totally avoid using SquadFeatures and SquadExample @@ -330,26 +338,41 @@ def preprocess(self, example, padding="do_not_pad", doc_stride=128, max_question qas_id=None, ) ) - return {"features": features, "example": example} + + split_features = [] + for feature in features: + fw_args = {} + others = {} + model_input_names = self.tokenizer.model_input_names + + for k, v in feature.__dict__.items(): + if k in model_input_names: + if self.framework == "tf": + tensor = tf.constant(v) + if tensor.dtype == tf.int64: + tensor = tf.cast(tensor, tf.int32) + fw_args[k] = tf.expand_dims(tensor, 0) + elif self.framework == "pt": + tensor = torch.tensor(v) + if tensor.dtype == torch.int32: + tensor = tensor.long() + fw_args[k] = tensor.unsqueeze(0) + else: + others[k] = v + split_features.append({"fw_args": fw_args, "others": others}) + return {"features": split_features, "example": example} def _forward(self, model_inputs): features = model_inputs["features"] example = model_inputs["example"] - model_input_names = self.tokenizer.model_input_names - fw_args = {k: [feature.__dict__[k] for feature in features] for k in model_input_names} - - if self.framework == "tf": - fw_args = {k: tf.constant(v) for (k, v) in fw_args.items()} - start, end = self.model(fw_args)[:2] - start, end = start.numpy(), end.numpy() - elif self.framework == "pt": - # Retrieve the score for the context tokens only (removing question tokens) - fw_args = {k: torch.tensor(v, device=self.device) for (k, v) in fw_args.items()} - # On Windows, the default int type in numpy is np.int32 so we get some non-long tensors. - fw_args = {k: v.long() if v.dtype == torch.int32 else v for (k, v) in fw_args.items()} + starts = [] + ends = [] + for feature in features: + fw_args = feature["fw_args"] start, end = self.model(**fw_args)[:2] - start, end = start.cpu().numpy(), end.cpu().numpy() - return {"start": start, "end": end, "features": features, "example": example} + starts.append(start) + ends.append(end) + return {"starts": starts, "ends": ends, "features": features, "example": example} def postprocess( self, @@ -360,90 +383,89 @@ def postprocess( ): min_null_score = 1000000 # large and positive answers = [] - start_ = model_outputs["start"][0] - end_ = model_outputs["end"][0] - feature = model_outputs["features"][0] example = model_outputs["example"] - # Ensure padded tokens & question tokens cannot belong to the set of candidate answers. - undesired_tokens = np.abs(np.array(feature.p_mask) - 1) - - if feature.attention_mask is not None: - undesired_tokens = undesired_tokens & feature.attention_mask - - # Generate mask - undesired_tokens_mask = undesired_tokens == 0.0 - - # Make sure non-context indexes in the tensor cannot contribute to the softmax - start_ = np.where(undesired_tokens_mask, -10000.0, start_) - end_ = np.where(undesired_tokens_mask, -10000.0, end_) - - # Normalize logits and spans to retrieve the answer - start_ = np.exp(start_ - np.log(np.sum(np.exp(start_), axis=-1, keepdims=True))) - end_ = np.exp(end_ - np.log(np.sum(np.exp(end_), axis=-1, keepdims=True))) - - if handle_impossible_answer: - min_null_score = min(min_null_score, (start_[0] * end_[0]).item()) - - # Mask CLS - start_[0] = end_[0] = 0.0 - - starts, ends, scores = self.decode(start_, end_, top_k, max_answer_len, undesired_tokens) - if not self.tokenizer.is_fast: - char_to_word = np.array(example.char_to_word_offset) - - # Convert the answer (tokens) back to the original text - # Score: score from the model - # Start: Index of the first character of the answer in the context string - # End: Index of the character following the last character of the answer in the context string - # Answer: Plain text of the answer - for s, e, score in zip(starts, ends, scores): - answers.append( - { - "score": score.item(), - "start": np.where(char_to_word == feature.token_to_orig_map[s])[0][0].item(), - "end": np.where(char_to_word == feature.token_to_orig_map[e])[0][-1].item(), - "answer": " ".join( - example.doc_tokens[feature.token_to_orig_map[s] : feature.token_to_orig_map[e] + 1] - ), - } - ) - else: - # Convert the answer (tokens) back to the original text - # Score: score from the model - # Start: Index of the first character of the answer in the context string - # End: Index of the character following the last character of the answer in the context string - # Answer: Plain text of the answer - question_first = bool(self.tokenizer.padding_side == "right") - enc = feature.encoding - - # Sometimes the max probability token is in the middle of a word so: - # - we start by finding the right word containing the token with `token_to_word` - # - then we convert this word in a character span with `word_to_chars` - sequence_index = 1 if question_first else 0 - for s, e, score in zip(starts, ends, scores): - try: - start_word = enc.token_to_word(s) - end_word = enc.token_to_word(e) - start_index = enc.word_to_chars(start_word, sequence_index=sequence_index)[0] - end_index = enc.word_to_chars(end_word, sequence_index=sequence_index)[1] - except Exception: - # Some tokenizers don't really handle words. Keep to offsets then. - start_index = enc.offsets[s][0] - end_index = enc.offsets[e][1] - - answers.append( - { - "score": score.item(), - "start": start_index, - "end": end_index, - "answer": example.context_text[start_index:end_index], - } - ) + for i, (feature_, start_, end_) in enumerate( + zip(model_outputs["features"], model_outputs["starts"], model_outputs["ends"]) + ): + feature = feature_["others"] + # Ensure padded tokens & question tokens cannot belong to the set of candidate answers. + undesired_tokens = np.abs(np.array(feature["p_mask"]) - 1) + + if feature_["fw_args"].get("attention_mask", None) is not None: + undesired_tokens = undesired_tokens & feature_["fw_args"]["attention_mask"].numpy() + + # Generate mask + undesired_tokens_mask = undesired_tokens == 0.0 + + # Make sure non-context indexes in the tensor cannot contribute to the softmax + start_ = np.where(undesired_tokens_mask, -10000.0, start_) + end_ = np.where(undesired_tokens_mask, -10000.0, end_) + + # Normalize logits and spans to retrieve the answer + start_ = np.exp(start_ - np.log(np.sum(np.exp(start_), axis=-1, keepdims=True))) + end_ = np.exp(end_ - np.log(np.sum(np.exp(end_), axis=-1, keepdims=True))) + + if handle_impossible_answer: + min_null_score = min(min_null_score, (start_[0] * end_[0]).item()) + + # Mask CLS + start_[0, 0] = end_[0, 0] = 0.0 + + starts, ends, scores = self.decode(start_, end_, top_k, max_answer_len, undesired_tokens) + if not self.tokenizer.is_fast: + char_to_word = np.array(example.char_to_word_offset) + + # Convert the answer (tokens) back to the original text + # Score: score from the model + # Start: Index of the first character of the answer in the context string + # End: Index of the character following the last character of the answer in the context string + # Answer: Plain text of the answer + for s, e, score in zip(starts, ends, scores): + token_to_orig_map = feature["token_to_orig_map"] + answers.append( + { + "score": score.item(), + "start": np.where(char_to_word == token_to_orig_map[s])[0][0].item(), + "end": np.where(char_to_word == token_to_orig_map[e])[0][-1].item(), + "answer": " ".join(example.doc_tokens[token_to_orig_map[s] : token_to_orig_map[e] + 1]), + } + ) + else: + # Convert the answer (tokens) back to the original text + # Score: score from the model + # Start: Index of the first character of the answer in the context string + # End: Index of the character following the last character of the answer in the context string + # Answer: Plain text of the answer + question_first = bool(self.tokenizer.padding_side == "right") + enc = feature["encoding"] + + # Sometimes the max probability token is in the middle of a word so: + # - we start by finding the right word containing the token with `token_to_word` + # - then we convert this word in a character span with `word_to_chars` + sequence_index = 1 if question_first else 0 + for s, e, score in zip(starts, ends, scores): + try: + start_word = enc.token_to_word(s) + end_word = enc.token_to_word(e) + start_index = enc.word_to_chars(start_word, sequence_index=sequence_index)[0] + end_index = enc.word_to_chars(end_word, sequence_index=sequence_index)[1] + except Exception: + # Some tokenizers don't really handle words. Keep to offsets then. + start_index = enc.offsets[s][0] + end_index = enc.offsets[e][1] + + answers.append( + { + "score": score.item(), + "start": start_index, + "end": end_index, + "answer": example.context_text[start_index:end_index], + } + ) if handle_impossible_answer: answers.append({"score": min_null_score, "start": 0, "end": 0, "answer": ""}) - - answers = sorted(answers, key=lambda x: x["score"], reverse=True)[:top_k] + answers = sorted(answers, key=lambda x: x["score"], reverse=True)[:top_k] if len(answers) == 1: return answers[0] return answers diff --git a/tests/test_modeling_led.py b/tests/test_modeling_led.py index bfad0388b169eb..cb0861acdb8aba 100644 --- a/tests/test_modeling_led.py +++ b/tests/test_modeling_led.py @@ -162,6 +162,11 @@ def get_config(self): attention_window=self.attention_window, ) + def get_pipeline_config(self): + config = self.get_config() + config.max_position_embeddings = 100 + return config + def prepare_config_and_inputs_for_common(self): config, inputs_dict = self.prepare_config_and_inputs() global_attention_mask = torch.zeros_like(inputs_dict["input_ids"]) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 3d80be0f15c009..de62e3ed5b358a 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -189,6 +189,7 @@ def get_config(self): def get_pipeline_config(self): config = self.get_config() config.vocab_size = 100 + config.max_position_embeddings = 100 config.axial_pos_shape = (4, 25) config.is_decoder = False return config diff --git a/tests/test_pipelines_question_answering.py b/tests/test_pipelines_question_answering.py index 209b22be5f81c4..cd0e7acde158bc 100644 --- a/tests/test_pipelines_question_answering.py +++ b/tests/test_pipelines_question_answering.py @@ -87,6 +87,12 @@ def run_pipeline_test(self, model, tokenizer, feature_extractor): outputs, [{"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)} for i in range(20)] ) + # Very long context require multiple features + outputs = question_answerer( + question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris." * 20 + ) + self.assertEqual(outputs, {"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)}) + @require_torch def test_small_model_pt(self): question_answerer = pipeline( @@ -121,6 +127,73 @@ def test_large_model_pt(self): self.assertEqual(nested_simplify(outputs), {"score": 0.979, "start": 27, "end": 32, "answer": "Paris"}) + @slow + @require_torch + def test_large_model_issue(self): + qa_pipeline = pipeline( + "question-answering", + model="mrm8488/bert-multi-cased-finetuned-xquadv1", + ) + outputs = qa_pipeline( + { + "context": "Yes Bank founder Rana Kapoor has approached the Bombay High Court, challenging a special court's order from August this year that had remanded him in police custody for a week in a multi-crore loan fraud case. Kapoor, who is currently lodged in Taloja Jail, is an accused in the loan fraud case and some related matters being probed by the CBI and Enforcement Directorate. A single bench presided over by Justice S K Shinde on Tuesday posted the plea for further hearing on October 14. In his plea filed through advocate Vijay Agarwal, Kapoor claimed that the special court's order permitting the CBI's request for police custody on August 14 was illegal and in breach of the due process of law. Therefore, his police custody and subsequent judicial custody in the case were all illegal. Kapoor has urged the High Court to quash and set aside the special court's order dated August 14. As per his plea, in August this year, the CBI had moved two applications before the special court, one seeking permission to arrest Kapoor, who was already in judicial custody at the time in another case, and the other, seeking his police custody. While the special court refused to grant permission to the CBI to arrest Kapoor, it granted the central agency's plea for his custody. Kapoor, however, said in his plea that before filing an application for his arrest, the CBI had not followed the process of issuing him a notice under Section 41 of the CrPC for appearance before it. He further said that the CBI had not taken prior sanction as mandated under section 17 A of the Prevention of Corruption Act for prosecuting him. The special court, however, had said in its order at the time that as Kapoor was already in judicial custody in another case and was not a free man the procedure mandated under Section 41 of the CrPC need not have been adhered to as far as issuing a prior notice of appearance was concerned. ADVERTISING It had also said that case records showed that the investigating officer had taken an approval from a managing director of Yes Bank before beginning the proceedings against Kapoor and such a permission was a valid sanction. However, Kapoor in his plea said that the above order was bad in law and sought that it be quashed and set aside. The law mandated that if initial action was not in consonance with legal procedures, then all subsequent actions must be held as illegal, he said, urging the High Court to declare the CBI remand and custody and all subsequent proceedings including the further custody as illegal and void ab-initio. In a separate plea before the High Court, Kapoor's daughter Rakhee Kapoor-Tandon has sought exemption from in-person appearance before a special PMLA court. Rakhee has stated that she is a resident of the United Kingdom and is unable to travel to India owing to restrictions imposed due to the COVID-19 pandemic. According to the CBI, in the present case, Kapoor had obtained a gratification or pecuniary advantage of ₹ 307 crore, and thereby caused Yes Bank a loss of ₹ 1,800 crore by extending credit facilities to Avantha Group, when it was not eligible for the same", + "question": "Is this person invovled in fraud?", + } + ) + self.assertEqual( + nested_simplify(outputs), + {"answer": "an accused in the loan fraud case", "end": 294, "score": 0.001, "start": 261}, + ) + + @slow + @require_torch + def test_large_model_course(self): + question_answerer = pipeline("question-answering") + long_context = """ +🤗 Transformers: State of the Art NLP + +🤗 Transformers provides thousands of pretrained models to perform tasks on texts such as classification, information extraction, +question answering, summarization, translation, text generation and more in over 100 languages. +Its aim is to make cutting-edge NLP easier to use for everyone. + +🤗 Transformers provides APIs to quickly download and use those pretrained models on a given text, fine-tune them on your own datasets and +then share them with the community on our model hub. At the same time, each python module defining an architecture is fully standalone and +can be modified to enable quick research experiments. + +Why should I use transformers? + +1. Easy-to-use state-of-the-art models: + - High performance on NLU and NLG tasks. + - Low barrier to entry for educators and practitioners. + - Few user-facing abstractions with just three classes to learn. + - A unified API for using all our pretrained models. + - Lower compute costs, smaller carbon footprint: + +2. Researchers can share trained models instead of always retraining. + - Practitioners can reduce compute time and production costs. + - Dozens of architectures with over 10,000 pretrained models, some in more than 100 languages. + +3. Choose the right framework for every part of a model's lifetime: + - Train state-of-the-art models in 3 lines of code. + - Move a single model between TF2.0/PyTorch frameworks at will. + - Seamlessly pick the right framework for training, evaluation and production. + +4. Easily customize a model or an example to your needs: + - We provide examples for each architecture to reproduce the results published by its original authors. + - Model internals are exposed as consistently as possible. + - Model files can be used independently of the library for quick experiments. + +🤗 Transformers is backed by the three most popular deep learning libraries — Jax, PyTorch and TensorFlow — with a seamless integration +between them. It's straightforward to train your models with one before loading them for inference with the other. +""" + question = "Which deep learning libraries back 🤗 Transformers?" + outputs = question_answerer(question=question, context=long_context) + + self.assertEqual( + nested_simplify(outputs), + {"answer": "Jax, PyTorch and TensorFlow", "end": 1919, "score": 0.971, "start": 1892}, + ) + @slow @require_tf def test_large_model_tf(self):