In [1]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

In [2]:
from pathlib import Path
import shutil

In [3]:
p_tr = Path("../ClinicalTransformerMRC/example_datasets/2018n2c2/raw_data/train")
p_dev = Path("../ClinicalTransformerMRC/example_datasets/2018n2c2/raw_data/test/")

fids = []

fids.extend([f for f in p_tr.glob("*.ann")])
fids.extend([f for f in p_dev.glob("*.ann")])

len(fids)

10

In [5]:
from pathlib import Path
import pickle as pkl
from collections import defaultdict, Counter
from itertools import permutations, combinations
from functools import reduce
import numpy as np
import os
import re
import sys
import json

# https://github.com/uf-hobi-informatics-lab/NLPreprocessing (git clone this repo to local)
sys.path.append("./NLPreprocessing/")
sys.path.append("./NLPreprocessing/text_process")
from annotation2BIO import pre_processing, read_annotation_brat, generate_BIO
MIMICIII_PATTERN = ""
from sentence_tokenization import logger as l1
from annotation2BIO import logger as l2
l1.disabled = True
l2.disabled = True

def pkl_save(data, file):
    with open(file, "wb") as f:
        pkl.dump(data, f)

        
def pkl_load(file):
    with open(file, "rb") as f:
        data = pkl.load(f)
    return data


def load_text(ifn):
    with open(ifn, "r") as f:
        txt = f.read()
    return txt


def save_text(text, ofn):
    with open(ofn, "w") as f:
        f.write(text)

In [6]:
def create_entity_to_sent_mapping(nnsents, entities, idx2e, fn):
    loc_ens = []
    
    ll = len(nnsents)
    mapping = defaultdict(list)
    for idx, each in enumerate(entities):
        en_label = idx2e[idx]
        en_s = each[2][0]
        en_e = each[2][1]
        new_en = []
        
        i = 0
        while i < ll and nnsents[i][1][0] < en_s:
            i += 1
        s_s = nnsents[i][1][0]
        s_e = nnsents[i][1][1]

        if en_s == s_s:
            mapping[en_label].append(i)

            while i < ll and s_e < en_e:
                i += 1
                s_e = nnsents[i][1][1]
            if s_e == en_e:
                 mapping[en_label].append(i)
            else:
                mapping[en_label].append(i)
                print(fn)
                print("last index not match ", each)
        else:
            mapping[en_label].append(i)
            print(fn)
            print("first index not match ", each)

            while i < ll and s_e < en_e:
                i += 1
                s_e = nnsents[i][1][1]
            if s_e == en_e:
                 mapping[en_label].append(i)
            else:
                mapping[en_label].append(i)
                print(fn)
                print("last index not match ", each)
    return mapping

    
def __ann_info(ann):
    en_info = ann.split(" ")
    return en_info[0], int(en_info[1]), int(en_info[-1])


def load_annotation_brat(ann_file, rep=False):
    """
    load annotation data
    entity_id2index_map -> {'T1': 0}
    entites -> ('T1', 'anticoagulant medications', 'Drug', (1000, 1025))
    relations -> ('Route-Drug', 'Arg1:T3', 'Arg2:T2')
    """
    # map the entity id (e.g., T1) to its index in entities list
    entity_id2index_map = dict()
    entites = []
    relations = []
    events = []
    attrs = []
    
    with open(ann_file, "r") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            anns = line.split("\t")
            ann_id = anns[0]
            if ann_id.startswith("T"):
                # T1	LivingStatus 25 30	lives
                entity_words = anns[-1]
                t_type, offset_s, offset_e = __ann_info(anns[1])
                entites.append((entity_words, t_type, (offset_s, offset_e), ann_id))
                entity_id2index_map[ann_id] = len(entites) - 1

            elif ann_id.startswith("A"):
                #  A1	StatusTimeVal T2 current
                att_type, env_id, att_val = anns[1].strip().split(" ")
                attrs.append((att_type, env_id, att_val))

            elif ann_id.startswith("E"):
                # E2	Alcohol:T3 Status:T4 Amount:T5 Frequency:T6 Type:T10
                single_event = dict()
                envs = anns[1].split(" ")
                trigger = envs[0].strip().split(":")
                tas = []
                for each in envs[1:]:
                    tas.append(each.strip().split(":"))
                
                single_event["trigger"] = trigger
                single_event["events"] = tas
                    
                events.append(single_event)
            elif ann_id.startswith("R"):
                # R2	Strength-Drug Arg1:T6 Arg2:T5
                relation, tail, head = anns[1].split()[0], anns[1].split()[1], anns[1].split()[-1]
                
                relations.append((relation, tail, head))

    return entity_id2index_map, entites, relations, events, attrs

In [7]:
"""
sample mrc format data:
a list of dict as
  {
    "context": "Germany 's representative to the European Union 's veterinary committee Werner Zwingmann said on Wednesday consumers should buy sheepmeat from countries other than Britain until the scientific advice was clearer .",
    "end_position": [
      0,
      23
    ],
    "entity_label": "LOC",
    "impossible": false,
    "qas_id": "4.3",
    "query": "location entities are the name of politically or geographically defined locations such as cities, provinces, countries, international regions, bodies of water, mountains, etc.",
    "span_position": [
      "0;0",
      "23;23"
    ],
    "start_position": [
      0,
      23
    ]
  },
  {
    "context": "EU rejects German call to boycott British lamb .",
    "end_position": [],
    "entity_label": "PER",
    "impossible": true,
    "qas_id": "0.2",
    "query": "person entities are named persons or family.",
    "span_position": [],
    "start_position": []
  }, ...

"""


