In [1]:
from util.util import *
from tqdm import tqdm
import seaborn as sns

In [2]:
class JsonDataset:
    def __init__(self, dataset_file):
        self.dataset_file = dataset_file
        self.dataset = load_json(self.dataset_file)

    def json_to_plain(self, remove_notfound=False, stp="no-ent", include_q_cands=False):
        """
        :param remove_notfound: replace the truth answer by an equivalent answer (UMLS) found in the document 
        :param include_q_cands: whether to include entities in query in list of candidate answers
        :param stp: no-ent | ent; whether to mark entities in passage; if ent, a multiword entity is treated as 1 token
        :return: {"id": "",
                  "p": "",
                  "q", "",
                  "a", "",
                  "c", [""]}
        """
        for datum in self.dataset[DATA_KEY]:
            for qa in datum[DOC_KEY][QAS_KEY]:
                fields = {}
                qa_txt_option = (" " + qa[QUERY_KEY]) if include_q_cands else ""
                #cand = [w for w in to_entities(datum[DOC_KEY][TITLE_KEY] + " " +
                #                               datum[DOC_KEY][CONTEXT_KEY] + qa_txt_option).lower().split() if w.startswith('@entity')]
                cand = [w for w in to_entities(datum[DOC_KEY][TITLE_KEY] + " " +
                                               datum[DOC_KEY][CONTEXT_KEY]).lower().split() if w.startswith('@entity')]
                cand_q = [w for w in to_entities(qa_txt_option).lower().split() if w.startswith('@entity')]
                if stp == "no-ent":
                    c = {ent_to_plain(e) for e in set(cand)}
                    a = ""
                    for ans in qa[ANS_KEY]:
                        if ans[ORIG_KEY] == "dataset":
                            a = ans[TXT_KEY].lower()
                    if remove_notfound:
                        if a not in c:
                            found_umls = False
                            for ans in qa[ANS_KEY]:
                                if ans[ORIG_KEY] == "UMLS":
                                    umls_answer = ans[TXT_KEY].lower()
                                    if umls_answer in c:
                                        found_umls = True
                                        a = umls_answer
                            if not found_umls:
                                continue
                    fields["c"] = list(c)
                    assert a
                    fields["a"] = a
                    document = remove_entity_marks(datum[DOC_KEY][TITLE_KEY] + " " + datum[DOC_KEY][CONTEXT_KEY]).replace(
                        "\n", " ").lower()
                    fields["p"] = document
                    fields["q"] = remove_entity_marks(qa[QUERY_KEY]).replace("\n", " ").lower()
                        
                elif stp == "ent":
                    c = set(cand)
                    c_q = set(cand_q)
                    a = ""
                    for ans in qa[ANS_KEY]:
                        if ans[ORIG_KEY] == "dataset":
                            a = plain_to_ent(ans[TXT_KEY].lower())
                    if remove_notfound:
                        if a not in c:
                            found_umls = False
                            for ans in qa[ANS_KEY]:
                                if ans[ORIG_KEY] == "UMLS":
                                    umls_answer = plain_to_ent(ans[TXT_KEY].lower())
                                    if umls_answer in c:
                                        found_umls = True
                                        a = umls_answer
                            if not found_umls:
                                continue
                    fields["c"] = list(c) + list(c_q)
                    assert a
                    fields["a"] = a
                    document = to_entities(datum[DOC_KEY][TITLE_KEY] + " " + datum[DOC_KEY][CONTEXT_KEY]).replace(
                        "\n", " ").lower()
                    fields["p"] = document
                    fields["q"] = to_entities(qa[QUERY_KEY]).replace("\n", " ").lower()
                else:
                    raise NotImplementedError

                fields["id"] = qa[ID_KEY]

                yield fields

In [3]:
filename = "/data/medg/misc/phuongpm/"

In [4]:
# traindata = JsonDataset(filename + "train1.0.json").json_to_plain(remove_notfound=True)

In [5]:
devdata = JsonDataset(filename + "dev1.0.json").json_to_plain()

In [7]:
found_doc_ans = 0
total = 0
for d in devdata:
    total += 1
    if d["a"] in d["p"]:
        found_doc_ans += 1

In [8]:
total

6391

In [9]:
found_doc_ans

3888