## adding some libraries

In [1]:
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)."""
import argparse
import h5py
import os
import random
import time
import pickle
import gc
import math
import json
import logging
import collections
import pandas as pd
import numpy as np
import tensorflow as tf
from tqdm.notebook import tqdm
from tqdm import tqdm; tqdm.monitor_interval = 0  # noqa
from collections import namedtuple
from transformers import BertConfig, BertTokenizer, RobertaConfig, RobertaTokenizer
from transformers.tokenization_bert import whitespace_tokenize

In [2]:
logger = logging.getLogger(__name__)


NQExample = collections.namedtuple("NQExample", [
    "qas_id", "question_text", "doc_tokens", "orig_answer_text",
    "start_position", "end_position", "long_position",
    "short_is_impossible", "long_is_impossible", "crop_start"])

Crop = collections.namedtuple("Crop", ["example_id","unique_id", "example_index", "doc_span_index",
    "tokens", "token_to_orig_map", "token_is_max_context",
    "input_ids", "attention_mask", "token_type_ids",
    "paragraph_len", "start_position", "end_position", "long_position",
    "short_is_impossible", "long_is_impossible"])

LongAnswerCandidate = collections.namedtuple('LongAnswerCandidate', [
    'start_token', 'end_token', 'top_level'])

DocSpan = collections.namedtuple("DocSpan", ["start", "length"])

PrelimPrediction = collections.namedtuple("PrelimPrediction",
    ["crop_index", "start_index", "end_index", "start_logit", "end_logit"])

NbestPrediction = collections.namedtuple("NbestPrediction", [
    "text", "start_logit", "end_logit",
    "start_index", "end_index",
    "orig_doc_start", "orig_doc_end", "crop_index"])

RawResult = namedtuple("RawResult", ["unique_id", "start_logits", "end_logits",
    "long_logits"])

UNMAPPED = -123
CLS_INDEX = 0

In [3]:
def get_add_tokens(do_enumerate):
    tags = ['Dd', 'Dl', 'Dt', 'H1', 'H2', 'H3', 'Li', 'Ol', 'P', 'Table', 'Td', 'Th', 'Tr', 'Ul']
    opening_tags = [f'<{tag}>' for tag in tags]
    closing_tags = [f'</{tag}>' for tag in tags]
    added_tags = opening_tags + closing_tags
    # See `nq_to_sqaud.py` for special-tokens
    special_tokens = ['<P>', '<Table>']
    if do_enumerate:
        for special_token in special_tokens:
            for j in range(11):
              added_tags.append(f'<{special_token[1: -1]}{j}>')

    add_tokens = ['Td_colspan', 'Th_colspan', '``', '\'\'', '--']
    add_tokens = add_tokens + added_tags
    return add_tokens

## this function helps us to create candidates dict for the first time and load it lately

In [4]:
def read_candidates(candidate_files, do_cache=True):
    assert isinstance(candidate_files, (tuple, list)), candidate_files
    for fn in candidate_files:
        assert os.path.exists(fn), f'Missing file {fn}'
    cache_fn = 'candidates.pkl'

    candidates = {}
    if not os.path.exists(cache_fn):
        for fn in candidate_files:
            with open(fn) as f:
                for line in tqdm(f):
                    entry = json.loads(line)
                    example_id = str(entry['example_id'])
                    cnds = entry.pop('long_answer_candidates')
                    cnds = [LongAnswerCandidate(c['start_token'], c['end_token'],
                            c['top_level']) for c in cnds]
                    candidates[example_id] = cnds

        if do_cache:
            with open(cache_fn, 'wb') as f:
                pickle.dump(candidates, f)
    else:
        print(f'Loading from cache: {cache_fn}')
        with open(cache_fn, 'rb') as f:
            candidates = pickle.load(f)

    return candidates

## this just check whether a charater is white space

In [5]:
def is_whitespace(c):
    if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
        return True
    return False

In [6]:
def read_nq_examples(input_file_or_data, is_training):
    """Read a NQ json file into a list of NQExample. Refer to `nq_to_squad.py`
       to convert the `simplified-nq-t*.jsonl` files to NQ json."""
    if isinstance(input_file_or_data, str):
        with open(input_file_or_data, "r", encoding='utf-8') as f:
            input_data = json.load(f)["data"]

    else:
        input_data = input_file_or_data
    for entry_index, entry in enumerate(tqdm(input_data, total=len(input_data))):
        # if entry_index >= 2:
        #     break
        assert len(entry["paragraphs"]) == 1
        paragraph = entry["paragraphs"][0]
        paragraph_text = paragraph["context"]
        doc_tokens = []
        char_to_word_offset = []
        prev_is_whitespace = True
        for c in paragraph_text:
            if is_whitespace(c):
                prev_is_whitespace = True
            else:
                if prev_is_whitespace:
                    doc_tokens.append(c)
                else:
                    doc_tokens[-1] += c
                prev_is_whitespace = False
            char_to_word_offset.append(len(doc_tokens) - 1)
        assert len(paragraph["qas"]) == 1
        qa = paragraph["qas"][0]
        start_position = None
        end_position = None
        long_position = None
        orig_answer_text = None
        short_is_impossible = False
        long_is_impossible = False
        if is_training:
            short_is_impossible = qa["short_is_impossible"]
            short_answers = qa["short_answers"]
            if len(short_answers) >= 2:
                # logger.info(f"Choosing leftmost of "
                #     f"{len(short_answers)} short answer")
                short_answers = sorted(short_answers, key=lambda sa: sa["answer_start"])
                short_answers = short_answers[0: 1]
            if not short_is_impossible:
                answer = short_answers[0]
                orig_answer_text = answer["text"]
                answer_offset = answer["answer_start"]
                answer_length = len(orig_answer_text)
                start_position = char_to_word_offset[answer_offset]
                end_position = char_to_word_offset[
                    answer_offset + answer_length - 1]
                # Only add answers where the text can be exactly
                # recovered from the document. If this CAN'T
                # happen it's likely due to weird Unicode stuff
                # so we will just skip the example.
                #
                # Note that this means for training mode, every
                # example is NOT guaranteed to be preserved.
                actual_text = " ".join(doc_tokens[start_position:
                    end_position + 1])
                cleaned_answer_text = " ".join(
                    whitespace_tokenize(orig_answer_text))
                if actual_text.find(cleaned_answer_text) == -1:
                    logger.warning(
                        "Could not find answer: '%s' vs. '%s'",
                        actual_text, cleaned_answer_text)
                    continue
            else:
                start_position = -1
                end_position = -1
                orig_answer_text = ""

            long_is_impossible = qa["long_is_impossible"]
            long_answers = qa["long_answers"]
            if (len(long_answers) != 1) and not long_is_impossible:
                raise ValueError(f"For training, each question"
                    f" should have exactly 1 long answer.")

            if not long_is_impossible:
                long_answer = long_answers[0]
                long_answer_offset = long_answer["answer_start"]
                long_position = char_to_word_offset[long_answer_offset]
            else:
                long_position = -1

            # print(f'Q:{question_text}')
            # print(f'A:{start_position}, {end_position},
            # {orig_answer_text}')
            # print(f'R:{doc_tokens[start_position: end_position]}')

            if not short_is_impossible and not long_is_impossible:
                assert long_position <= start_position

            if not short_is_impossible and long_is_impossible:
                assert False, f'Invalid pair short, long pair'

        example = NQExample(
            qas_id=qa["id"],
            question_text=qa["question"],
            doc_tokens=doc_tokens,
            orig_answer_text=orig_answer_text,
            start_position=start_position,
            end_position=end_position,
            long_position=long_position,
            short_is_impossible=short_is_impossible,
            long_is_impossible=long_is_impossible,
            crop_start=qa["crop_start"])

        yield example

## Here feel like we create list of sub - documents using technique called SLIDING WINDOW with stride equal to doc_stride

In [7]:
def get_spans(doc_stride, max_tokens_for_doc, max_len):
    doc_spans = []
    start_offset = 0
    while start_offset < max_len:
        length = max_len - start_offset
        if length > max_tokens_for_doc:
            length = max_tokens_for_doc
        doc_spans.append(DocSpan(start=start_offset, length=length))
        if start_offset + length == max_len:
            break
        start_offset += min(length, doc_stride)
    return doc_spans

In [8]:
def convert_examples_to_crops(examples_gen, tokenizer, max_seq_length,
                              doc_stride, max_query_length, is_training,
                              cls_token='[CLS]', sep_token='[SEP]', pad_id=0,
                              sequence_a_segment_id=0,
                              sequence_b_segment_id=1,
                              cls_token_segment_id=0,
                              pad_token_segment_id=0,
                              mask_padding_with_zero=True,
                              p_keep_impossible=None,
                              sep_token_extra=False):
    """Loads a data file into a list of `InputBatch`s."""
    assert p_keep_impossible is not None, '`p_keep_impossible` is required'
    unique_id = 1000000000
    num_short_pos, num_short_neg = 0, 0
    num_long_pos, num_long_neg = 0, 0
    sub_token_cache = {}
    
    crops = []
    for example_index, example in enumerate(examples_gen):
        if example_index % 1000 == 0 and example_index > 0:
            logger.info('Converting %s: short_pos %s short_neg %s'
                ' long_pos %s long_neg %s',
                example_index, num_short_pos, num_short_neg,
                num_long_pos, num_long_neg)

        query_tokens = tokenizer.tokenize(example.question_text)
        if len(query_tokens) > max_query_length:
            query_tokens = query_tokens[0:max_query_length]

        # this takes the longest!
        tok_to_orig_index = []
        orig_to_tok_index = []
        all_doc_tokens = []
        
        for i, token in enumerate(example.doc_tokens):
            orig_to_tok_index.append(len(all_doc_tokens))
            sub_tokens = sub_token_cache.get(token)
            if sub_tokens is None:
                sub_tokens = tokenizer.tokenize(token)
                sub_token_cache[token] = sub_tokens
            tok_to_orig_index.extend([i for _ in range(len(sub_tokens))])
            all_doc_tokens.extend(sub_tokens)

        tok_start_position = None
        tok_end_position = None
#         if is_training and example.short_is_impossible:
#             tok_start_position = -1
#             tok_end_position = -1

#         if is_training and not example.short_is_impossible:
#             tok_start_position = orig_to_tok_index[example.start_position]
#             if example.end_position < len(example.doc_tokens) - 1:
#                 tok_end_position = orig_to_tok_index[
#                     example.end_position + 1] - 1
#             else:
#                 tok_end_position = len(all_doc_tokens) - 1
#             tok_long_position = None
#         if is_training and example.long_is_impossible:
#             tok_long_position = -1

#         if is_training and not example.long_is_impossible:
#             tok_long_position = orig_to_tok_index[example.long_position]

        # For Bert: [CLS] question [SEP] paragraph [SEP]
        special_tokens_count = 3
        if sep_token_extra:
            # For Roberta: <s> question </s> </s> paragraph </s>
            special_tokens_count += 1
        max_tokens_for_doc = max_seq_length - len(query_tokens) - special_tokens_count
        assert max_tokens_for_doc > 0
        # We can have documents that are longer than the maximum
        # sequence length. To deal with this we do a sliding window
        # approach, where we take chunks of the up to our max length
        # with a stride of `doc_stride`.
        doc_spans = get_spans(doc_stride, max_tokens_for_doc, len(all_doc_tokens))
        for doc_span_index, doc_span in enumerate(doc_spans):
            # Tokens are constructed as: CLS Query SEP Paragraph SEP
            tokens = []
            token_to_orig_map = UNMAPPED * np.ones((max_seq_length, ), dtype=np.int32)
            token_is_max_context = np.zeros((max_seq_length, ), dtype=np.bool)
            token_type_ids = []
            short_is_impossible = example.short_is_impossible
            start_position = None
            end_position = None
            special_tokens_offset = special_tokens_count - 1
            doc_offset = len(query_tokens) + special_tokens_offset
#             if is_training and not short_is_impossible:
#                 doc_start = doc_span.start
#                 doc_end = doc_span.start + doc_span.length - 1
#                 if not (tok_start_position >= doc_start and tok_end_position <= doc_end):
#                     start_position = 0
#                     end_position = 0
#                     short_is_impossible = True
#                 else:
#                     start_position = tok_start_position - doc_start + doc_offset
#                     end_position = tok_end_position - doc_start + doc_offset

            long_is_impossible = example.long_is_impossible
            long_position = None
#             if is_training and not long_is_impossible:
#                 doc_start = doc_span.start
#                 doc_end = doc_span.start + doc_span.length - 1
#                 # out of span
#                 if not (tok_long_position >= doc_start and tok_long_position <= doc_end):
#                     long_position = 0
#                     long_is_impossible = True
#                 else:
#                     long_position = tok_long_position - doc_start + doc_offset

            # drop impossible samples
            if long_is_impossible:
                if np.random.rand() > p_keep_impossible:
                    continue

            # CLS token at the beginning
            tokens.append(cls_token)
            token_type_ids.append(cls_token_segment_id)
            # p_mask.append(0)  # can be answer

            # Query
            tokens += query_tokens
            token_type_ids += [sequence_a_segment_id] * len(query_tokens)
            # p_mask += [1] * len(query_tokens)  # can not be answer

            # SEP token
            tokens.append(sep_token)
            token_type_ids.append(sequence_a_segment_id)
            # p_mask.append(1)  # can not be answer
            if sep_token_extra:
                tokens.append(sep_token)
                token_type_ids.append(sequence_a_segment_id)
                # p_mask.append(1)

            # Paragraph
            for i in range(doc_span.length):
                split_token_index = doc_span.start + i
                # We add `example.crop_start` as the original document
                # is already shifted
                token_to_orig_map[len(tokens)] = tok_to_orig_index[
                    split_token_index] + example.crop_start

                token_is_max_context[len(tokens)] = check_is_max_context(doc_spans,
                    doc_span_index, split_token_index)
                tokens.append(all_doc_tokens[split_token_index])
                token_type_ids.append(sequence_b_segment_id)
                # p_mask.append(0)  # can be answer

            paragraph_len = doc_span.length

            # SEP token
            tokens.append(sep_token)
            token_type_ids.append(sequence_b_segment_id)
            # p_mask.append(1)  # can not be answer

            input_ids = tokenizer.convert_tokens_to_ids(tokens)

            # The mask has 1 for real tokens and 0 for padding tokens. Only real
            # tokens are attended to.
            attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)

            # Zero-pad up to the sequence length.
            while len(input_ids) < max_seq_length:
                input_ids.append(pad_id)
                attention_mask.append(0 if mask_padding_with_zero else 1)
                token_type_ids.append(pad_token_segment_id)
                
            # reduce memory, only input_ids needs more bits
            input_ids = np.array(input_ids, dtype=np.int32)
            attention_mask = np.array(attention_mask, dtype=np.bool)
            token_type_ids = np.array(token_type_ids, dtype=np.uint8)

#             if is_training and short_is_impossible:
#                 start_position = CLS_INDEX
#                 end_position = CLS_INDEX

#             if is_training and long_is_impossible:
#                 long_position = CLS_INDEX

            if example_index in (0, 10):
                # too spammy otherwise
                if doc_span_index in (0, 5):
                    logger.info("*** Example ***")
                    logger.info("unique_id: %s" % (unique_id))
                    logger.info("example_index: %s" % (example_index))
                    logger.info("doc_span_index: %s" % (doc_span_index))
                    logger.info("tokens: %s" % " ".join(tokens))
                    logger.info("input_ids: %s" % input_ids)
                    logger.info("attention_mask: %s" % np.uint8(attention_mask))
                    logger.info("token_type_ids: %s" % token_type_ids)
#                     if is_training and short_is_impossible:
#                         logger.info("short impossible example")
#                     if is_training and long_is_impossible:
#                         logger.info("long impossible example")
#                     if is_training and not short_is_impossible:
#                         answer_text = " ".join(tokens[start_position: end_position + 1])
#                         logger.info("start_position: %d" % (start_position))
#                         logger.info("end_position: %d" % (end_position))
#                         logger.info("answer: %s" % (answer_text))

            if short_is_impossible:
                num_short_neg += 1
            else:
                num_short_pos += 1

            if long_is_impossible:
                num_long_neg += 1
            else:
                num_long_pos += 1

            crop = Crop(
                example_id = example.qas_id,
                unique_id=unique_id,
                example_index=example_index,
                doc_span_index=doc_span_index,
                tokens=tokens,
                token_to_orig_map=token_to_orig_map,
                token_is_max_context=token_is_max_context,
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                paragraph_len=paragraph_len,
                start_position=start_position,
                end_position=end_position,
                long_position=long_position,
                short_is_impossible=short_is_impossible,
                long_is_impossible=long_is_impossible)
            crops.append(crop)
            unique_id += 1

    return crops

In [9]:
def check_is_max_context(doc_spans, cur_span_index, position):
    """Check if this is the 'max context' doc span for the token."""

    # Because of the sliding window approach taken to scoring documents, a single
    # token can appear in multiple documents. E.g.
    #  Doc: the man went to the store and bought a gallon of milk
    #  Span A: the man went to the
    #  Span B: to the store and bought
    #  Span C: and bought a gallon of
    #  ...
    #
    # Now the word 'bought' will have two scores from spans B and C. We only
    # want to consider the score with "maximum context", which we define as
    # the *minimum* of its left and right context (the *sum* of left and
    # right context will always be the same, of course).
    #
    # In the example the maximum context for 'bought' would be span C since
    # it has 1 left context and 3 right context, while span B has 4 left context
    # and 0 right context.
    best_score = None
    best_span_index = None
    for (span_index, doc_span) in enumerate(doc_spans):
        end = doc_span.start + doc_span.length - 1
        if position < doc_span.start:
            continue
        if position > end:
            continue
        num_left_context = position - doc_span.start
        num_right_context = end - position
        score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
        if best_score is None or score > best_score:
            best_score = score
            best_span_index = span_index

    return cur_span_index == best_span_index

In [10]:
def clean_text(tok_text):
    # De-tokenize WordPieces that have been split off.
    tok_text = tok_text.replace(" ##", "")
    tok_text = tok_text.replace("##", "")

    # Clean whitespace
    tok_text = tok_text.strip()
    tok_text = " ".join(tok_text.split())
    return tok_text

In [11]:
def get_nbest(prelim_predictions, crops, example, n_best_size):
    seen, nbest = set(), []
    for pred in prelim_predictions:
        if len(nbest) >= n_best_size:
            break
        crop = crops[pred.crop_index]
        orig_doc_start, orig_doc_end = -1, -1
        # non-null
        if pred.start_index > 0:
            # Long answer has no end_index. We still generate some text to check
            if pred.end_index == -1:
                tok_tokens = crop.tokens[pred.start_index: pred.start_index + 11]
            else:
                tok_tokens = crop.tokens[pred.start_index: pred.end_index + 1]
            tok_text = " ".join(tok_tokens)
            tok_text = clean_text(tok_text)

            orig_doc_start = int(crop.token_to_orig_map[pred.start_index])
            if pred.end_index == -1:
                orig_doc_end = orig_doc_start + 10
            else:
                orig_doc_end = int(crop.token_to_orig_map[pred.end_index])

            final_text = tok_text
            if final_text in seen:
                continue

        else:
            final_text = ""

        seen.add(final_text)
        nbest.append(NbestPrediction(
            text=final_text,
            start_logit=pred.start_logit, end_logit=pred.end_logit,
            start_index=pred.start_index, end_index=pred.end_index,
            orig_doc_start=orig_doc_start, orig_doc_end=orig_doc_end,
            crop_index=pred.crop_index))

    # Degenerate case. I never saw this happen.
    if len(nbest) in (0, 1):
        nbest.insert(0, NbestPrediction(text="empty",
            start_logit=0.0, end_logit=0.0,
            start_index=-1, end_index=-1,
            orig_doc_start=-1, orig_doc_end=-1,
            crop_index=UNMAPPED))

    assert len(nbest) >= 1
    return nbest

## this helps us write all predictions

In [12]:
def write_predictions(examples_gen, all_crops, all_results, n_best_size,
                      max_answer_length, output_prediction_file,
                      output_nbest_file, output_null_log_odds_file, verbose_logging,
                      short_null_score_diff, long_null_score_diff):
    """Write final predictions to the json file and log-odds of null if needed."""
    logger.info("Writing predictions to: %s" % output_prediction_file)
    logger.info("Writing nbest to: %s" % output_nbest_file)

    # create indexes
    example_index_to_crops = collections.defaultdict(list)
    for crop in all_crops:
        example_index_to_crops[crop.example_index].append(crop)
    unique_id_to_result = {result.unique_id: result for result in all_results}

    all_predictions = collections.OrderedDict()
    all_nbest_json = collections.OrderedDict()
    scores_diff_json = collections.OrderedDict()
    short_num_empty, long_num_empty = 0, 0
    for example_index, example in enumerate(examples_gen):
        if example_index % 1000 == 0 and example_index > 0:
            logger.info(f'[{example_index}]: {short_num_empty} short and {long_num_empty} long empty')

        crops = example_index_to_crops[example_index]
        short_prelim_predictions, long_prelim_predictions = [], []
        for crop_index, crop in enumerate(crops):
            assert crop.unique_id in unique_id_to_result, f"{crop.unique_id}"
            result = unique_id_to_result[crop.unique_id]
            # get the `n_best_size` largest indexes
            # https://stackoverflow.com/questions/6910641/how-do-i-get-indices-of-n-maximum-values-in-a-numpy-array#23734295
            start_indexes = np.argpartition(result.start_logits, -n_best_size)[-n_best_size:]
            start_indexes = [int(x) for x in start_indexes]
            end_indexes = np.argpartition(result.end_logits, -n_best_size)[-n_best_size:]
            end_indexes = [int(x) for x in end_indexes]

            # create short answers
            for start_index in start_indexes:
                if start_index >= len(crop.tokens):
                    continue
                # this skips [CLS] i.e. null prediction
                if crop.token_to_orig_map[start_index] == UNMAPPED:
                    continue
                if not crop.token_is_max_context[start_index]:
                    continue

                for end_index in end_indexes:
                    if end_index >= len(crop.tokens):
                        continue
                    if crop.token_to_orig_map[end_index] == UNMAPPED:
                        continue
                    if end_index < start_index:
                        continue
                    length = end_index - start_index + 1
                    if length > max_answer_length:
                        continue

                    short_prelim_predictions.append(PrelimPrediction(
                        crop_index=crop_index,
                        start_index=start_index,
                        end_index=end_index,
                        start_logit=result.start_logits[start_index],
                        end_logit=result.end_logits[end_index]))

            long_indexes = np.argpartition(result.long_logits, -n_best_size)[-n_best_size:].tolist()
            for long_index in long_indexes:
                if long_index >= len(crop.tokens):
                    continue
                # this skips [CLS] i.e. null prediction
                if crop.token_to_orig_map[long_index] == UNMAPPED:
                    continue
                # TODO(see--): Is this needed?
                # -> Yep helps both short and long by about 0.1
                if not crop.token_is_max_context[long_index]:
                    continue
                long_prelim_predictions.append(PrelimPrediction(
                    crop_index=crop_index,
                    start_index=long_index, end_index=-1,
                    start_logit=result.long_logits[long_index],
                    end_logit=result.long_logits[long_index]))

        short_prelim_predictions = sorted(short_prelim_predictions,
            key=lambda x: x.start_logit + x.end_logit, reverse=True)

        short_nbest = get_nbest(short_prelim_predictions, crops,
            example, n_best_size)

        short_best_non_null = None
        for entry in short_nbest:
            if short_best_non_null is None:
                if entry.text != "":
                    short_best_non_null = entry

        long_prelim_predictions = sorted(long_prelim_predictions,
            key=lambda x: x.start_logit, reverse=True)

        long_nbest = get_nbest(long_prelim_predictions, crops,
            example, n_best_size)

        long_best_non_null = None
        for entry in long_nbest:
            if long_best_non_null is None:
                if entry.text != "":
                    long_best_non_null = entry

        nbest_json = {'short': [], 'long': []}
        for kk, entries in [('short', short_nbest), ('long', long_nbest)]:
            for i, entry in enumerate(entries):
                output = {}
                output["text"] = entry.text
                output["start_logit"] = entry.start_logit
                output["end_logit"] = entry.end_logit
                output["start_index"] = entry.start_index
                output["end_index"] = entry.end_index
                output["orig_doc_start"] = entry.orig_doc_start
                output["orig_doc_end"] = entry.orig_doc_end
                nbest_json[kk].append(output)

        assert len(nbest_json['short']) >= 1
        assert len(nbest_json['long']) >= 1

        # We use the [CLS] score of the crop that has the maximum positive score
        # long_score_diff = min_long_score_null - long_best_non_null.start_logit
        # Predict "" if null score - the score of best non-null > threshold
        try:
            crop_unique_id = crops[short_best_non_null.crop_index].unique_id
            start_score_null = unique_id_to_result[crop_unique_id].start_logits[CLS_INDEX]
            end_score_null = unique_id_to_result[crop_unique_id].end_logits[CLS_INDEX]
            short_score_null = start_score_null + end_score_null
            short_score_diff = short_score_null - (short_best_non_null.start_logit +
                short_best_non_null.end_logit)

            if short_score_diff > short_null_score_diff:
                final_pred = ("", -1, -1)
                short_num_empty += 1
            else:
                final_pred = (short_best_non_null.text, short_best_non_null.orig_doc_start,
                    short_best_non_null.orig_doc_end)
        except Exception as e:
            print(e)
            final_pred = ("", -1, -1)
            short_num_empty += 1

        try:
            long_score_null = unique_id_to_result[crops[
                long_best_non_null.crop_index].unique_id].long_logits[CLS_INDEX]
            long_score_diff = long_score_null - long_best_non_null.start_logit
            scores_diff_json[example.qas_id] = {'short_score_diff': short_score_diff,
                'long_score_diff': long_score_diff}

            if long_score_diff > long_null_score_diff:
                final_pred += ("", -1)
                long_num_empty += 1
                # print(f"LONG EMPTY: {round(long_score_null, 2)} vs "
                #     f"{round(long_best_non_null.start_logit, 2)} (th {long_null_score_diff})")

            else:
                final_pred += (long_best_non_null.text, long_best_non_null.orig_doc_start)

        except Exception as e:
            print(e)
            final_pred += ("", -1)
            long_num_empty += 1

        all_predictions[example.qas_id] = final_pred
        all_nbest_json[example.qas_id] = nbest_json

    if output_prediction_file is not None:
        with open(output_prediction_file, "w") as writer:
            writer.write(json.dumps(all_predictions, indent=2))

    if output_nbest_file is not None:
        with open(output_nbest_file, "w") as writer:
            writer.write(json.dumps(all_nbest_json, indent=2))

    if output_null_log_odds_file is not None:
        with open(output_null_log_odds_file, "w") as writer:
            writer.write(json.dumps(scores_diff_json, indent=2))

    logger.info(f'{short_num_empty} short and {long_num_empty} long empty of'
        f' {example_index}')
    return all_predictions

In [13]:
def convert_preds_to_df(preds, candidates):
  num_found_long, num_searched_long = 0, 0
  df = {'example_id': [], 'PredictionString': []}
  for example_id, pred in preds.items():
    short_text, start_token, end_token, long_text, long_token = pred
    df['example_id'].append(example_id + '_short')
    short_answer = ''
    if start_token != -1:
      # +1 is required to make the token inclusive
      short_answer = f'{start_token}:{end_token + 1}'
    df['PredictionString'].append(short_answer)

    # print(entry['document_text'].split(' ')[start_token: end_token + 1])
    # find the long answer
    long_answer = ''
    found_long = False
    min_dist = 1_000_000
    if long_token != -1:
      num_searched_long += 1
      for candidate in candidates[example_id]:
        cstart, cend = candidate.start_token, candidate.end_token
        dist = abs(cstart - long_token)
        if dist < min_dist:
          min_dist = dist
        if long_token == cstart:
          long_answer = f'{cstart}:{cend}'
          found_long = True
          break

      if found_long:
        num_found_long += 1
      else:
        logger.info(f"Not found: {min_dist}")

    df['example_id'].append(example_id + '_long')
    df['PredictionString'].append(long_answer)

  df = pd.DataFrame(df)
  print(f'Found {num_found_long} of {num_searched_long} (total {len(preds)})')
  return df

## here is our Model !!

In [14]:
import tensorflow as tf
from tensorflow.keras import layers as L

from transformers import TFBertMainLayer, TFBertPreTrainedModel, TFRobertaMainLayer, TFRobertaPreTrainedModel
from transformers.modeling_tf_utils import get_initializer


class TFBertForNaturalQuestionAnswering(TFBertPreTrainedModel):
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.num_labels = config.num_labels

        self.bert = TFBertMainLayer(config, name='bert')
        self.initializer = get_initializer(config.initializer_range)
        self.qa_outputs = L.Dense(config.num_labels,
            kernel_initializer=self.initializer, name='qa_outputs')
        self.long_outputs = L.Dense(1, kernel_initializer=self.initializer,
            name='long_outputs')

    def call(self, inputs, **kwargs):
        outputs = self.bert(inputs, **kwargs)
        sequence_output = outputs[0]
        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = tf.split(logits, 2, axis=-1)
        start_logits = tf.squeeze(start_logits, -1)
        end_logits = tf.squeeze(end_logits, -1)
        long_logits = tf.squeeze(self.long_outputs(sequence_output), -1)
        return start_logits, end_logits, long_logits


class TFRobertaForNaturalQuestionAnswering(TFRobertaPreTrainedModel):
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.num_labels = config.num_labels

        self.roberta = TFRobertaMainLayer(config, name='roberta')
        self.initializer = get_initializer(config.initializer_range)
        self.qa_outputs = L.Dense(config.num_labels,
            kernel_initializer=self.initializer, name='qa_outputs')
        self.long_outputs = L.Dense(1, kernel_initializer=self.initializer,
            name='long_outputs')

    def call(self, inputs, **kwargs):
        outputs = self.roberta(inputs, **kwargs)
        sequence_output = outputs[0]
        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = tf.split(logits, 2, axis=-1)
        start_logits = tf.squeeze(start_logits, -1)
        end_logits = tf.squeeze(end_logits, -1)
        long_logits = tf.squeeze(self.long_outputs(sequence_output), -1)
        return start_logits, end_logits, long_logits

In [15]:
def enumerate_tags(text_split):
  """Reproduce the preprocessing from:
  A BERT Baseline for the Natural Questions (https://arxiv.org/pdf/1901.08634.pdf)

  We introduce special markup tokens in the doc-ument  to  give  the  model
  a  notion  of  which  part of the document it is reading.  The special
  tokens we introduced are of the form “[Paragraph=N]”,“[Table=N]”, and “[List=N]”
  at the beginning ofthe N-th paragraph,  list and table respectively
  in the document. This decision was based on the observation that the first
  few paragraphs and tables in the document are much more likely than the rest
  of the document to contain the annotated answer and so the model could benefit
  from knowing whether it is processing one of these passages.

  We deviate as follows: Tokens are only created for the first 10 times. All other
  tokens are the same. We only add `special_tokens`. These two are added as they
  make 72.9% + 19.0% = 91.9% of long answers.
  (https://github.com/google-research-datasets/natural-questions)
  """
  special_tokens = ['<P>', '<Table>']
  special_token_counts = [0 for _ in range(len(special_tokens))]
  for index, token in enumerate(text_split):
    for special_token_index, special_token in enumerate(special_tokens):
      if token == special_token:
        cnt = special_token_counts[special_token_index]
        if cnt <= 10:
          text_split[index] = f'<{special_token[1: -1]}{cnt}>'
        special_token_counts[special_token_index] = cnt + 1

  return text_split

## let's fly to stuff can turn natural questions to SQUAD !!

In [16]:
def convert_nq_to_squad(args=None):
  np.random.seed(123)
  if args is None:
    parser = argparse.ArgumentParser()
    parser.add_argument('--fn', type=str, default='simplified-nq-train.jsonl')
    parser.add_argument('--version', type=str, default='v1.0.2')
    parser.add_argument('--prefix', type=str, default='nq')
    parser.add_argument('--p_val', type=float, default=0.1)
    parser.add_argument('--crop_len', type=int, default=2_500)
    parser.add_argument('--num_samples', type=int, default=1_000_000)
    parser.add_argument('--val_ids', type=str, default='val_ids.csv')
    parser.add_argument('--do_enumerate', action='store_true')
    parser.add_argument('--do_not_dump', action='store_true')
    parser.add_argument('--num_max_tokens', type=int, default=400_000)
    args = parser.parse_args()

  is_train = 'train' in args.fn
  if is_train:
    train_fn = f'{args.prefix}-train-{args.version}.json'
    val_fn = f'{args.prefix}-val-{args.version}.json'
    print(f'Converting {args.fn} to {train_fn} & {val_fn} ... ')
  else:
    test_fn = f'{args.prefix}-test-{args.version}.json'
    print(f'Converting {args.fn} to {test_fn} ... ')

  if args.val_ids:
    val_ids = set(str(x) for x in pd.read_csv(args.val_ids)['val_ids'].values)
  else:
    val_ids = set()

  entries = []
  smooth = 0.999
  total_split_len, long_split_len = 0., 0.
  long_end = 0.
  num_very_long, num_yes_no, num_short_dropped, num_trimmed = 0, 0, 0, 0
  num_short_possible, num_long_possible = 0, 0
  max_end_token = -1
  orig_data = {}
  with open(args.fn) as f:
    progress = tqdm(f, total=args.num_samples)
    entry = {}
    for kk, line in enumerate(progress):
      if kk >= args.num_samples:
        break

      data = json.loads(line)
      data_cpy = data.copy()
      example_id = str(data_cpy.pop('example_id'))
      data_cpy['document_text'] = ''
      orig_data[example_id] = data_cpy
      url = 'MISSING' if not is_train else data['document_url']
      # progress.write(f'############ {url} ###############')
      document_text = data['document_text']
      document_text_split = document_text.split(' ')
      # trim super long
      if len(document_text_split) > args.num_max_tokens:
        num_trimmed += 1
        document_text_split = document_text_split[:args.num_max_tokens]

      if args.do_enumerate:
        document_text_split = enumerate_tags(document_text_split)
      question = data['question_text']  # + '?'
      annotations = [None] if not is_train else data['annotations']
      assert len(annotations) == 1, annotations
      # User str keys!
      example_id = str(data['example_id'])
      candidates = data['long_answer_candidates']
      if not is_train:
        qa = {'question': question, 'id': example_id, 'crop_start': 0}
        context = ' '.join(document_text_split)

      else:
        long_answer = annotations[0]['long_answer']
        long_answer_len = long_answer['end_token'] - long_answer['start_token']
        total_split_len = smooth * total_split_len + (1. - smooth) * len(
            document_text_split)
        long_split_len = smooth * long_split_len + (1. - smooth) * \
            long_answer_len
        if long_answer['end_token'] > 0:
          long_end = smooth * long_end + (1. - smooth) * long_answer['end_token']

        if long_answer['end_token'] > max_end_token:
          max_end_token = long_answer['end_token']

        progress.set_postfix({'ltotal': int(total_split_len),
            'llong': int(long_split_len), 'long_end': round(long_end, 2)})

        short_answers = annotations[0]['short_answers']
        yes_no_answer = annotations[0]['yes_no_answer']
        if yes_no_answer != 'NONE':
          # progress.write(f'Skipping yes-no: {yes_no_answer}')
          num_yes_no += 1
          continue

        # print(f'Q: {question}')
        # print(f'L: {long_answer_str}')
        long_is_impossible = long_answer['start_token'] == -1
        if long_is_impossible:
          long_answer_candidate = np.random.randint(len(candidates))
        else:
          long_answer_candidate = long_answer['candidate_index']

        long_start_token = candidates[long_answer_candidate]['start_token']
        long_end_token = candidates[long_answer_candidate]['end_token']
        # generate crop based on tokens. Note that validation samples should
        # not be cropped as this won't reflect test set performance.
        if args.crop_len > 0 and example_id not in val_ids:
          crop_start = long_start_token - np.random.randint(int(args.crop_len * 0.75))
          if crop_start <= 0:
            crop_start = 0
            crop_start_len = -1
          else:
            crop_start_len = len(' '.join(document_text_split[:crop_start]))

          crop_end = crop_start + args.crop_len
        else:
          crop_start = 0
          crop_start_len = -1
          crop_end = 10_000_000

        is_very_long = False
        if long_end_token > crop_end:
          num_very_long += 1
          is_very_long = True
          # progress.write(f'{num_very_long}: Skipping very long answer {long_end_token}, {crop_end}')
          # continue

        document_text_crop_split = document_text_split[crop_start: crop_end]
        context = ' '.join(document_text_crop_split)
        # create long answer
        long_answers_ = []
        if not long_is_impossible:
          long_answer_pre_split = document_text_split[:long_answer[
              'start_token']]
          long_answer_start = len(' '.join(long_answer_pre_split)) - \
              crop_start_len
          long_answer_split = document_text_split[long_answer['start_token']:
              long_answer['end_token']]
          long_answer_text = ' '.join(long_answer_split)
          if not is_very_long:
            assert context[long_answer_start: long_answer_start + len(
                long_answer_text)] == long_answer_text, long_answer_text
          long_answers_ = [{'text': long_answer_text,
              'answer_start': long_answer_start}]

        # create short answers
        short_is_impossible = len(short_answers) == 0
        short_answers_ = []
        if not short_is_impossible:
          for short_answer in short_answers:
            short_start_token = short_answer['start_token']
            short_end_token = short_answer['end_token']
            if short_start_token >= crop_start + args.crop_len:
              num_short_dropped += 1
              continue
            short_answers_pre_split = document_text_split[:short_start_token]
            short_answer_start = len(' '.join(short_answers_pre_split)) - \
                crop_start_len
            short_answer_split = document_text_split[short_start_token: short_end_token]
            short_answer_text = ' '.join(short_answer_split)
            assert short_answer_text != ''

            # this happens if we crop and parts of the short answer overflow
            short_from_context = context[short_answer_start: short_answer_start + len(short_answer_text)]
            if short_from_context != short_answer_text:
              print(f'short diff: {short_from_context} vs {short_answer_text}')
            short_answers_.append({'text': short_from_context,
                'answer_start': short_answer_start})

        if len(short_answers_) == 0:
          short_is_impossible = True

        if not short_is_impossible:
          num_short_possible += 1
        if not long_is_impossible:
          num_long_possible += 1

        qa = {'question': question,
            'short_answers': short_answers_, 'long_answers': long_answers_,
            'id': example_id, 'short_is_impossible': short_is_impossible,
            'long_is_impossible': long_is_impossible,
            'crop_start': crop_start}

      paragraph = {'qas': [qa], 'context': context}
      entry = {'title': url, 'paragraphs': [paragraph]}
      entries.append(entry)

  progress.write('  ------------ STATS ------------------')
  progress.write(f'  Found {num_yes_no} yes/no, {num_very_long} very long'
      f' and {num_short_dropped} short of {kk} and trimmed {num_trimmed}')
  progress.write(f'  #short {num_short_possible} #long {num_long_possible}'
      f' of {len(entries)}')
  
  # shuffle to test remaining code
  np.random.shuffle(entries)

  if is_train:
    train_entries, val_entries = [], []
    for entry in entries:
      if entry['paragraphs'][0]['qas'][0]['id'] not in val_ids:
        train_entries.append(entry)
      else:
        val_entries.append(entry)

    for out_fn, entries in [(train_fn, train_entries), (val_fn, val_entries)]:
      if not args.do_not_dump:
        with open(out_fn, 'w') as f:
          json.dump({'version': args.version, 'data': entries}, f)
        progress.write(f'Wrote {len(entries)} entries to {out_fn}')

#       save val in competition csv format
      if 'val' in out_fn:
        val_example_ids, val_strs = [], []
        for entry in entries:
          example_id = entry['paragraphs'][0]['qas'][0]['id']
          short_answers = orig_data[example_id]['annotations'][0][
              'short_answers']
          sa_str = ''
          for si, sa in enumerate(short_answers):
            sa_str += f'{sa["start_token"]}:{sa["end_token"]}'
            if si < len(short_answers) - 1:
              sa_str += ' '
          val_example_ids.append(example_id + '_short')
          val_strs.append(sa_str)

          la = orig_data[example_id]['annotations'][0][
              'long_answer']
          la_str = ''
          if la['start_token'] > 0:
            la_str += f'{la["start_token"]}:{la["end_token"]}'
          val_example_ids.append(example_id + '_long')
          val_strs.append(la_str)

        val_df = pd.DataFrame({'example_id': val_example_ids,
            'PredictionString': val_strs})
        val_csv_fn = f'{args.prefix}-val-{args.version}.csv'
        val_df.to_csv(val_csv_fn, index=False, columns=['example_id',
            'PredictionString'])
        print(f'Wrote csv to {val_csv_fn}')

  else:
    if not args.do_not_dump:
      with open(test_fn, 'w') as f:
        json.dump({'version': args.version, 'data': entries}, f)
      progress.write(f'Wrote to {test_fn}')

  if args.val_ids:
    print(f'Using val ids from: {args.val_ids}')
  return entries

In [17]:
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys())
                  for conf in (BertConfig, )), ())

MODEL_CLASSES = {
    'bert': (BertConfig, TFBertForNaturalQuestionAnswering, BertTokenizer),
    'roberta': (RobertaConfig, TFRobertaForNaturalQuestionAnswering, RobertaTokenizer),
}


def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    tf.random.set_seed(args.seed)

## below stuff helps us to make submissions !!!

In [18]:
def submit(args, model, tokenizer):
    csv_fn = 'submission.csv'
    # all_input_ids, all_attention_mask, all_token_type_ids, all_p_mask
    eval_dataset, crops, entries  = load_and_cache_crops(args, tokenizer, evaluate=True)
    args.eval_batch_size = args.per_tpu_eval_batch_size

    # pad dataset to multiple of `args.eval_batch_size`
    eval_dataset_length = len(eval_dataset[0])
    padded_length = math.ceil(eval_dataset_length / args.eval_batch_size) * args.eval_batch_size
    num_pad = padded_length - eval_dataset[0].shape[0]
    for ti, t in enumerate(eval_dataset):
        pad_tensor = tf.expand_dims(tf.zeros_like(t[0]), 0)
        pad_tensor = tf.repeat(pad_tensor, num_pad, 0)
        eval_dataset[ti] = tf.concat([t, pad_tensor], 0)

    # create eval dataset
    eval_ds = tf.data.Dataset.from_tensor_slices({
        'input_ids': tf.constant(eval_dataset[0]),
        'attention_mask': tf.constant(eval_dataset[1]),
        'token_type_ids': tf.constant(eval_dataset[2]),
        'example_index': tf.range(padded_length, dtype=tf.int32)

    })
    eval_ds = eval_ds.batch(batch_size=args.eval_batch_size, drop_remainder=True)
    # eval_ds = eval_ds.prefetch(tf.data.experimental.AUTOTUNE)
    # eval_ds = strategy.experimental_distribute_dataset(eval_ds)

    # eval
    print("***** Running evaluation *****")
    print("  Num examples =  ", eval_dataset_length)
    print("  Batch size =  ", args.eval_batch_size)

    @tf.function
    def predict_step(batch):
        outputs = model(batch, training=False)
        return outputs

    all_results = []
    tic = time.time()
    for batch_ind, batch in tqdm(enumerate(eval_ds), total=padded_length // args.per_tpu_eval_batch_size):
        example_indexes = batch['example_index']
        outputs = predict_step(batch)
        batched_start_logits = outputs[0].numpy()
        batched_end_logits = outputs[1].numpy()
        batched_long_logits = outputs[2].numpy()
        for i, example_index in enumerate(example_indexes):
            # filter out padded samples
            if example_index >= eval_dataset_length:
                continue

            eval_crop = crops[example_index]
            unique_id = int(eval_crop.unique_id)
            start_logits = batched_start_logits[i].tolist()
            end_logits = batched_end_logits[i].tolist()
            long_logits = batched_long_logits[i].tolist()

            result = RawResult(unique_id=unique_id,
                               start_logits=start_logits,
                               end_logits=end_logits,
                               long_logits=long_logits)
            all_results.append(result)

    eval_time = time.time() - tic
    print("  Evaluation done in total %f secs (%f sec per example)",
        eval_time, eval_time / padded_length)
    examples_gen = read_nq_examples(entries, is_training=False)
    preds = write_predictions(examples_gen, crops, all_results, args.n_best_size,
                              args.max_answer_length,
                              None, None, None,
                              args.verbose_logging,
                              args.short_null_score_diff_threshold, args.long_null_score_diff_threshold)
    del crops, all_results
    gc.collect()
    candidates = read_candidates(['../input/TensorFlow-2.0-Question-Answering/simplified-nq-test.jsonl'], do_cache=False)
    sub = convert_preds_to_df(preds, candidates).sort_values('example_id')
    sub.to_csv(csv_fn, index=False, columns=['example_id', 'PredictionString'])
    print(f'***** Wrote submission to {csv_fn} *****')
    result = {}
    return result

In [19]:
def get_convert_args():
    convert_args = argparse.Namespace()
    convert_args.fn = '../input/TensorFlow-2.0-Question-Answering/simplified-nq-test.jsonl'
    convert_args.version = 'v0.0.1'
    convert_args.prefix = 'nq'
    convert_args.num_samples = 1_000_000
    convert_args.val_ids = None
    convert_args.do_enumerate = False
    convert_args.do_not_dump = True
    convert_args.num_max_tokens = 400_000
    return convert_args

In [20]:
def load_and_cache_crops(args, tokenizer, evaluate=False):
    # Load data crops from cache or dataset file
    do_cache = False
    cached_crops_fn = 'cached_{}_{}.pkl'.format('test', str(args.max_seq_length))
    if os.path.exists(cached_crops_fn) and do_cache:
        print("Loading crops from cached file %s", cached_crops_fn)
        with open(cached_crops_fn, "rb") as f:
            crops = pickle.load(f)
    else:
        entries = convert_nq_to_squad(get_convert_args())
        examples_gen = read_nq_examples(entries, is_training=not evaluate)
        crops = convert_examples_to_crops(examples_gen=examples_gen,
                                          tokenizer=tokenizer,
                                          max_seq_length=args.max_seq_length,
                                          doc_stride=args.doc_stride,
                                          max_query_length=args.max_query_length,
                                          is_training=not evaluate,
                                          cls_token_segment_id=0,
                                          pad_token_segment_id=0,
                                          p_keep_impossible=args.p_keep_impossible if not evaluate else 1.0)
        if do_cache:
            with open(cached_crops_fn, "wb") as f:
                pickle.dump(crops, f)

    # stack
    all_input_ids = tf.stack([c.input_ids for c in crops], 0)
    all_attention_mask = tf.stack([c.attention_mask for c in crops], 0)
    all_token_type_ids = tf.stack([c.token_type_ids for c in crops], 0)

    if evaluate:
        dataset = [all_input_ids, all_attention_mask, all_token_type_ids]
    else:
        all_start_positions = tf.convert_to_tensor([f.start_position for f in crops], dtype=tf.int32)
        all_end_positions = tf.convert_to_tensor([f.end_position for f in crops], dtype=tf.int32)
        all_long_positions = tf.convert_to_tensor([f.long_position for f in crops], dtype=tf.int32)
        dataset = [all_input_ids, all_attention_mask, all_token_type_ids,
            all_start_positions, all_end_positions, all_long_positions]

    return dataset, crops, entries

In [21]:
# def main():
#     parser = argparse.ArgumentParser()

#     # Required parameters
#     parser.add_argument("--model_type", default="bert", type=str)
#     parser.add_argument("--model_config",
#         default="../input/transformers-cache/bert_large_uncased_config.json", type=str)
#     parser.add_argument("--checkpoint_dir", default="../input/nq-bert-uncased-68", type=str)
#     parser.add_argument("--vocab_txt", default="../input/transformers-cache/bert_large_uncased_vocab.txt", type=str)

#     # Other parameters
#     parser.add_argument('--short_null_score_diff_threshold', type=float, default=0.0)
#     parser.add_argument('--long_null_score_diff_threshold', type=float, default=0.0)
#     parser.add_argument("--max_seq_length", default=512, type=int)
#     parser.add_argument("--doc_stride", default=256, type=int)
#     parser.add_argument("--max_query_length", default=64, type=int)
#     parser.add_argument("--per_tpu_eval_batch_size", default=4, type=int)
#     parser.add_argument("--n_best_size", default=10, type=int)
#     parser.add_argument("--max_answer_length", default=30, type=int)
#     parser.add_argument("--verbose_logging", action='store_true')
#     parser.add_argument('--seed', type=int, default=42)
#     parser.add_argument('--p_keep_impossible', type=float,
#                         default=0.1, help="The fraction of impossible"
#                         " samples to keep.")
#     parser.add_argument('--do_enumerate', action='store_true')

#     args, _ = parser.parse_known_args()
#     assert args.model_type not in ('xlnet', 'xlm'), f'Unsupported model_type: {args.model_type}'

#     # Set seed
#     set_seed(args)

#     # Set cased / uncased
#     config_basename = os.path.basename(args.model_config)
#     if config_basename.startswith('bert'):
#         do_lower_case = 'uncased' in config_basename
#     elif config_basename.startswith('roberta'):
#         # https://github.com/huggingface/transformers/pull/1386/files
#         do_lower_case = False

#     # Set XLA
#     # https://github.com/kamalkraj/ALBERT-TF2.0/blob/8d0cc211361e81a648bf846d8ec84225273db0e4/run_classifer.py#L136
#     tf.config.optimizer.set_jit(True)
#     tf.config.optimizer.set_experimental_options({'pin_to_host_optimization': False})

#     print("Training / evaluation parameters ", args)
#     args.model_type = args.model_type.lower()
#     config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
#     config = config_class.from_json_file(args.model_config)
#     tokenizer = tokenizer_class(args.vocab_txt, do_lower_case=do_lower_case)
#     tags = get_add_tokens(do_enumerate=args.do_enumerate)
#     num_added = tokenizer.add_tokens(tags, offset=1)
#     print(f"Added {num_added} tokens")
#     print("Evaluate the following checkpoint: ", args.checkpoint_dir)
#     weights_fn = os.path.join(args.checkpoint_dir, 'weights.h5')
#     model = model_class(config)
#     model(model.dummy_inputs, training=False)
#     model.load_weights(weights_fn)

#     # Evaluate
#     result = submit(args, model, tokenizer)
#     print("Result: {}".format(result))


# main()

In [22]:
parser = argparse.ArgumentParser()

# Required parameters
parser.add_argument("--model_type", default="bert", type=str)
parser.add_argument("--model_config",
    default="../input/transformers_cache/bert_large_uncased_config.json", type=str)
parser.add_argument("--checkpoint_dir", default="../input/nq_bert_uncased_68", type=str)
parser.add_argument("--vocab_txt", default="../input/transformers_cache/bert_large_uncased_vocab.txt", type=str)

# Other parameters
parser.add_argument('--short_null_score_diff_threshold', type=float, default=0.0)
parser.add_argument('--long_null_score_diff_threshold', type=float, default=0.0)
parser.add_argument("--max_seq_length", default=512, type=int)
parser.add_argument("--doc_stride", default=256, type=int)
parser.add_argument("--max_query_length", default=64, type=int)
parser.add_argument("--per_tpu_eval_batch_size", default=4, type=int)
parser.add_argument("--n_best_size", default=10, type=int)
parser.add_argument("--max_answer_length", default=30, type=int)
parser.add_argument("--verbose_logging", action='store_true')
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--p_keep_impossible', type=float,
                    default=0.1, help="The fraction of impossible"
                    " samples to keep.")
parser.add_argument('--do_enumerate', action='store_true')

args, _ = parser.parse_known_args()
args

Namespace(checkpoint_dir='../input/nq_bert_uncased_68', do_enumerate=False, doc_stride=256, long_null_score_diff_threshold=0.0, max_answer_length=30, max_query_length=64, max_seq_length=512, model_config='../input/transformers_cache/bert_large_uncased_config.json', model_type='bert', n_best_size=10, p_keep_impossible=0.1, per_tpu_eval_batch_size=4, seed=42, short_null_score_diff_threshold=0.0, verbose_logging=False, vocab_txt='../input/transformers_cache/bert_large_uncased_vocab.txt')

In [23]:
set_seed(args)

# Set cased / uncased
config_basename = os.path.basename(args.model_config)
if config_basename.startswith('bert'):
    do_lower_case = 'uncased' in config_basename
elif config_basename.startswith('roberta'):
    # https://github.com/huggingface/transformers/pull/1386/files
    do_lower_case = False

# Set XLA
# https://github.com/kamalkraj/ALBERT-TF2.0/blob/8d0cc211361e81a648bf846d8ec84225273db0e4/run_classifer.py#L136
tf.config.optimizer.set_jit(True)
tf.config.optimizer.set_experimental_options({'pin_to_host_optimization': False})

print("Training / evaluation parameters ", args)
args.model_type = args.model_type.lower()
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
config = config_class.from_json_file(args.model_config)
tokenizer = tokenizer_class(args.vocab_txt, do_lower_case=do_lower_case)
tags = get_add_tokens(do_enumerate=args.do_enumerate)
num_added = tokenizer.add_tokens(tags, offset=1)
print(f"Added {num_added} tokens")
print("Evaluate the following checkpoint: ",args.checkpoint_dir)
weights_fn = os.path.join(args.checkpoint_dir, 'weights.h5')
model = model_class(config)
model(model.dummy_inputs, training=False)
model.load_weights(weights_fn)

Training / evaluation parameters  Namespace(checkpoint_dir='../input/nq_bert_uncased_68', do_enumerate=False, doc_stride=256, long_null_score_diff_threshold=0.0, max_answer_length=30, max_query_length=64, max_seq_length=512, model_config='../input/transformers_cache/bert_large_uncased_config.json', model_type='bert', n_best_size=10, p_keep_impossible=0.1, per_tpu_eval_batch_size=4, seed=42, short_null_score_diff_threshold=0.0, verbose_logging=False, vocab_txt='../input/transformers_cache/bert_large_uncased_vocab.txt')
Added 33 tokens
Evaluate the following checkpoint:  ../input/nq_bert_uncased_68


## Everything from scratch

In [27]:
eval_dataset, crops, entries  = load_and_cache_crops(args, tokenizer, evaluate=True)
eval_dataset

Converting ../input/TensorFlow-2.0-Question-Answering/simplified-nq-test.jsonl to nq-test-v0.0.1.json ... 


HBox(children=(FloatProgress(value=0.0, max=1000000.0), HTML(value='')))


  ------------ STATS ------------------
  Found 0 yes/no, 0 very long and 0 short of 345 and trimmed 0
  #short 0 #long 0 of 346


HBox(children=(FloatProgress(value=0.0, max=346.0), HTML(value='')))




[<tf.Tensor: id=8164, shape=(13135, 512), dtype=int32, numpy=
 array([[  101,  2054,  3185, ...,  1019,  2156,   102],
        [  101,  2054,  3185, ...,  3336,     4,   102],
        [  101,  2054,  3185, ...,    29,    14,   102],
        ...,
        [  101,  2043,  2106, ...,    12,    26,   102],
        [  101,  2043,  2106, ..., 25592, 28748,   102],
        [  101,  2043,  2106, ...,     0,     0,     0]])>,
 <tf.Tensor: id=8165, shape=(13135, 512), dtype=bool, numpy=
 array([[ True,  True,  True, ...,  True,  True,  True],
        [ True,  True,  True, ...,  True,  True,  True],
        [ True,  True,  True, ...,  True,  True,  True],
        ...,
        [ True,  True,  True, ...,  True,  True,  True],
        [ True,  True,  True, ...,  True,  True,  True],
        [ True,  True,  True, ..., False, False, False]])>,
 <tf.Tensor: id=8166, shape=(13135, 512), dtype=int32, numpy=
 array([[0, 0, 0, ..., 1, 1, 1],
        [0, 0, 0, ..., 1, 1, 1],
        [0, 0, 0, ..., 1, 1, 1],


In [28]:
# crops[3]

In [29]:
# type(examples_gen)

In [30]:
preference_index = 25
examples_gen = read_nq_examples(entries, is_training=False)
for _ in range(preference_index+1) :
    example = next(examples_gen)
example

HBox(children=(FloatProgress(value=0.0, max=346.0), HTML(value='')))

NQExample(qas_id='-1560939115027122612', question_text='why do female tarantulas live longer than males', doc_tokens=['Tarantula', '-', 'wikipedia', '<H1>', 'Tarantula', '</H1>', 'This', 'article', 'is', 'about', 'the', 'spider', 'family', 'Theraphosidae', '.', 'For', 'the', 'European', 'tarantula', 'wolf', 'spider', ',', 'see', 'Lycosa', 'tarantula', '.', 'For', 'other', 'uses', ',', 'see', 'Tarantula', '(', 'disambiguation', ')', '.', 'Not', 'to', 'be', 'confused', 'with', 'tarantella', '.', '<P>', '</P>', '<Table>', '<Tr>', '<Td>', '</Td>', '<Td>', 'This', 'article', 'needs', 'additional', 'citations', 'for', 'verification', '.', 'Please', 'help', 'improve', 'this', 'article', 'by', 'adding', 'citations', 'to', 'reliable', 'sources', '.', 'Unsourced', 'material', 'may', 'be', 'challenged', 'and', 'removed', '.', '(', 'February', '2016', ')', '(', 'Learn', 'how', 'and', 'when', 'to', 'remove', 'this', 'template', 'message', ')', '</Td>', '</Tr>', '</Table>', '<P>', '</P>', '<Table>',

In [31]:
args.eval_batch_size,eval_dataset_length=1,len(eval_dataset[0])
args.eval_batch_size,eval_dataset_length

(1, 13135)

# 762

In [32]:
i = 0
for crop in crops :
    if crop.example_index <= preference_index :

        if crop.example_index == preference_index :
            i += 1
        else :
            continue
    else :
        break
i

47

In [33]:
anfangen , fertig = 762,762+i

In [34]:
# create eval dataset
eval_ds = tf.data.Dataset.from_tensor_slices({
    'input_ids': tf.constant(eval_dataset[0][anfangen:fertig]),
    'attention_mask': tf.constant(eval_dataset[1][anfangen:fertig]),
    'token_type_ids': tf.constant(eval_dataset[2][anfangen:fertig]),
    'example_index': tf.range(anfangen,fertig, dtype=tf.int32)

})
eval_ds = eval_ds.batch(batch_size=args.eval_batch_size, drop_remainder=True)
eval_ds

<BatchDataset shapes: {input_ids: (1, 512), attention_mask: (1, 512), token_type_ids: (1, 512), example_index: (1,)}, types: {input_ids: tf.int32, attention_mask: tf.bool, token_type_ids: tf.int32, example_index: tf.int32}>

In [35]:
@tf.function
def predict_step(batch):
    outputs = model(batch, training=False)
    return outputs

In [36]:
all_results = []
tic = time.time()
for batch_ind, batch in tqdm(enumerate(eval_ds), total=fertig-anfangen):
    example_indexes = batch['example_index']
    outputs = predict_step(batch)
    batched_start_logits = outputs[0].numpy()
    batched_end_logits = outputs[1].numpy()
    batched_long_logits = outputs[2].numpy()
    for i, example_index in enumerate(example_indexes):
        # filter out padded samples
        if example_index >= eval_dataset_length:
            continue

        eval_crop = crops[example_index]
        unique_id = int(eval_crop.unique_id)
        print(f"unique_id {unique_id} of exmaple_index {example_index} was found")
        start_logits = batched_start_logits[i].tolist()
        end_logits = batched_end_logits[i].tolist()
        long_logits = batched_long_logits[i].tolist()

        result = RawResult(unique_id=unique_id,
                           start_logits=start_logits,
                           end_logits=end_logits,
                           long_logits=long_logits)
        all_results.append(result)

eval_time = time.time() - tic
eval_time

HBox(children=(FloatProgress(value=0.0, max=47.0), HTML(value='')))

unique_id 1000000762 was found
unique_id 1000000763 was found
unique_id 1000000764 was found
unique_id 1000000765 was found
unique_id 1000000766 was found
unique_id 1000000767 was found
unique_id 1000000768 was found
unique_id 1000000769 was found
unique_id 1000000770 was found
unique_id 1000000771 was found
unique_id 1000000772 was found
unique_id 1000000773 was found
unique_id 1000000774 was found
unique_id 1000000775 was found
unique_id 1000000776 was found
unique_id 1000000777 was found
unique_id 1000000778 was found
unique_id 1000000779 was found
unique_id 1000000780 was found
unique_id 1000000781 was found
unique_id 1000000782 was found
unique_id 1000000783 was found
unique_id 1000000784 was found
unique_id 1000000785 was found
unique_id 1000000786 was found
unique_id 1000000787 was found
unique_id 1000000788 was found
unique_id 1000000789 was found
unique_id 1000000790 was found
unique_id 1000000791 was found
unique_id 1000000792 was found
unique_id 1000000793 was found
unique_i

805.3355777263641

In [37]:
example_index_to_crops = collections.defaultdict(list)
for crop in crops[anfangen:fertig] :
    example_index_to_crops[crop.example_index].append(crop)
example_index_to_crops

defaultdict(list,
            {25: [Crop(example_id='-1560939115027122612', unique_id=1000000762, example_index=25, doc_span_index=0, tokens=['[CLS]', 'why', 'do', 'female', 'tara', '##nt', '##ula', '##s', 'live', 'longer', 'than', 'males', '[SEP]', 'tara', '##nt', '##ula', '-', 'wikipedia', '<H1>', 'tara', '##nt', '##ula', '</H1>', 'this', 'article', 'is', 'about', 'the', 'spider', 'family', 'the', '##raph', '##osi', '##dae', '.', 'for', 'the', 'european', 'tara', '##nt', '##ula', 'wolf', 'spider', ',', 'see', 'l', '##y', '##cos', '##a', 'tara', '##nt', '##ula', '.', 'for', 'other', 'uses', ',', 'see', 'tara', '##nt', '##ula', '(', 'di', '##sam', '##bi', '##gua', '##tion', ')', '.', 'not', 'to', 'be', 'confused', 'with', 'tara', '##nte', '##lla', '.', '<P>', '</P>', '<Table>', '<Tr>', '<Td>', '</Td>', '<Td>', 'this', 'article', 'needs', 'additional', 'citations', 'for', 'verification', '.', 'please', 'help', 'improve', 'this', 'article', 'by', 'adding', 'citations', 'to', 'reliable', 

In [38]:
unique_id_to_result = {result.unique_id: result for result in all_results}
unique_id_to_result.keys()

dict_keys([1000000762, 1000000763, 1000000764, 1000000765, 1000000766, 1000000767, 1000000768, 1000000769, 1000000770, 1000000771, 1000000772, 1000000773, 1000000774, 1000000775, 1000000776, 1000000777, 1000000778, 1000000779, 1000000780, 1000000781, 1000000782, 1000000783, 1000000784, 1000000785, 1000000786, 1000000787, 1000000788, 1000000789, 1000000790, 1000000791, 1000000792, 1000000793, 1000000794, 1000000795, 1000000796, 1000000797, 1000000798, 1000000799, 1000000800, 1000000801, 1000000802, 1000000803, 1000000804, 1000000805, 1000000806, 1000000807, 1000000808])

In [39]:
all_predictions = collections.OrderedDict()
example_index = preference_index
n_best_size = 10
max_answer_length = 30

part_of_crops = example_index_to_crops[example_index]
short_prelim_predictions, long_prelim_predictions = [], []
for crop_index, crop in enumerate(part_of_crops):
    result = unique_id_to_result[crop.unique_id]
    # get the `n_best_size` largest indexes
    # https://stackoverflow.com/questions/6910641/how-do-i-get-indices-of-n-maximum-values-in-a-numpy-array#23734295
    start_indexes = np.argpartition(result.start_logits, -n_best_size)[-n_best_size:]
    start_indexes = [int(x) for x in start_indexes]
    end_indexes = np.argpartition(result.end_logits, -n_best_size)[-n_best_size:]
    end_indexes = [int(x) for x in end_indexes]

    # create short answers
    for start_index in start_indexes:
        if start_index >= len(crop.tokens):
            continue
        if crop.token_to_orig_map[start_index] == UNMAPPED:
            continue
        if not crop.token_is_max_context[start_index]:
            continue

        for end_index in end_indexes:
            if end_index >= len(crop.tokens):
                continue
            if crop.token_to_orig_map[end_index] == UNMAPPED:
                continue
            if end_index < start_index:
                continue
            length = end_index - start_index + 1
            if length > max_answer_length:
                continue

            short_prelim_predictions.append(PrelimPrediction(
                crop_index=crop_index,
                start_index=start_index,
                end_index=end_index,
                start_logit=result.start_logits[start_index],
                end_logit=result.end_logits[end_index]))

    long_indexes = np.argpartition(result.long_logits, -n_best_size)[-n_best_size:].tolist()
    for long_index in long_indexes:
        if long_index >= len(crop.tokens):
            continue
        if crop.token_to_orig_map[long_index] == UNMAPPED:
            continue
        if not crop.token_is_max_context[long_index]:
            continue
        long_prelim_predictions.append(PrelimPrediction(
            crop_index=crop_index,
            start_index=long_index, end_index=-1,
            start_logit=result.long_logits[long_index],
            end_logit=result.long_logits[long_index]))

short_prelim_predictions = sorted(short_prelim_predictions,
    key=lambda x: x.start_logit + x.end_logit, reverse=True)

short_prelim_predictions

[PrelimPrediction(crop_index=21, start_index=244, end_index=251, start_logit=5.561295032501221, end_logit=7.388420581817627),
 PrelimPrediction(crop_index=21, start_index=245, end_index=251, start_logit=5.364441394805908, end_logit=7.388420581817627),
 PrelimPrediction(crop_index=21, start_index=236, end_index=251, start_logit=4.8185296058654785, end_logit=7.388420581817627),
 PrelimPrediction(crop_index=21, start_index=246, end_index=251, start_logit=4.062827110290527, end_logit=7.388420581817627),
 PrelimPrediction(crop_index=21, start_index=244, end_index=252, start_logit=5.561295032501221, end_logit=3.1709303855895996),
 PrelimPrediction(crop_index=21, start_index=245, end_index=252, start_logit=5.364441394805908, end_logit=3.1709303855895996),
 PrelimPrediction(crop_index=21, start_index=236, end_index=252, start_logit=4.8185296058654785, end_logit=3.1709303855895996),
 PrelimPrediction(crop_index=21, start_index=238, end_index=251, start_logit=0.5436409115791321, end_logit=7.3884

In [40]:
short_nbest = get_nbest(short_prelim_predictions, part_of_crops,
    example, n_best_size)

short_nbest

[NbestPrediction(text='because they die relatively soon after maturing', start_logit=5.561295032501221, end_logit=7.388420581817627, start_index=244, end_index=251, orig_doc_start=4498, orig_doc_end=4504, crop_index=21),
 NbestPrediction(text='they die relatively soon after maturing', start_logit=5.364441394805908, end_logit=7.388420581817627, start_index=245, end_index=251, orig_doc_start=4499, orig_doc_end=4504, crop_index=21),
 NbestPrediction(text='males have much shorter lifespans than females because they die relatively soon after maturing', start_logit=4.8185296058654785, end_logit=7.388420581817627, start_index=236, end_index=251, orig_doc_start=4491, orig_doc_end=4504, crop_index=21),
 NbestPrediction(text='die relatively soon after maturing', start_logit=4.062827110290527, end_logit=7.388420581817627, start_index=246, end_index=251, orig_doc_start=4500, orig_doc_end=4504, crop_index=21),
 NbestPrediction(text='because they die relatively soon after maturing .', start_logit=5.

In [41]:
short_best_non_null = None
for entry in short_nbest:
    if short_best_non_null is None:
        if entry.text != "":
            short_best_non_null = entry
short_best_non_null

NbestPrediction(text='because they die relatively soon after maturing', start_logit=5.561295032501221, end_logit=7.388420581817627, start_index=244, end_index=251, orig_doc_start=4498, orig_doc_end=4504, crop_index=21)

In [42]:
long_prelim_predictions = sorted(long_prelim_predictions,
    key=lambda x: x.start_logit, reverse=True)
long_prelim_predictions

[PrelimPrediction(crop_index=21, start_index=177, end_index=-1, start_logit=9.943519592285156, end_logit=9.943519592285156),
 PrelimPrediction(crop_index=24, start_index=208, end_index=-1, start_logit=5.8917670249938965, end_logit=5.8917670249938965),
 PrelimPrediction(crop_index=22, start_index=383, end_index=-1, start_logit=5.884551048278809, end_logit=5.884551048278809),
 PrelimPrediction(crop_index=20, start_index=336, end_index=-1, start_logit=4.956899166107178, end_logit=4.956899166107178),
 PrelimPrediction(crop_index=22, start_index=271, end_index=-1, start_logit=4.336737632751465, end_logit=4.336737632751465),
 PrelimPrediction(crop_index=21, start_index=385, end_index=-1, start_logit=4.286718845367432, end_logit=4.286718845367432),
 PrelimPrediction(crop_index=23, start_index=219, end_index=-1, start_logit=4.286419868469238, end_logit=4.286419868469238),
 PrelimPrediction(crop_index=8, start_index=374, end_index=-1, start_logit=3.2629942893981934, end_logit=3.2629942893981934

In [43]:
long_nbest = get_nbest(long_prelim_predictions, part_of_crops,
    example, n_best_size)
long_nbest

[NbestPrediction(text="<P> a juvenile male ' s sex can be determined by", start_logit=9.943519592285156, end_logit=9.943519592285156, start_index=177, end_index=-1, orig_doc_start=4457, orig_doc_end=4467, crop_index=21),
 NbestPrediction(text='<P> females deposit 50 to 2000 eggs , depending on the', start_logit=5.8917670249938965, end_logit=5.8917670249938965, start_index=208, end_index=-1, orig_doc_start=5123, orig_doc_end=5133, crop_index=24),
 NbestPrediction(text='<P> females continue to molt after reaching maturity . female', start_logit=5.884551048278809, end_logit=5.884551048278809, start_index=383, end_index=-1, orig_doc_start=4829, orig_doc_end=4839, crop_index=22),
 NbestPrediction(text='<P> some tarantula species exhibit pronounced sexual dimor', start_logit=4.956899166107178, end_logit=4.956899166107178, start_index=336, end_index=-1, orig_doc_start=4377, orig_doc_end=4387, crop_index=20),
 NbestPrediction(text='<P> tarantulas may live for years ; most', start_logit=4.33673

In [44]:
long_best_non_null = None
for entry in long_nbest:
    if long_best_non_null is None :
        if entry.text != "":
            long_best_non_null = entry
long_best_non_null

NbestPrediction(text="<P> a juvenile male ' s sex can be determined by", start_logit=9.943519592285156, end_logit=9.943519592285156, start_index=177, end_index=-1, orig_doc_start=4457, orig_doc_end=4467, crop_index=21)

In [45]:
candidates = read_candidates(['../input/TensorFlow-2.0-Question-Answering/simplified-nq-test.jsonl'], do_cache=False)
candidates_of_example = candidates[example.qas_id]
candidates_of_example

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




[LongAnswerCandidate(start_token=45, end_token=96, top_level=True),
 LongAnswerCandidate(start_token=46, end_token=95, top_level=False),
 LongAnswerCandidate(start_token=98, end_token=234, top_level=True),
 LongAnswerCandidate(start_token=99, end_token=104, top_level=False),
 LongAnswerCandidate(start_token=108, end_token=119, top_level=False),
 LongAnswerCandidate(start_token=119, end_token=125, top_level=False),
 LongAnswerCandidate(start_token=125, end_token=134, top_level=False),
 LongAnswerCandidate(start_token=134, end_token=143, top_level=False),
 LongAnswerCandidate(start_token=143, end_token=152, top_level=False),
 LongAnswerCandidate(start_token=152, end_token=161, top_level=False),
 LongAnswerCandidate(start_token=161, end_token=170, top_level=False),
 LongAnswerCandidate(start_token=170, end_token=179, top_level=False),
 LongAnswerCandidate(start_token=179, end_token=191, top_level=False),
 LongAnswerCandidate(start_token=191, end_token=196, top_level=False),
 LongAnswerCan

In [46]:
long_token = long_best_non_null.orig_doc_start
min_dist = 1_000_000
long_answer = ''
num_long_found = 0
lyst_long_answers = []
for candidate in candidates_of_example :
    cstart, cend = candidate.start_token, candidate.end_token
    if cstart <= short_best_non_null.orig_doc_start and cend >= short_best_non_null.orig_doc_end:
      num_long_found += 1
      lyst_long_answers.append(' '.join(example.doc_tokens[cstart:cend]))
for candidate in candidates_of_example :
    cstart, cend = candidate.start_token, candidate.end_token
    dist = abs(cstart - long_token)
    if dist < min_dist:
      min_dist = dist
    if cstart == long_token:
      long_answer = ' '.join(example.doc_tokens[cstart:cend])
      break
      
print(f"QUESTION : {example.question_text}")
print(f"SHORT ANSWER : {short_best_non_null.text}")
print('------------------------------------------------------------------------------------------------------------------------')
if long_answer != '':
    print(f"BEST LONG ANSWER FOUND : {long_answer}")
else :
    print(f"no long answer found , min distance : {min_dist}")
print('------------------------------------------------------------------------------------------------------------------------')
print(f"THERE ARE {num_long_found} POSSIBLE LONG ANSWERS , WHICH CONTAIN SHORT ANSWER")
for long_answer in lyst_long_answers:
    print(long_answer)
    print('\n')

QUESTION : why do female tarantulas live longer than males
SHORT ANSWER : because they die relatively soon after maturing
------------------------------------------------------------------------------------------------------------------------
BEST LONG ANSWER FOUND : <P> A juvenile male 's sex can be determined by looking at a cast exuvia for exiandrous fusillae or spermathecae . Females possess spermathecae except for the species Sickius longibulbi and Encyocratella olivacea . Males have much shorter lifespans than females because they die relatively soon after maturing . Few live long enough for a postultimate molt , which is unlikely in natural habitats because they are vulnerable to predation , but has happened in captivity , though rarely . Most males do not live through this molt , as they tend to get their emboli , mature male sexual organs on pedipalps , stuck in the molt . Most tarantula fanciers regard females as more desirable as pets due to their much longer lifespans . Wil

In [47]:
# questions = [
#       'How long did it take to find the answer?',
#       'What\'s the answer to the great question?',
#       'What\'s the name of the computer?']
# paragraph = '''<p>The computer is named Deep Thought.</p>.
#                 <p>After 46 million years of training it found the answer.</p>
#                 <p>However, nobody was amazed. The answer was 42.</p>'''

In [48]:
# questions = [
#       'Who is she',
#       'How old am I',
#       'Why do I learn Machine Learning']
# paragraph = '''I am Dinh.
#                 Now I'm 26 years old.
#                 I study Machine Learning because I love state of the art algorithms'''

In [24]:
# questions = ['what is the difference between you and stars']
# paragraph = '''Do you know what is the difference between you and stars ?
#                 The stars are on the sky and you are in my heart !'''

In [36]:
# questions = ['who is Alan Turing',
#             'what did he do in World war 2',
#             'where did he work after the war',
#             'when did he join Max Newman',
#             'what is one of the first design for a stored-programm computer',
#             "why isn't he recognised in his home country as father of computer science"]
# paragraph = '''Alan Mathison Turing OBE FRS (/ˈtjʊərɪŋ/; 23 June 1912 – 7 June 1954) 
#             was an English mathematician, computer scientist, logician, cryptanalyst, 
#             philosopher and theoretical biologist.Turing was highly influential in 
#             the development of theoretical computer science, providing a formalisation of 
#             the concepts of algorithm and computation with the Turing machine, 
#             which can be considered a model of a general-purpose computer.
#             Turing is widely considered to be the father of theoretical computer science 
#             and artificial intelligence.Despite these accomplishments, he was not fully 
#             recognised in his home country during his lifetime, due to his homosexuality, 
#             and because much of his work was covered by the Official Secrets Act.During the Second World War, 
#             Turing worked for the Government Code and Cypher School (GC&CS) at Bletchley Park, 
#             Britain's codebreaking centre that produced Ultra intelligence. For a time he led Hut 8, 
#             the section that was responsible for German naval cryptanalysis. Here, he devised a 
#             number of techniques for speeding the breaking of German ciphers, including improvements 
#             to the pre-war Polish bombe method, an electromechanical machine that could find settings for the Enigma machine.
#             After the war, Turing worked at the National Physical Laboratory, 
#             where he designed the Automatic Computing Engine, 
#             which was one of the first designs for a stored-program computer. 
#             In 1948, Turing joined Max Newman's Computing Machine Laboratory at the Victoria University of Manchester, 
#             where he helped develop the Manchester computers and became interested in mathematical biology. 
#             He wrote a paper on the chemical basis of morphogenesis and predicted oscillating chemical 
#             reactions such as the Belousov–Zhabotinsky reaction, first observed in the 1960s.'''

In [55]:
questions = ['how many did he score for his country and club',
            'when was he born',
            'how much is his transfer to real madrid',
            'who is his biggest opponent']
paragraph = '''Cristiano Ronaldo dos Santos Aveiro GOIH ComM (born 5 February 1985) 
            is a Portuguese professional footballer who plays as a forward for Serie A club Juventus and captains the Portugal 
            national team. Often considered the best player in the world and widely regarded as one of the greatest players of 
            all time,[note 3] Ronaldo has won five Ballons d'Or[note 4] and four European Golden Shoes, both of which are 
            records for a European player. He has won 29 trophies in his career, including six league titles, five UEFA 
            Champions Leagues, one UEFA European Championship, and one UEFA Nations League. A prolific goalscorer, Ronaldo 
            holds the records for the most goals scored in the UEFA Champions League (128) and the joint-most goals scored in 
            the UEFA European Championship. He has scored over 700 senior career goals for club and country.

            Born and raised in Madeira, Ronaldo began his senior club career playing for Sporting CP, before signing with 
            Manchester United in 2003, aged 18. After winning the FA Cup in his first season, he helped United win three 
            successive Premier League titles, the UEFA Champions League, and the FIFA Club World Cup; at age 23, he won his 
            first Ballon d'Or. In 2009, Ronaldo was the subject of the then-most expensive association football transfer when 
            signed for Real Madrid in a transfer worth €94 million (£80 million). There, Ronaldo won 15 trophies, including 
            two La Liga titles, two Copas del Rey, and four UEFA Champions League titles, and became the club's all-time top 
            goalscorer. After joining Madrid, Ronaldo finished runner-up for the Ballon d'Or three times, behind Lionel 
            Messi—his perceived career rival—before winning back-to-back Ballons d'Or from 2013–2014 and again from 2016–2017. 
            After winning a third consecutive Champions League title in 2018, Ronaldo became the first player to win the trophy 
            five times. In 2018, he signed for Juventus in a transfer worth an initial €100 million (£88 million), the highest 
            ever paid by an Italian club and the highest ever paid for a player over 30 years old. With the Italian outfit, 
            he has won one Serie A and one Supercoppa Italiana.'''

In [65]:
questions = ["who is Ronaldo's girlfriend"]
paragraph = '''Born and raised in Madeira, DINH began his senior club career playing for Sporting CP, before signing with 
            Manchester United in 2003, aged 18. After winning the FA Cup in his first season, he helped United win three 
            successive Premier League titles, the UEFA Champions League, and the FIFA Club World Cup; at age 23, he won his 
            first Ballon d'Or. In 2009, Ronaldo was the subject of the then-most expensive association football transfer when 
            signed for Real Madrid in a transfer worth €94 million (£80 million). There, Ronaldo won 15 trophies, including 
            two La Liga titles, two Copas del Rey, and four UEFA Champions League titles, and became the club's all-time top 
            goalscorer. After joining Madrid, Ronaldo finished runner-up for the Ballon d'Or three times, behind Lionel 
            Messi—his perceived career rival—before winning back-to-back Ballons d'Or from 2013–2014 and again from 2016–2017. 
            After winning a third consecutive Champions League title in 2018, Ronaldo became the first player to win the trophy 
            five times. In 2018, he signed for Juventus in a transfer worth an initial €100 million (£88 million), the highest 
            ever paid by an Italian club and the highest ever paid for a player over 30 years old. With the Italian outfit, 
            he has won one Serie A and one Supercoppa Italiana.He hates Thien'''

In [66]:
def demo(questions,paragraph):
  for question in questions:
    question_tokens = tokenizer.tokenize(question)
    paragraph_tokens = tokenizer.tokenize(paragraph)
    tokens = ['[CLS]'] + question_tokens + ['[SEP]'] + paragraph_tokens + ['[SEP]']
    input_word_ids = tokenizer.convert_tokens_to_ids(tokens)
    input_mask = [1] * len(input_word_ids)
    input_type_ids = [0] * (1 + len(question_tokens) + 1) + [1] * (len(paragraph_tokens) + 1)

    input_word_ids, input_mask, input_type_ids = map(lambda t: tf.expand_dims(
        tf.convert_to_tensor(t, dtype=tf.int32), 0), (input_word_ids, input_mask, input_type_ids))
    outputs = model([input_word_ids, input_mask, input_type_ids])
    # using `[1:]` will enforce an answer. `outputs[:][0][0]` is the ignored '[CLS]' token logit.
    short_start = tf.argmax(outputs[0][0][1:]) + 1
    short_end = tf.argmax(outputs[1][0][1:]) + 1
    answer_tokens = tokens[short_start: short_end + 1]
    answer = tokenizer.convert_tokens_to_string(answer_tokens)
    print(f'Question: {question}')
    print(f'Answer: {answer}')
demo(questions,paragraph)

Question: who is Ronaldo's girlfriend
Answer: who is ronaldo ' s girlfriend


In [54]:
config

{
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "do_sample": false,
  "eos_token_ids": 0,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 1024,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1"
  },
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "is_decoder": false,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1
  },
  "layer_norm_eps": 1e-12,
  "length_penalty": 1.0,
  "max_length": 20,
  "max_position_embeddings": 512,
  "num_attention_heads": 16,
  "num_beams": 1,
  "num_hidden_layers": 24,
  "num_labels": 2,
  "num_return_sequences": 1,
  "output_attentions": false,
  "output_hidden_states": false,
  "output_past": true,
  "pad_token_id": 0,
  "pruned_heads": {},
  "repetition_penalty": 1.0,
  "temperature": 1.0,
  "top_k": 50,
  "top_p": 1.0,
  "torchscript": false,
  "type_vocab_size": 2,
  "use_bfloat16": false,
  "vocab_size": 30522
}