# create query template
# make sure the label id in qas_id is consistent for each type of entity
entity_query_template = {
    "Drug": "Find the drug events including names, brand names, and collective names.",
    "Strength": "Find the strength events that are the amount of drug in a given dosage.",
    "Form": "Find the form events that are the physical form of given drug or medication.", 
    "Dosage": "Find the dosage events that are the amount of a medication used in each administration.", 
    "Frequency":"Frequency indicates how often each dose of the medication should be taken", 
    "Route": "Find the frequency events that indicate how often each dose of the medication should be taken.", 
    "Duration": "Find the duration events that indicate how long the medication is to be administered.", 
    "Reason": "Find the reason events that are the medical reason for which the medication is given.",
    "ADE": "Find the ADE events that are injuries resulting from a medical intervention related to drugs."
}

attribute_query_template = {
    "Strength": "What is the active ingredient amount of {}",
    "Form": "What is the physical form of {}", 
    "Dosage": "What is the amount of {} taken", 
    "Frequency":"How often each dose of {} should be taken", 
    "Route": "What is the path of {} taken into the body", 
    "Duration": "How long to take {}", 
    "Reason": "What is the medical reason for giving {}",
    "ADE": "What are the injuries resulting from the use of {}"
}

entity_id = {
    "Drug": 1, "Strength": 2, "Form": 3, "Dosage": 4, "Frequency": 5, "Route": 6, "Duration": 7, "Reason": 8, "ADE": 9
}

attribute_id = {
    "Strength": 1, "Form": 2, "Dosage": 3, "Frequency": 4, "Route": 5, "Duration": 6, "Reason": 7, "ADE": 8
}

head_entity = {"Drug"}

In [8]:
def get_sent(sents, idx1, idx2):
    if idx1 == idx2:
        return sents[idx1]
    elif idx1 == idx2 + 1:
        raise Exception(f"{idx1} {idx2} - entity not in the same sentence")
    else:
        raise Exception(f"{idx1} {idx2} - the entity has word spread in >2 sentences")
        
        
def to_json(data, p, fn="train"):
    import json
    
    ofn = p / f"mrc-ner.{fn}"
    
    with open(ofn, "w") as f:
        json.dump(data, f, indent=2)

In [None]:
#create mrc format data using the brat annotation file
#NER task

training_data = []

for i, fn in enumerate(p_tr.glob("*.ann")):

    txt_fn = p_tr / f"{fn.stem}.txt"
    ann_fn = p_tr / f"{fn.name}"
    txt, sents = pre_processing(txt_fn, MIMICIII_PATTERN, max_len=256)
    e2i, ens, relations, evns, attrs = load_annotation_brat(ann_fn)
    i2e = {v: k for k, v in e2i.items()}
    nsents, sent_bound = generate_BIO(sents, ens, file_id="", no_overlap=False, record_pos=True)
    nnsents = [w for sent in nsents for w in sent]
    mappings = create_entity_to_sent_mapping(nnsents, ens, i2e, fn.name)

    num_sents = len(nsents)
    sent_ids = set(range(num_sents))
    sent_with_entities = set()

    entity_sent_idx_mappings = defaultdict(list)

    # sentence with entities
    for en in ens:

        entype = en[1]
        type_id = entity_id[entype]
        
        s_idx, e_idx = mappings[en[-1]]
        word_info1 = nnsents[s_idx]
        word_info2 = nnsents[e_idx]
        
        sent_idx1 = word_info1[3][0]
        sent_idx2 = word_info2[3][0]
        sent_with_entities.add(sent_idx1)
    #         sent_with_entities.add(sent_idx2)
        
        try:
            sent_text = get_sent(nsents, sent_idx1, sent_idx2)
            context = " ".join([e[0] for e in sent_text])
            start = word_info1[3][1]
            end = word_info2[3][1]
            start_end = f"{start};{end}"
        except Exception as ex:
            print(fn.name)
            print(word_info1, word_info2)
            print(context)
            print(start, end, start_end)
            print(entype, type_id)
        
        # key will be sent_id, type
        # data will be tuple (context, entype, type_id, start, end, start_end)
        entity_sent_idx_mappings[(sent_idx1, entype)].append(
            (context, entype, type_id, start, end, start_end))

    # for i in sent_ids:
    for i in sent_with_entities:
        file_id = f"{fn.stem}"
        sent_id = f"{i}"
        sent_i_context = " ".join(e[0] for e in nsents[i])
        for k, v in entity_query_template.items():
            tid = entity_id[k]
            
            if (i, k) in entity_sent_idx_mappings:
                d = {
                        "context": sent_i_context,
                        "end_position": [],
                        "entity_label": k,
                        "impossible": False,
                        "qas_id": f"{file_id}.{sent_id}.{sent_id}.{tid}",
                        "query": v,
                        "span_position": [],
                        "start_position": []
                        }
                
                entities = entity_sent_idx_mappings[(i, k)]
        
                for ent in entities:
                    assert ent[0] == sent_i_context, f"expect context: {ent[0]} but get {sent_i_context}"
                    assert ent[1] == k, f"expect en type: {ent[1]} but get {k}"
                    s, e, se = ent[-3:]
                    d["start_position"].append(s)
                    d["end_position"].append(e)
                    d["span_position"].append(se)
                
                training_data.append(d)
            else:
                d = {
                        "context": sent_i_context,
                        "end_position": [],
                        "entity_label": k,
                        "impossible": True,
                        "qas_id": f"{file_id}.{sent_id}.{sent_id}.{tid}",
                        "query": v,
                        "span_position": [],
                        "start_position": []
                    }
                training_data.append(d)

In [13]:
pout = Path("../ClinicalTransformerMRC/example_datasets/2018n2c2/mrc_data/entity/")
pout.mkdir(parents=True, exist_ok=True)
to_json(training_data, pout, "train")

In [None]:
##create dev data
#NER task
dev_data = []


