Skip to content

Commit

Permalink
Merge pull request huggingface#12 from stevezheng23/dev/zheng/quac
Browse files Browse the repository at this point in the history
Dev/zheng/quac
  • Loading branch information
stevezheng23 committed Oct 29, 2019
2 parents 4906c01 + 9c58687 commit 8b67c21
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
3 changes: 3 additions & 0 deletions examples/run_quac_kd.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ def predict(args, model, tokenizer, prefix=""):
result = result_lookup[feature.unique_id]
feature.start_targets = _compute_softmax(result.kd_start_logits)
feature.end_targets = _compute_softmax(result.kd_end_logits)
updated_features.append(feature)

torch.save(updated_features, updated_features_file)

Expand Down Expand Up @@ -471,6 +472,8 @@ def main():
help="Whether to run training.")
parser.add_argument("--do_eval", action='store_true',
help="Whether to run eval on the dev set.")
parser.add_argument("--do_predict", action='store_true',
help="Whether to run predict on the dev set.")
parser.add_argument("--evaluate_during_training", action='store_true',
help="Rul evaluation during training at each logging step.")
parser.add_argument("--do_lower_case", action='store_true',
Expand Down
4 changes: 2 additions & 2 deletions examples/utils_quac.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,9 +948,9 @@ def write_predictions_v2(all_examples, all_features, all_results, n_best_size,

if null_score_threshold is not None:
if nbest_json[0]["text"] == "CANNOTANSWER" and nbest_json[0]["probability"] > null_score_threshold:
all_predictions[qas_id] = "CANNOTANSWER"
all_predictions[example.qas_id] = "CANNOTANSWER"
else:
all_predictions[qas_id] = nbest_json[0]["text"] if nbest_json[0]["text"] != "CANNOTANSWER" else nbest_json[1]["text"]
all_predictions[example.qas_id] = nbest_json[0]["text"] if nbest_json[0]["text"] != "CANNOTANSWER" else nbest_json[1]["text"]
else:
if not version_2_with_negative:
all_predictions[example.qas_id] = nbest_json[0]["text"]
Expand Down

0 comments on commit 8b67c21

Please sign in to comment.