Skip to content

Commit

Permalink
update evaluate script (codalab ver)
Browse files Browse the repository at this point in the history
  • Loading branch information
ymcui committed Apr 27, 2022
1 parent 937ef90 commit c0eb1b6
Showing 1 changed file with 12 additions and 15 deletions.
27 changes: 12 additions & 15 deletions baseline/cmrc2018_evaluate.py
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
'''
Evaluation script for CMRC 2018
version: v5 - special
Note:
v6: compatible for both original and SQuAD-style CMRC 2018 datasets. support python3.
v5 - special: Evaluate on SQuAD-style CMRC 2018 Datasets
v5: formatted output, add usage description
v4: fixed segmentation issues
Expand All @@ -14,14 +14,12 @@
import argparse
import json
import sys
reload(sys)
sys.setdefaultencoding('utf8')
import nltk
import pdb

# split Chinese with English
def mixed_segmentation(in_str, rm_punc=False):
in_str = str(in_str).decode('utf-8').lower().strip()
in_str = str(in_str).lower().strip()
segs_out = []
temp_str = ""
sp_char = ['-',':','_','*','^','/','\\','~','`','+','=',
Expand All @@ -30,7 +28,7 @@ def mixed_segmentation(in_str, rm_punc=False):
for char in in_str:
if rm_punc and char in sp_char:
continue
if re.search(ur'[\u4e00-\u9fa5]', char) or char in sp_char:
if re.search(r'[\u4e00-\u9fa5]', char) or char in sp_char:
if temp_str != "":
ss = nltk.word_tokenize(temp_str)
segs_out.extend(ss)
Expand All @@ -46,10 +44,9 @@ def mixed_segmentation(in_str, rm_punc=False):

return segs_out


# remove punctuation
def remove_punctuation(in_str):
in_str = str(in_str).decode('utf-8').lower().strip()
in_str = str(in_str).lower().strip()
sp_char = ['-',':','_','*','^','/','\\','~','`','+','=',
',','。',':','?','!','“','”',';','’','《','》','……','·','、',
'「','」','(',')','-','~','『','』']
Expand Down Expand Up @@ -82,22 +79,22 @@ def evaluate(ground_truth_file, prediction_file):
em = 0
total_count = 0
skip_count = 0
for instance in ground_truth_file["data"]:
#context_id = instance['context_id'].strip()
#context_text = instance['context_text'].strip()
for para in instance["paragraphs"]:

data_list = ground_truth_file['data'] if 'data' in ground_truth_file else ground_truth_file
for instance in data_list:
para_list = instance['paragraphs'] if 'paragraphs' in instance else [instance]
for para in para_list:
for qas in para['qas']:
total_count += 1
query_id = qas['id'].strip()
query_text = qas['question'].strip()
answers = [x["text"] for x in qas['answers']]
query_id = qas['id'] if 'id' in qas else qas['query_id']
answers = [x['text'] if isinstance(x, dict) else x for x in qas['answers']]

if query_id not in prediction_file:
sys.stderr.write('Unanswered question: {}\n'.format(query_id))
skip_count += 1
continue

prediction = str(prediction_file[query_id]).decode('utf-8')
prediction = str(prediction_file[query_id])
f1 += calc_f1_score(answers, prediction)
em += calc_em_score(answers, prediction)

Expand Down

0 comments on commit c0eb1b6

Please sign in to comment.