for i, fn in enumerate(p_dev.glob("*.ann")):

    txt_fn = p_dev / f"{fn.stem}.txt"
    ann_fn = p_dev / f"{fn.name}"
    txt, sents = pre_processing(txt_fn, MIMICIII_PATTERN, max_len=256)
    e2i, ens, _, evns, attrs = load_annotation_brat(ann_fn)
    i2e = {v: k for k, v in e2i.items()}
    nsents, sent_bound = generate_BIO(sents, ens, file_id="", no_overlap=False, record_pos=True)
    nnsents = [w for sent in nsents for w in sent]
    mappings = create_entity_to_sent_mapping(nnsents, ens, i2e, fn.name)

    num_sents = len(nsents)
    sent_ids = set(range(num_sents))
    sent_with_entities = set()

    entity_sent_idx_mappings = defaultdict(list)

    # sentence with entities
    for en in ens:

        entype = en[1]
        type_id = entity_id[entype]
        
        s_idx, e_idx = mappings[en[-1]]
        word_info1 = nnsents[s_idx]
        word_info2 = nnsents[e_idx]
        
        sent_idx1 = word_info1[3][0]
        sent_idx2 = word_info2[3][0]
        sent_with_entities.add(sent_idx1)
    #         sent_with_entities.add(sent_idx2)
        
        try:
            sent_text = get_sent(nsents, sent_idx1, sent_idx2)
            context = " ".join([e[0] for e in sent_text])
            start = word_info1[3][1]
            end = word_info2[3][1]
            start_end = f"{start};{end}"
        except Exception as ex:
            print(fn.name)
            print(word_info1, word_info2)
            print(context)
            print(start, end, start_end)
            print(entype, type_id)
        
        # key will be sent_id, type
        # data will be tuple (context, entype, type_id, start, end, start_end)
        entity_sent_idx_mappings[(sent_idx1, entype)].append(
            (context, entype, type_id, start, end, start_end))

    # for i in sent_ids:
    for i in sent_with_entities:
        file_id = f"{fn.stem}"
        sent_id = f"{i}"
        sent_i_context = " ".join(e[0] for e in nsents[i])
        for k, v in entity_query_template.items():
            tid = entity_id[k]
            if (i, k) in entity_sent_idx_mappings:
                d = {
                        "context": sent_i_context,
                        "end_position": [],
                        "entity_label": k,
                        "impossible": False,
                        "qas_id": f"{file_id}.{sent_id}.{sent_id}.{tid}",
                        "query": v,
                        "span_position": [],
                        "start_position": []
                        }
                
                entities = entity_sent_idx_mappings[(i, k)]
        
                for ent in entities:
                    assert ent[0] == sent_i_context, f"expect context: {ent[0]} but get {sent_i_context}"
                    assert ent[1] == k, f"expect en type: {ent[1]} but get {k}"
                    s, e, se = ent[-3:]
                    d["start_position"].append(s)
                    d["end_position"].append(e)
                    d["span_position"].append(se)
                
                dev_data.append(d)
            else:
                d = {
                        "context": sent_i_context,
                        "end_position": [],
                        "entity_label": k,
                        "impossible": True,
                        "qas_id": f"{file_id}.{sent_id}.{sent_id}.{tid}",
                        "query": v,
                        "span_position": [],
                        "start_position": []
                    }
                dev_data.append(d)

In [12]:
pout = Path("/data/datasets/cheng/ClinicalTransformerMRC/2018n2c2/dataset/mrc_entity/")
pout.mkdir(parents=True, exist_ok=True)
to_json(dev_data, pout, "dev")

In [None]:
#create test data
#NER task
test_data = []

for i, fn in enumerate(p_dev.glob("*.ann")):

    txt_fn = p_dev / f"{fn.stem}.txt"
    ann_fn = p_dev / f"{fn.name}"
    txt, sents = pre_processing(txt_fn, MIMICIII_PATTERN, max_len=256)
    e2i, ens, _, evns, attrs = load_annotation_brat(ann_fn)
    i2e = {v: k for k, v in e2i.items()}
    nsents, sent_bound = generate_BIO(sents, ens, file_id="", no_overlap=False, record_pos=True)
    nnsents = [w for sent in nsents for w in sent]
    mappings = create_entity_to_sent_mapping(nnsents, ens, i2e, fn.name)

    num_sents = len(nsents)
    sent_ids = set(range(num_sents))
    sent_with_entities = set()

    entity_sent_idx_mappings = defaultdict(list)

    # sentence with entities
    for en in ens:

        entype = en[1]
        type_id = entity_id[entype]
        
        s_idx, e_idx = mappings[en[-1]]
        word_info1 = nnsents[s_idx]
        word_info2 = nnsents[e_idx]
        
        sent_idx1 = word_info1[3][0]
        sent_idx2 = word_info2[3][0]
        sent_with_entities.add(sent_idx1)
    #         sent_with_entities.add(sent_idx2)

    # for i in sent_ids:
    for i in sent_with_entities:
        file_id = f"{fn.stem}"
        sent_id = f"{i}"
        sent_i_context = " ".join(e[0] for e in nsents[i])
        for k, v in entity_query_template.items():
            tid = entity_id[k]
            d = {
                    "context": sent_i_context,
                    "end_position": [],
                    "entity_label": k,
                    "impossible": True,
                    "qas_id": f"{file_id}.{sent_id}.{sent_id}.{tid}",
                    "query": v,
                    "span_position": [],
                    "start_position": []
                    }
            test_data.append(d)
          

In [14]:
pout = Path("/data/datasets/cheng/ClinicalTransformerMRC/2018n2c2/dataset/mrc_entity/")
pout.mkdir(parents=True, exist_ok=True)
to_json(test_data, pout, "test")

In [None]:
#create mrc format data using the brat annotation file
#RE task
training_data = []

