In [7]:
checkpoint = "/home/pzhu/data/qa/squad2_model/checkpoint-193450"
predict_file = "data/dev-v2.0.json"
model_name = "xlnet-large-cased"
output_prediction_file = "data/predictions.json"
output_nbest_file = "data/nbest_predictions.json"
output_null_log_odds_file = "data/null_odds.json"
device = "cuda:0"

In [2]:
from pytorch_transformers import XLNetForQuestionAnswering
model = XLNetForQuestionAnswering.from_pretrained(checkpoint)
model.to(device)
model.eval()
print("loaded")

loaded


In [3]:
from xlnet_qa.squad2_reader import SQuAD2Reader

reader = SQuAD2Reader(is_training=False)
dataset, examples, features = reader.squad_data(predict_file)

In [4]:
from tqdm import tqdm
import torch
from torch.utils.data import SequentialSampler, DataLoader
from xlnet_qa.utils_squad import RawResultExtended, write_predictions_extended

sampler = SequentialSampler(dataset)
dataloader = DataLoader(dataset, sampler=sampler, batch_size=8)

def to_list(tensor):
    return tensor.detach().cpu().tolist()

all_results = []
for batch in tqdm(dataloader, desc="Evaluating"):
    batch = tuple(t.to(device) for t in batch)
    with torch.no_grad():
        outputs = model(input_ids = batch[0],
                        attention_mask = batch[1],
                        token_type_ids = batch[2],
                        cls_index = batch[4],
                        p_mask = batch[5]
                       )
    for i, example_index in enumerate(batch[3]):
        eval_feature = features[example_index.item()]
        unique_id = int(eval_feature.unique_id)
        result = RawResultExtended(unique_id= unique_id,
                                   start_top_log_probs  = to_list(outputs[0][i]),
                                   start_top_index      = to_list(outputs[1][i]),
                                   end_top_log_probs    = to_list(outputs[2][i]),
                                   end_top_index        = to_list(outputs[3][i]),
                                   cls_logits           = to_list(outputs[4][i])
                                  )
        
        all_results.append(result)

Evaluating: 100%|██████████| 1543/1543 [14:02<00:00,  1.98it/s]


In [8]:
write_predictions_extended([examples], [features], all_results, 20, 30,
                  output_prediction_file, output_nbest_file, output_null_log_odds_file,
                  predict_file, model.config.start_n_top, model.config.end_n_top,
                  True, reader.tokenizer, False)

{'best_exact': 50.07159100480081,
 'best_exact_thresh': 0.0,
 'best_f1': 50.07580224037733,
 'best_f1_thresh': -0.8465023040771484,
 'has_ans_exact': 0.0020242914979757085,
 'has_ans_f1': 0.059967068372362826}

In [9]:
from xlnet_qa.utils_squad_evaluate import EVAL_OPTS, main as evaluate_on_squad
evaluate_options = EVAL_OPTS(data_file=predict_file,
                                 pred_file=output_prediction_file,
                                 na_prob_file=output_null_log_odds_file)
results = evaluate_on_squad(evaluate_options)

{
  "exact": 6.3505432493893705,
  "f1": 9.243533911491307,
  "total": 11873,
  "HasAns_exact": 0.20242914979757085,
  "HasAns_f1": 5.996706837236289,
  "HasAns_total": 5928,
  "NoAns_exact": 12.48107653490328,
  "NoAns_f1": 12.48107653490328,
  "NoAns_total": 5945,
  "best_exact": 50.07159100480081,
  "best_exact_thresh": 0.0,
  "best_f1": 50.07580224037733,
  "best_f1_thresh": -0.8465023040771484
}
