In [1]:
checkpoint = "/home/pzhu/data/qa/squad2_model"
predict_file = "data/squad2/dev-v2.0.json"
device = "cuda:0"

In [2]:
from pytorch_transformers import XLNetForQuestionAnswering
model = XLNetForQuestionAnswering.from_pretrained(checkpoint)
model.to(device)
model.eval()
print("loaded")

loaded


In [3]:
from xlnet_qa.squad2_reader import SQuAD2Reader

reader = SQuAD2Reader(is_training=False)
examples, features, datasets = reader.squad_data(predict_file)

In [4]:
from tqdm import tqdm
import torch
from torch.utils.data import SequentialSampler, DataLoader
from xlnet_qa.utils_squad import RawResultExtended, write_predictions_extended

sampler = SequentialSampler(datasets)
dataloader = DataLoader(datasets, sampler=sampler, batch_size=1)

def to_list(tensor):
    return tensor.detach().cpu().tolist()

In [22]:
data = tuple(t.to(device) for t in next(iter(dataloader)))
example = examples[data[3].item()]
feature = features[data[3].item()]

In [23]:
print(example.question_text)
print(example.doc_tokens)
print(example.orig_answer_text)
print(example.start_position, example.end_position)

In what country is Normandy located?
['The', 'Normans', '(Norman:', 'Nourmands;', 'French:', 'Normands;', 'Latin:', 'Normanni)', 'were', 'the', 'people', 'who', 'in', 'the', '10th', 'and', '11th', 'centuries', 'gave', 'their', 'name', 'to', 'Normandy,', 'a', 'region', 'in', 'France.', 'They', 'were', 'descended', 'from', 'Norse', '("Norman"', 'comes', 'from', '"Norseman")', 'raiders', 'and', 'pirates', 'from', 'Denmark,', 'Iceland', 'and', 'Norway', 'who,', 'under', 'their', 'leader', 'Rollo,', 'agreed', 'to', 'swear', 'fealty', 'to', 'King', 'Charles', 'III', 'of', 'West', 'Francia.', 'Through', 'generations', 'of', 'assimilation', 'and', 'mixing', 'with', 'the', 'native', 'Frankish', 'and', 'Roman-Gaulish', 'populations,', 'their', 'descendants', 'would', 'gradually', 'merge', 'with', 'the', 'Carolingian-based', 'cultures', 'of', 'West', 'Francia.', 'The', 'distinct', 'cultural', 'and', 'ethnic', 'identity', 'of', 'the', 'Normans', 'emerged', 'initially', 'in', 'the', 'first', 'half'

In [26]:
outputs = model(input_ids = data[0],
                        attention_mask = data[1],
                        token_type_ids = data[2],
                        cls_index = data[4],
                        p_mask = data[5]
                       )

In [28]:
unique_id = int(feature.unique_id)
result = RawResultExtended(unique_id= unique_id,
                            start_top_log_probs  = to_list(outputs[0][0]),
                            start_top_index      = to_list(outputs[1][0]),
                            end_top_log_probs    = to_list(outputs[2][0]),
                            end_top_index        = to_list(outputs[3][0]),
                            cls_logits           = to_list(outputs[4][0])
                            )

In [31]:
result

RawResultExtended(unique_id=1000000000, start_top_log_probs=[0.00584795419126749, 0.00584795419126749, 0.005847953725606203, 0.005847953725606203, 0.005847953725606203], start_top_index=[164, 153, 47, 15, 11], end_top_log_probs=[0.005847962573170662, 0.005847966764122248, 0.0058479635044932365, 0.0058479649014770985, 0.005847963970154524, 0.005847962573170662, 0.0058479611761868, 0.0058479635044932365, 0.005847962107509375, 0.0058479611761868, 0.005847962573170662, 0.0058479611761868, 0.0058479635044932365, 0.005847962107509375, 0.0058479611761868, 0.005847959779202938, 0.0058479611761868, 0.005847960710525513, 0.005847962107509375, 0.0058479611761868, 0.005847959779202938, 0.0058479611761868, 0.005847960710525513, 0.005847959313541651, 0.0058479611761868], end_top_index=[56, 46, 64, 46, 46, 46, 51, 47, 52, 75, 42, 47, 46, 45, 52, 52, 52, 54, 51, 144, 47, 50, 52, 16, 69], cls_logits=-0.8711456656455994)

In [55]:
import collections
from xlnet_qa.utils_squad import get_final_text, _compute_softmax

def write_predictions_extended(example, feature, result, n_best_size,
                                max_answer_length, start_n_top, end_n_top, tokenizer):
    """ XLNet write prediction logic (more complex than Bert's).
        Write final predictions to the json file and log-odds of null if needed.

        Requires utils_squad_evaluate.py
    """
    _PrelimPrediction = collections.namedtuple(  # pylint: disable=invalid-name
        "PrelimPrediction",
        ["start_index", "end_index",
        "start_log_prob", "end_log_prob"])

    _NbestPrediction = collections.namedtuple(  # pylint: disable=invalid-name
        "NbestPrediction", ["text", "start_log_prob", "end_log_prob"])

    prelim_predictions = []
    # keep track of the minimum score of null start+end of position 0
    score_null = 1000000  # large and positive

    cur_null_score = result.cls_logits

    # if we could have irrelevant answers, get the min score of irrelevant
    score_null = min(score_null, cur_null_score)

    for i in range(start_n_top):
        for j in range(end_n_top):
            start_log_prob = result.start_top_log_probs[i]
            start_index = result.start_top_index[i]

            j_index = i * end_n_top + j

            end_log_prob = result.end_top_log_probs[j_index]
            end_index = result.end_top_index[j_index]

            # We could hypothetically create invalid predictions, e.g., predict
            # that the start of the span is in the question. We throw out all
            # invalid predictions.
            if start_index >= feature.paragraph_len - 1:
                continue
            if end_index >= feature.paragraph_len - 1:
                continue

            if not feature.token_is_max_context.get(start_index, False):
                continue
            if end_index < start_index:
                continue
            length = end_index - start_index + 1
            if length > max_answer_length:
                continue

            prelim_predictions.append(
                _PrelimPrediction(
                    start_index=start_index,
                    end_index=end_index,
                    start_log_prob=start_log_prob,
                    end_log_prob=end_log_prob))

    prelim_predictions = sorted(
        prelim_predictions,
        key=lambda x: (x.start_log_prob + x.end_log_prob),
        reverse=True)

    seen_predictions = {}
    nbest = []
    for pred in prelim_predictions:
        if len(nbest) >= n_best_size:
            break

        # XLNet un-tokenizer
        # Let's keep it simple for now and see if we need all this later.
        # 
        # tok_start_to_orig_index = feature.tok_start_to_orig_index
        # tok_end_to_orig_index = feature.tok_end_to_orig_index
        # start_orig_pos = tok_start_to_orig_index[pred.start_index]
        # end_orig_pos = tok_end_to_orig_index[pred.end_index]
        # paragraph_text = example.paragraph_text
        # final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip()

        # Previously used Bert untokenizer
        tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
        orig_doc_start = feature.token_to_orig_map[pred.start_index]
        orig_doc_end = feature.token_to_orig_map[pred.end_index]
        orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
        tok_text = tokenizer.convert_token_best_sizens_to_string(tok_tokens)

        # Clean whitespace
        tok_text = tok_text.strip()
        tok_text = " ".join(tok_text.split())
        orig_text = " ".join(orig_tokens)

        final_text = get_final_text(tok_text, orig_text, tokenizer.do_lower_case,
                                    False)

        if final_text in seen_predictions:
            continue

        seen_predictions[final_text] = True

        nbest.append(
            _NbestPrediction(
                text=final_text,
                start_log_prob=pred.start_log_prob,
                end_log_prob=pred.end_log_prob))

    # In very rare edge cases we could have no valid predictions. So we
    # just create a nonce prediction in this case to avoid failure.
    if not nbest:
        nbest.append(
            _NbestPrediction(text="", start_log_prob=-1e6,
            end_log_prob=-1e6))

    total_scores = []
    best_non_null_entry = None
    for entry in nbest:
        total_scores.append(entry.start_log_prob + entry.end_log_prob)
        if not best_non_null_entry:
            best_non_null_entry = entry

    probs = _compute_softmax(total_scores)

    nbest_json = []
    for (i, entry) in enumerate(nbest):
        output = collections.OrderedDict()
        output["text"] = entry.text
        output["probability"] = probs[i]
        output["start_log_prob"] = entry.start_log_prob
        output["end_log_prob"] = entry.end_log_prob
        nbest_json.append(output)

    assert len(nbest_json) >= 1
    assert best_non_null_entry is not None

    score_diff = score_null
    
    print("="*80)
    print(score_diff)
    print(best_non_null_entry.text)
    print(nbest_json)
    return  best_non_null_entry.text, score_diff


In [56]:
write_predictions_extended(example, feature, result, 20, 30,
                  model.config.start_n_top, model.config.end_n_top,
                  reader.tokenizer)

-0.8711456656455994
their
[OrderedDict([('text', 'their'), ('probability', 0.33333333354029393), ('start_log_prob', 0.005847953725606203), ('end_log_prob', 0.0058479611761868)]), OrderedDict([('text', 'their name to Normandy, a'), ('probability', 0.33333333354029393), ('start_log_prob', 0.005847953725606203), ('end_log_prob', 0.0058479611761868)]), OrderedDict([('text', 'Normans (Norman:'), ('probability', 0.3333333329194122), ('start_log_prob', 0.005847953725606203), ('end_log_prob', 0.005847959313541651)])]