for i, fn in enumerate(p_tr.glob("*.ann")):

    txt_fn = p_tr / f"{fn.stem}.txt"
    ann_fn = p_tr / f"{fn.name}"
    txt, sents = pre_processing(txt_fn, MIMICIII_PATTERN, max_len=256)
    e2i, ens, relations, evns, attrs = load_annotation_brat(ann_fn)
    i2e = {v: k for k, v in e2i.items()}
    nsents, sent_bound = generate_BIO(sents, ens, file_id="", no_overlap=False, record_pos=True)
    nnsents = [w for sent in nsents for w in sent]
    mappings = create_entity_to_sent_mapping(nnsents, ens, i2e, fn.name)

    num_sents = len(nsents)
    sent_ids = set(range(num_sents))
    sent_with_drugs = set()

    head_list = []
    tail_ann_idx_mappings = defaultdict(list)
    relation_ann_idx_mappings = defaultdict(list)


    # sentence with entities
    for en in ens:
        entype = en[1]
        ann_id = en[-1]
        en_offset = en[2]
        type_id = entity_id[entype]
        text = en[0]
        s_idx, e_idx = mappings[en[-1]]
        word_info1 = nnsents[s_idx]
        word_info2 = nnsents[e_idx]
        
        sent_idx1 = word_info1[3][0]
        sent_idx2 = word_info2[3][0]
        sent_with_drugs.add(sent_idx1)
        
        try:
            sent_text = get_sent(nsents, sent_idx1, sent_idx2)
            context = " ".join([e[0] for e in sent_text])
            start = word_info1[3][1]
            end = word_info2[3][1]
            start_end = f"{start};{end}"
        except Exception as ex:
            print(fn.name)
            print(word_info1, word_info2)
            print(context)
            print(start, end, start_end)
            print(entype, type_id)
            
            # key will be sent_id, type
            # data will be tuple (context, entype, type_id, start, end, start_end)
        if entype in head_entity:
            head_list.append(
                [sent_idx1, text, context, entype, en_offset, start, end, start_end, ann_id])
        else:
            tail_ann_idx_mappings[ann_id].append((sent_idx1, text, context, type_id, start, end, start_end))

    for r in relations:
        head_id = r[-1].split(":")[1]
        tail_ent = r[0].split('-')[0]
        tail_id = r[1].split(":")[1]

        relation_ann_idx_mappings[(head_id, tail_ent)].append(tail_id)
    
    for head_ent in head_list:
        head_sent_id = head_ent[0]
        head_text = head_ent[1]
        len_head_sent = len(nsents[head_sent_id])
        file_id = f"{fn.stem}"
        head_sent_context = " ".join(e[0] for e in nsents[head_sent_id])
        head_ann_id = head_ent[-1]
        head_entity_type = head_ent[3]
        head_offset_s, head_offset_e = head_ent[4]
      

        for k, v in attribute_query_template.items():
            tid = attribute_id[k]
            if (head_ann_id, k) in relation_ann_idx_mappings:
                tail_ann_ids = relation_ann_idx_mappings[(head_ann_id, k)]

                same_sent_start_positions = []
                same_sent_end_positions = []
                same_sent_span_positions = []

                tail_sent_id_to_positions = defaultdict(lambda: ([], [], []))  # map from tail_sent_id to positions list
                
                for tail_ann_id in tail_ann_ids:
                    tail_sent_id = tail_ann_idx_mappings[tail_ann_id][0][0]
                    tail_sent_context = tail_ann_idx_mappings[tail_ann_id][0][2]

                    s, e, se = tail_ann_idx_mappings[tail_ann_id][0][-3], tail_ann_idx_mappings[tail_ann_id][0][-2], tail_ann_idx_mappings[tail_ann_id][0][-1]

                    if tail_sent_id == head_sent_id:
                        same_sent_start_positions.append(s)
                        same_sent_end_positions.append(e)
                        same_sent_span_positions.append(se)
                    else:
                        s = s + (len_head_sent if tail_sent_id > head_sent_id else 0)
                        e = e + (len_head_sent if tail_sent_id > head_sent_id else 0)
                        se = f"{s};{e}"
                        tail_sent_id_to_positions[tail_sent_id][0].append(s)
                        tail_sent_id_to_positions[tail_sent_id][1].append(e)
                        tail_sent_id_to_positions[tail_sent_id][2].append(se)
                        tail_sent_context = tail_sent_context + ' ' + head_sent_context if tail_sent_id < head_sent_id else head_sent_context + ' ' + tail_sent_context

                # Append data for the entities in the same sentence
                if same_sent_start_positions:
                    d = {
                        "context": head_sent_context,
                        "end_position": same_sent_end_positions,
                        "head_entity": [head_entity_type, head_text, head_offset_s, head_offset_e],
                        "entity_label": k,
                        "impossible": False,
                        "qas_id": f"{file_id}.{head_sent_id}.{head_sent_id}.{tid}",
                        "query": v.format(head_text),
                        "span_position": same_sent_span_positions,
                        "start_position": same_sent_start_positions
                    }
                    training_data.append(d)
                
                # Append data for the entities in different sentences
                for tail_sent_id, positions in tail_sent_id_to_positions.items():
                    tail_sent_context = " ".join(e[0] for e in nsents[tail_sent_id])
                    d = {
                        "context": tail_sent_context + ' ' + head_sent_context if tail_sent_id < head_sent_id else head_sent_context + ' ' + tail_sent_context,
                        "end_position": positions[1],
                        "head_entity": [head_entity_type, head_text, head_offset_s, head_offset_e],
                        "entity_label": k,
                        "impossible": False,
                        "qas_id": f"{file_id}.{head_sent_id}.{tail_sent_id}.{tid}",
                        "query": v.format(head_text),
                        "span_position": positions[2],
                        "start_position": positions[0]
                    }
                    training_data.append(d)
            else:
                # 'else' condition (i.e., 'head' entity does not have any 'tail' entity)
                d = {
                    "context": head_sent_context,
                    "end_position": [],
                    "head_entity": [head_entity_type, head_text, head_offset_s, head_offset_e],
                    "entity_label": k,
                    "impossible": True,
                    "qas_id": f"{file_id}.{head_sent_id}.{head_sent_id}.{tid}",
                    "query": v.format(head_text),
                    "span_position": [],
                    "start_position": []
                }
                training_data.append(d)

In [16]:
pout = Path("/data/datasets/cheng/ClinicalTransformerMRC/2018n2c2/dataset/mrc_relation")
pout.mkdir(parents=True, exist_ok=True)
to_json(training_data, pout, "train")

In [None]:
#create dev data
#RE task
dev_data = []

for i, fn in enumerate(p_dev.glob("*.ann")):

    txt_fn = p_dev/ f"{fn.stem}.txt"
    ann_fn = p_dev / f"{fn.name}"
    txt, sents = pre_processing(txt_fn, MIMICIII_PATTERN, max_len=256)
    e2i, ens, relations, evns, attrs = load_annotation_brat(ann_fn)
    i2e = {v: k for k, v in e2i.items()}
    nsents, sent_bound = generate_BIO(sents, ens, file_id="", no_overlap=False, record_pos=True)
    nnsents = [w for sent in nsents for w in sent]
    mappings = create_entity_to_sent_mapping(nnsents, ens, i2e, fn.name)

    num_sents = len(nsents)
    sent_ids = set(range(num_sents))
    sent_with_drugs = set()

    head_list = []
    tail_ann_idx_mappings = defaultdict(list)
    relation_ann_idx_mappings = defaultdict(list)


    # sentence with entities
    for en in ens:
        entype = en[1]
        ann_id = en[-1]
        en_offset = en[2]
        type_id = entity_id[entype]
        text = en[0]
        s_idx, e_idx = mappings[en[-1]]
        word_info1 = nnsents[s_idx]
        word_info2 = nnsents[e_idx]
        
        sent_idx1 = word_info1[3][0]
        sent_idx2 = word_info2[3][0]
        sent_with_drugs.add(sent_idx1)
        
        try:
            sent_text = get_sent(nsents, sent_idx1, sent_idx2)
            context = " ".join([e[0] for e in sent_text])
            start = word_info1[3][1]
            end = word_info2[3][1]
            start_end = f"{start};{end}"
        except Exception as ex:
            print(fn.name)
            print(word_info1, word_info2)
            print(context)
            print(start, end, start_end)
            print(entype, type_id)
            
            # key will be sent_id, type
            # data will be tuple (context, entype, type_id, start, end, start_end)
        if entype in head_entity:
            head_list.append(
                [sent_idx1, text, context, entype, en_offset, start, end, start_end, ann_id])
        else:
            tail_ann_idx_mappings[ann_id].append((sent_idx1, text, context, type_id, start, end, start_end))

    for r in relations:
        head_id = r[-1].split(":")[1]
        tail_ent = r[0].split('-')[0]
        tail_id = r[1].split(":")[1]

        relation_ann_idx_mappings[(head_id, tail_ent)].append(tail_id)
    
    for head_ent in head_list:
        head_sent_id = head_ent[0]
        head_text = head_ent[1]
        len_head_sent = len(nsents[head_sent_id])
        file_id = f"{fn.stem}"
        head_sent_context = " ".join(e[0] for e in nsents[head_sent_id])
        head_ann_id = head_ent[-1]
        head_entity_type = head_ent[3]
        head_offset_s, head_offset_e = head_ent[4]
      

        for k, v in attribute_query_template.items():
            tid = attribute_id[k]
            if (head_ann_id, k) in relation_ann_idx_mappings:
                tail_ann_ids = relation_ann_idx_mappings[(head_ann_id, k)]

                same_sent_start_positions = []
                same_sent_end_positions = []
                same_sent_span_positions = []

                tail_sent_id_to_positions = defaultdict(lambda: ([], [], []))  # map from tail_sent_id to positions list
                
                for tail_ann_id in tail_ann_ids:
                    tail_sent_id = tail_ann_idx_mappings[tail_ann_id][0][0]
                    tail_sent_context = tail_ann_idx_mappings[tail_ann_id][0][2]

                    s, e, se = tail_ann_idx_mappings[tail_ann_id][0][-3], tail_ann_idx_mappings[tail_ann_id][0][-2], tail_ann_idx_mappings[tail_ann_id][0][-1]

                    if tail_sent_id == head_sent_id:
                        same_sent_start_positions.append(s)
                        same_sent_end_positions.append(e)
                        same_sent_span_positions.append(se)
                    else:
                        s = s + (len_head_sent if tail_sent_id > head_sent_id else 0)
                        e = e + (len_head_sent if tail_sent_id > head_sent_id else 0)
                        se = f"{s};{e}"
                        tail_sent_id_to_positions[tail_sent_id][0].append(s)
                        tail_sent_id_to_positions[tail_sent_id][1].append(e)
                        tail_sent_id_to_positions[tail_sent_id][2].append(se)
                        tail_sent_context = tail_sent_context + ' ' + head_sent_context if tail_sent_id < head_sent_id else head_sent_context + ' ' + tail_sent_context

                # Append data for the entities in the same sentence
                if same_sent_start_positions:
                    d = {
                        "context": head_sent_context,
                        "end_position": same_sent_end_positions,
                        "head_entity": [head_entity_type, head_text, head_offset_s, head_offset_e],
                        "entity_label": k,
                        "impossible": False,
                        "qas_id": f"{file_id}.{head_sent_id}.{head_sent_id}.{tid}",
                        "query": v.format(head_text),
                        "span_position": same_sent_span_positions,
                        "start_position": same_sent_start_positions
                    }
                    dev_data.append(d)
                
                # Append data for the entities in different sentences
                for tail_sent_id, positions in tail_sent_id_to_positions.items():
                    tail_sent_context = " ".join(e[0] for e in nsents[tail_sent_id])
                    d = {
                        "context": tail_sent_context + ' ' + head_sent_context if tail_sent_id < head_sent_id else head_sent_context + ' ' + tail_sent_context,
                        "end_position": positions[1],
                        "head_entity": [head_entity_type, head_text, head_offset_s, head_offset_e],
                        "entity_label": k,
                        "impossible": False,
                        "qas_id": f"{file_id}.{head_sent_id}.{tail_sent_id}.{tid}",
                        "query": v.format(head_text),
                        "span_position": positions[2],
                        "start_position": positions[0]
                    }
                    dev_data.append(d)
            else:
                # 'else' condition (i.e., 'head' entity does not have any 'tail' entity)
                d = {
                    "context": head_sent_context,
                    "end_position": [],
                    "head_entity": [head_entity_type, head_text, head_offset_s, head_offset_e],
                    "entity_label": k,
                    "impossible": True,
                    "qas_id": f"{file_id}.{head_sent_id}.{head_sent_id}.{tid}",
                    "query": v.format(head_text),
                    "span_position": [],
                    "start_position": []
                }
                dev_data.append(d)

In [None]:
pout = Path("../2018_n2c2/data/mrc_relation")
pout.mkdir(parents=True, exist_ok=True)
to_json(dev_data, pout, "dev")

In [None]:
#create test data
#RE task
test_data = []

for i, fn in enumerate(p_dev.glob("*.ann")):

    txt_fn = p_dev/ f"{fn.stem}.txt"
    ann_fn = p_dev / f"{fn.name}"
    txt, sents = pre_processing(txt_fn, MIMICIII_PATTERN, max_len=256)
    e2i, ens, relations, evns, attrs = load_annotation_brat(ann_fn)
    i2e = {v: k for k, v in e2i.items()}
    nsents, sent_bound = generate_BIO(sents, ens, file_id="", no_overlap=False, record_pos=True)
    nnsents = [w for sent in nsents for w in sent]
    mappings = create_entity_to_sent_mapping(nnsents, ens, i2e, fn.name)

    num_sents = len(nsents)
    sent_ids = set(range(num_sents))
    sent_with_drugs = set()

    head_list = []
    tail_ann_idx_mappings = defaultdict(list)
    relation_ann_idx_mappings = defaultdict(list)


    # sentence with entities
    for en in ens:
        entype = en[1]
        ann_id = en[-1]
        en_offset = en[2]
        type_id = entity_id[entype]
        text = en[0]
        s_idx, e_idx = mappings[en[-1]]
        word_info1 = nnsents[s_idx]
        word_info2 = nnsents[e_idx]
        
        sent_idx1 = word_info1[3][0]
        sent_idx2 = word_info2[3][0]
        sent_with_drugs.add(sent_idx1)
        
        try:
            sent_text = get_sent(nsents, sent_idx1, sent_idx2)
            context = " ".join([e[0] for e in sent_text])
            start = word_info1[3][1]
            end = word_info2[3][1]
            start_end = f"{start};{end}"
        except Exception as ex:
            print(fn.name)
            print(word_info1, word_info2)
            print(context)
            print(start, end, start_end)
            print(entype, type_id)
            
            # key will be sent_id, type
            # data will be tuple (context, entype, type_id, start, end, start_end)
        if entype in head_entity:
            head_list.append(
                [sent_idx1, text, context, entype, en_offset, start, end, start_end, ann_id])
        else:
            tail_ann_idx_mappings[ann_id].append((sent_idx1, text, context, type_id, start, end, start_end))

    for r in relations:
        head_id = r[-1].split(":")[1]
        tail_ent = r[0].split('-')[0]
        tail_id = r[1].split(":")[1]

        relation_ann_idx_mappings[(head_id, tail_ent)].append(tail_id)
    
    for head_ent in head_list:
        head_sent_id = head_ent[0]
        head_text = head_ent[1]
        len_head_sent = len(nsents[head_sent_id])
        file_id = f"{fn.stem}"
        head_sent_context = " ".join(e[0] for e in nsents[head_sent_id])
        head_ann_id = head_ent[-1]
        head_entity_type = head_ent[3]
        head_offset_s, head_offset_e = head_ent[4]
      

        for k, v in attribute_query_template.items():
            tid = attribute_id[k]
            if (head_ann_id, k) in relation_ann_idx_mappings:
                tail_ann_ids = relation_ann_idx_mappings[(head_ann_id, k)]

                same_sent_start_positions = []
                same_sent_end_positions = []
                same_sent_span_positions = []

                tail_sent_id_to_positions = defaultdict(lambda: ([], [], []))  # map from tail_sent_id to positions list
                
                for tail_ann_id in tail_ann_ids:
                    tail_sent_id = tail_ann_idx_mappings[tail_ann_id][0][0]
                    tail_sent_context = tail_ann_idx_mappings[tail_ann_id][0][2]

                    s, e, se = tail_ann_idx_mappings[tail_ann_id][0][-3], tail_ann_idx_mappings[tail_ann_id][0][-2], tail_ann_idx_mappings[tail_ann_id][0][-1]

                    if tail_sent_id == head_sent_id:
                        same_sent_start_positions.append(s)
                        same_sent_end_positions.append(e)
                        same_sent_span_positions.append(se)
                    else:
                        s = s + (len_head_sent if tail_sent_id > head_sent_id else 0)
                        e = e + (len_head_sent if tail_sent_id > head_sent_id else 0)
                        se = f"{s};{e}"
                        tail_sent_id_to_positions[tail_sent_id][0].append(s)
                        tail_sent_id_to_positions[tail_sent_id][1].append(e)
                        tail_sent_id_to_positions[tail_sent_id][2].append(se)
                        tail_sent_context = tail_sent_context + ' ' + head_sent_context if tail_sent_id < head_sent_id else head_sent_context + ' ' + tail_sent_context

                # Append data for the entities in the same sentence
                if same_sent_start_positions:
                    d = {
                        "context": head_sent_context,
                        "end_position": [],
                        "head_entity": [head_entity_type, head_text, head_offset_s, head_offset_e],
                        "entity_label": k,
                        "impossible": True,
                        "qas_id": f"{file_id}.{head_sent_id}.{head_sent_id}.{tid}",
                        "query": v.format(head_text),
                        "span_position": [],
                        "start_position": []
                    }
                    test_data.append(d)
                
                # Append data for the entities in different sentences
                for tail_sent_id, positions in tail_sent_id_to_positions.items():
                    tail_sent_context = " ".join(e[0] for e in nsents[tail_sent_id])
                    d = {
                        "context": tail_sent_context + ' ' + head_sent_context if tail_sent_id < head_sent_id else head_sent_context + ' ' + tail_sent_context,
                        "end_position": [],
                        "head_entity": [head_entity_type, head_text, head_offset_s, head_offset_e],
                        "entity_label": k,
                        "impossible": True,
                        "qas_id": f"{file_id}.{head_sent_id}.{tail_sent_id}.{tid}",
                        "query": v.format(head_text),
                        "span_position": [],
                        "start_position": []
                    }
                    test_data.append(d)
            else:
                # 'else' condition (i.e., 'head' entity does not have any 'tail' entity)
                d = {
                    "context": head_sent_context,
                    "end_position": [],
                    "head_entity": [head_entity_type, head_text, head_offset_s, head_offset_e],
                    "entity_label": k,
                    "impossible": True,
                    "qas_id": f"{file_id}.{head_sent_id}.{head_sent_id}.{tid}",
                    "query": v.format(head_text),
                    "span_position": [],
                    "start_position": []
                }
                test_data.append(d)

In [None]:
pout = Path("../2018_n2c2/data/mrc_relation")
pout.mkdir(parents=True, exist_ok=True)
to_json(test_data, pout, "test")

In [23]:
def remap_index_to_wordindex (token_start, token_end, tokens, ccontext, nnsents, sent_id):

    corpora_records = ccontext.split()
    word_2_char_mapping={}
    char_cursor=0
    for ind in range(len(corpora_records)):
        if(len(corpora_records[ind])>0):#the last space will not be considered
            start=char_cursor
            end=char_cursor+len(corpora_records[ind])
            word_2_char_mapping[ind]=[start,end]
            char_cursor=char_cursor+len(corpora_records[ind])+1#consider the white-space length
    # print(ccontext)
    # print(word_2_char_mapping)
    start_char_span=tokens.token_to_chars(token_start)
    end_char_span=tokens.token_to_chars(token_end)
    # print(start_char_span,end_char_span)
    for each_word in word_2_char_mapping:
        start = word_2_char_mapping[each_word][0]
        end = word_2_char_mapping[each_word][1]
        
        if(start_char_span[0]>=start and start_char_span[1]<=end):
            s_char = each_word

        if(end_char_span[0]>=start and end_char_span[1]<=end):
            e_char = each_word
            
    print(sent_id, s_char, e_char, ccontext)
    for i, sent in enumerate(nnsents):
        if (sent_id, s_char) == sent[3]:
            s_offset = sent[1][0]
            
        if (sent_id, e_char) == sent[3]:
            e_offset = sent[1][1]
            
    return s_offset, e_offset 

In [None]:
#convert the output to brat format
#NER task
import os
from tokenizers import BertWordPieceTokenizer
from pathlib import Path
import json


# Define paths
data_root = Path("../2018_n2c2")
test_data_path = data_root / "test_data"
mrc_data_file = data_root / "mrc_entity/mrc-ner.test"
result_path = data_root / "exp/ner/pred"
entity_model_name = 'bert-large-cased'
entity_prediction_file = result_path / f"{entity_model_name}.json"
output_path = data_root / "exp/ner/results/"
bert_model_path = "../bert-large-cased"

vocab_file = os.path.join(bert_model_path, "vocab.txt")

# Initialize tokenizer
tokenizer = BertWordPieceTokenizer(vocab_file)

# Load entity predictions
with open(entity_prediction_file, "r") as f:
    entity_predictions = json.load(f)

# Load MRC formart entity data
with open(mrc_data_file, "r") as f:
    mrc_entity_data = json.load(f)

# Define output format
BRAT_TEMPLATE_T = "{}\t{} {} {}\t{}"
output_template = BRAT_TEMPLATE_T

output_file_suffix = 'ann'

for i, fn in enumerate(test_data_path.glob("*.txt")):
    
    txt_fn = test_data_path / f"{fn.name}"
    txt_data = open(txt_fn,'r').read()
    txt, sents = pre_processing(txt_fn, MIMICIII_PATTERN, max_len=256)
    output_fn = output_path / "{}.{}".format(fn.stem, output_file_suffix)
    
    nsents, sent_bound = generate_BIO(sents, [], file_id="", no_overlap=False, record_pos=True)
    nnsents = [w for sent in nsents for w in sent]

    num_sents = len(nsents)
    sent_ids = set(range(num_sents))
    
    # Match entity predictions for each txt file
    matched_entity_predictions = []

    
    for i, pred in enumerate(entity_predictions):
        if int(str(pred['sample_idx'])) == int(fn.stem):
            matched_entity_predictions.append([i, pred])
    
    # Map token-level predictions back to original text and format in BRAT
    brat_entities = []

    for prediction_idex, entity_prediction in matched_entity_predictions:
        sent_id= int(str(entity_prediction['head_sent_idx'][0]))
    
        context = mrc_entity_data[prediction_idex]['context']
        query = mrc_entity_data[prediction_idex]['query']
        query_tokens = tokenizer.encode(query, add_special_tokens=False)
        context_tokens = tokenizer.encode(context, add_special_tokens=False)

        
        for entity in entity_prediction['en']:
            token_start = entity[0]- len(query_tokens) - 2
            token_end = entity[1]- len(query_tokens) - 3
            print(entity)
            entity_start, entity_end = remap_index_to_wordindex(token_start, token_end, context_tokens, context, nnsents, sent_id)
            ent_type = entity[3]
            entity_word = txt_data[entity_start: entity_end].replace("\n", " ")
            brat_entities.append((ent_type, entity_start, entity_end, entity_word))
            
# Write BRAT-formatted entities to output file
    with open(output_fn, "w") as f:
        for entity in brat_entities:
            f.write(output_template.format(*entity))
            f.write("\n")

In [None]:
#convert the output to brat format
#RE task
import os
from tokenizers import BertWordPieceTokenizer
from pathlib import Path
import json



data_root = Path("../2018_n2c2")
test_data_path = data_root / "test_data"
mrc_data_file = data_root / "mrc_relation/mrc-ner.test"
result_path = data_root / "exp/re/pred"
relation_model_name = 'bert-large-cased'
relation_prediction_file = result_path / f"{entity_model_name}.json"
output_path = data_root / "exp/ner/results/"
bert_model_path = "../bert-large-cased"
vocab_file = os.path.join(bert_model_path, "vocab.txt")

# Initialize tokenizer
tokenizer = BertWordPieceTokenizer(vocab_file)

# Load relation predictions
with open(relation_prediction_file, "r") as f:
    relation_predictions = json.load(f)

# Load MRC formart entity data
with open(mrc_data_file, "r") as f:
    mrc_relation_data = json.load(f)

# Define output format
BRAT_TEMPLATE_T = "{}\t{} {} {}\t{}"
output_template_t = BRAT_TEMPLATE_T
BRAT_TEMPLATE_R = "{}\t{}-{} Arg1:{} Arg2:{}"
output_template_r = BRAT_TEMPLATE_R

file_suffix = 'ann'

for i, fn in enumerate(test_data_path.glob("*.txt")):
    txt_fn = test_data_path / f"{fn.name}"
    txt_data = open(txt_fn,'r').read()
    
    txt, sents = pre_processing(txt_fn, MIMICIII_PATTERN, max_len=256)
    output_fn = output_path / "{}.{}".format(fn.stem, file_suffix)
    
    nsents, sent_bound = generate_BIO(sents, [], file_id="", no_overlap=False, record_pos=True)
    nnsents = [w for sent in nsents for w in sent]

    num_sents = len(nsents)
    sent_ids = set(range(num_sents))
  
    # Match entity predictions for each txt file
    matched_relation_predictions = []

    event_prediction = defaultdict(list)
     
    for i, pred in enumerate(relation_predictions):
         if int(str(pred['sample_idx'])) == int(fn.stem):
            matched_relation_predictions.append([i, pred])
    

    for prediction_idex, relation_prediction in matched_relation_predictions:
    
        head_sent_id = int(relation_prediction['head_sent_idx'][0])
        tail_sent_id = int(relation_prediction['tail_sent_idx'][0])
        
        head_type = mrc_relation_data[prediction_idex]['head_entity'][0]
        head_text = mrc_relation_data[prediction_idex]['head_entity'][1]
        head_offset_s = mrc_relation_data[prediction_idex]['head_entity'][2]
        head_offset_e = mrc_relation_data[prediction_idex]['head_entity'][3]
        
        context = mrc_relation_data[prediction_idex]['context']
        query = mrc_relation_data[prediction_idex]['query']
        query_tokens = tokenizer.encode(query, add_special_tokens=False)
        context_tokens = tokenizer.encode(context, add_special_tokens=False)

        
        
        for tail_entity in relation_prediction['en']:
            ent_type = tail_entity[3]
            if head_sent_id == tail_sent_id:
                token_start = tail_entity[0] - len(query_tokens) - 2
                token_end = tail_entity[1] - len(query_tokens) - 3
                
            
                entity_start, entity_end = remap_index_to_wordindex (token_start, token_end, context_tokens, context, nnsents, head_sent_id)
            
            if head_sent_id < tail_sent_id:
                head_sent = " ".join(e[0] for e in nsents[head_sent_id])
                head_sent_tokens = tokenizer.encode(head_sent, add_special_tokens=False)
                token_start = tail_entity[0] - len(query_tokens) - len(head_sent_tokens) - 2
                token_end = tail_entity[1] - len(query_tokens) - len(head_sent_tokens) - 3
                tail_sent = " ".join(e[0] for e in nsents[tail_sent_id])
                tail_sent_tokens = tokenizer.encode(tail_sent, add_special_tokens=False)
                entity_start, endtity_end = remap_index_to_wordindex (token_start, token_end, tail_sent_tokens, tail_sent, nnsents, tail_sent_id) 

            if head_sent_id > tail_sent_id:
                tail_sent = " ".join(e[0] for e in nsents[tail_sent_id])
                tail_sent_tokens = tokenizer.encode(tail_sent, add_special_tokens=False)
                token_start = tail_entity[0] - len(query_tokens) - 2
                token_end = tail_entity[1] - len(query_tokens) - 3
                entity_start, entity_end = remap_index_to_wordindex (token_start, token_end, tail_sent_tokens, tail_sent, nnsents, tail_sent_id)

            entity_word = txt_data[entity_start:entity_end].replace("\n", " ")
            
            event_prediction[(head_text, head_offset_s, head_offset_e)] .append((ent_type, entity_start, entity_end, entity_word))
    
    output_tr = []
    i = 1
    k = 1
    m = 1
    for event in list(event_prediction):
        head_ent = event

        text_h, head_s, head_e  = head_ent
        formatted_output_head = output_template_t.format("T{}".format(i), 'Drug', head_s, head_e, text_h)
        output_tr.append (formatted_output_head)
        i= i + 1 

        for attributes in event_prediction[event]:
            att_type, tail_s, tail_e, text_t  = attributes
            formatted_output_att = output_template_t.format("T{}".format(i), att_type, tail_s, tail_e, text_t)
            formatted_output_rel = output_template_r.format("R{}".format(m), att_type, 'Drug', "T{}".format(i), "T{}".format(k))
            output_tr.append(formatted_output_att)
            output_tr.append(formatted_output_rel)
            i = i + 1
            m = m + 1
        k = i

    with open(output_fn, "w") as f:
        formatted_output = "\n".join(output_tr)
        f.write(formatted_output)
        f.write("\n")