In [1]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

In [2]:
from pathlib import Path
import shutil

In [3]:
p_tr = Path("/data/datasets/cheng/mrc-for-ner-medical/2018_n2c2/2018n2c2_track2_training/")
p_dev = Path("/data/datasets/cheng/mrc-for-ner-medical/2018_n2c2/gold_standard_test/")

# p_tr = Path("/data/datasets/cheng/mrc-for-ner-medical/2018_n2c2/2018n2c2_track2_training/")

fids = []

fids.extend([f for f in p_tr.glob("*.ann")])
fids.extend([f for f in p_dev.glob("*.ann")])

len(fids)

505

In [4]:
p = Path("/data/datasets/cheng/mrc-for-ner-medical/2018_n2c2/data/mimic")
p.mkdir(parents=True, exist_ok=True)

xx = [shutil.copyfile(fn, p/f"{fn.name}") for fn in fids]
xx = [shutil.copyfile(fn.parent/f"{fn.stem}.txt", p/f"{fn.stem}.txt") for fn in fids]

In [5]:
from pathlib import Path
import pickle as pkl
from collections import defaultdict, Counter
from itertools import permutations, combinations
from functools import reduce
import numpy as np
import os
import re
import sys
import json

# https://github.com/uf-hobi-informatics-lab/NLPreprocessing (git clone this repo to local)
sys.path.append("./NLPreprocessing/")
sys.path.append("./NLPreprocessing/text_process")
from annotation2BIO import pre_processing, read_annotation_brat, generate_BIO
MIMICIII_PATTERN = ""
from sentence_tokenization import logger as l1
from annotation2BIO import logger as l2
l1.disabled = True
l2.disabled = True

def pkl_save(data, file):
    with open(file, "wb") as f:
        pkl.dump(data, f)

        
def pkl_load(file):
    with open(file, "rb") as f:
        data = pkl.load(f)
    return data


def load_text(ifn):
    with open(ifn, "r") as f:
        txt = f.read()
    return txt


def save_text(text, ofn):
    with open(ofn, "w") as f:
        f.write(text)

In [6]:
def create_entity_to_sent_mapping(nnsents, entities, idx2e, fn):
    loc_ens = []
    
    ll = len(nnsents)
    mapping = defaultdict(list)
    for idx, each in enumerate(entities):
        en_label = idx2e[idx]
        en_s = each[2][0]
        en_e = each[2][1]
        new_en = []
        
        i = 0
        while i < ll and nnsents[i][1][0] < en_s:
            i += 1
        s_s = nnsents[i][1][0]
        s_e = nnsents[i][1][1]

        if en_s == s_s:
            mapping[en_label].append(i)

            while i < ll and s_e < en_e:
                i += 1
                s_e = nnsents[i][1][1]
            if s_e == en_e:
                 mapping[en_label].append(i)
            else:
                mapping[en_label].append(i)
                print(fn)
                print("last index not match ", each)
        else:
            mapping[en_label].append(i)
            print(fn)
            print("first index not match ", each)

            while i < ll and s_e < en_e:
                i += 1
                s_e = nnsents[i][1][1]
            if s_e == en_e:
                 mapping[en_label].append(i)
            else:
                mapping[en_label].append(i)
                print(fn)
                print("last index not match ", each)
    return mapping

    
def __ann_info(ann):
    en_info = ann.split(" ")
    return en_info[0], int(en_info[1]), int(en_info[-1])


def load_annotation_brat(ann_file, rep=False):
    """
    load annotation data
    entity_id2index_map -> {'T1': 0}
    entites -> ('T1', 'anticoagulant medications', 'Drug', (1000, 1025))
    relations -> ('Route-Drug', 'T3', 'T2')
    
    T31	NoDisposition 8778 8786	depakote
    E31	NoDisposition:T31 
    T32	Disposition 8684 8697	anti-htn meds
    E32	Disposition:T32 
    A1	Certainty E32 Certain
    A2	Actor E32 Physician
    A3	Action E32 Stop
    """
    # map the entity id (e.g., T1) to its index in entities list
    entity_id2index_map = dict()
    entites = []
    relations = []
    events = []
    attrs = []
    
    with open(ann_file, "r") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            anns = line.split("\t")
            ann_id = anns[0]
            if ann_id.startswith("T"):
                t_type = anns[-1]
                entity_words, offset_s, offset_e = __ann_info(anns[1])
                entites.append((t_type,  entity_words, (offset_s, offset_e), ann_id))
                entity_id2index_map[ann_id] = len(entites) - 1
            elif ann_id.startswith("A"):
                att_type, env_id, att_val = anns[1].strip().split(" ")
                attrs.append((att_type, env_id, att_val))
            elif ann_id.startswith("E"):
                # E2	Alcohol:T3 Status:T4 Amount:T5 Frequency:T6 Type:T10
                single_event = dict()
                envs = anns[1].split(" ")
                trigger = envs[0].strip().split(":")
                tas = []
                for each in envs[1:]:
                    tas.append(each.strip().split(":"))
                
                single_event["trigger"] = trigger
                single_event["events"] = tas
                    
                events.append(single_event)
            elif ann_id.startswith("R"):
                relation, tail, head = anns[1].split()[0], anns[1].split()[1], anns[1].split()[-1]
                
        
                relations.append((relation, tail, head))

    return entity_id2index_map, entites, relations, events, attrs

In [28]:
entities = {'Dosage', 'Route', 'Form', 'Reason', 'ADE', 'Duration', 'Strength', 'Drug', 'Frequency'}
tails = {'Dosage', 'Route', 'Form', 'Reason', 'ADE', 'Duration', 'Strength', 'Frequency'}
p_test = Path("/data/datasets/zehao/2022_n2c2/att/res_067_testA_e2e_2")

all_entity_types = set()
nested_entities_count = 0
for i, fn in enumerate(p_test.glob("*.ann")):
    ann_fn = p_test / f"{fn.name}"
    e2i, ens, _, evns, attrs = load_annotation_brat(ann_fn)
    all_entity_types.update(set([e[1] for e in ens]))
    entities = {}
    for each in ens:
        entity_id, entity_type, span, text = each[-1], each[1], each[2], each[0]
        if span in entities:
            entities[span].append((entity_id, entity_type, text))
        else:
            entities[span] = [(entity_id, entity_type, text)]
    
    nested_entities = [entity_list for entity_list in entities.values() if len(entity_list) > 1]
    nested_entities_count += len(nested_entities)
    
print(nested_entities_count)

524


In [8]:
"""
sample data:
a list of dict as
  {
    "context": "Germany 's representative to the European Union 's veterinary committee Werner Zwingmann said on Wednesday consumers should buy sheepmeat from countries other than Britain until the scientific advice was clearer .",
    "end_position": [
      0,
      23
    ],
    "entity_label": "LOC",
    "impossible": false,
    "qas_id": "4.3",
    "query": "location entities are the name of politically or geographically defined locations such as cities, provinces, countries, international regions, bodies of water, mountains, etc.",
    "span_position": [
      "0;0",
      "23;23"
    ],
    "start_position": [
      0,
      23
    ]
  },
  {
    "context": "EU rejects German call to boycott British lamb .",
    "end_position": [],
    "entity_label": "PER",
    "impossible": true,
    "qas_id": "0.2",
    "query": "person entities are named persons or family.",
    "span_position": [],
    "start_position": []
  }, ...

"""


# create trigger template
# make sure the label id in qas_id is consistent for each type of entity
entity_template = {
    "Drug": "Drug events include names, brand names and collective names of prescription substances and over-the-counter medications",
    "Strength": "Strength is the amount of drug in a given dosage, for example, 50 mg",
    "Form": "Form is the physical form of given drug or medication, for example, pill, tablet, capsule, powder or injection", 
    "Dosage": "Dosage is the amount of a medication used in each administration", 
    "Frequency":"Frequency indicates how often each dose of the medication should be taken", 
    "Route": "Route is the path by which a drug is taken into the body", 
    "Duration": "Duration indicates how long the medication is to be administered", 
    "Reason": " Reason indicates the medical reason for which the medication is given",
    "ADE": "ADEs are injuries resulting from a medical intervention related to drugs"
}

attribute_template_new = {
    "Strength": "What is the active ingredient amount of {}",
    "Form": "What is the physical form of {}", 
    "Dosage": "What is the amount of {} taken", 
    "Frequency":"How often each dose of {} should be taken", 
    "Route": "What is the path of {} taken into the body", 
    "Duration": "How long to take {}", 
    "Reason": "What is the medical reason for giving {}",
    "ADE": "What are the injuries resulting from the use of {}"
}

attribute_template = {
    "Strength": "What is the amount of {} in a given dosage",
    "Form": "What is the physical form of {}", 
    "Dosage": "What is the amount of {} used in each administration", 
    "Frequency":"How often each dose of {} should be taken", 
    "Route": "What is the path of {} taken into the body", 
    "Duration": "How long is {} to be administered", 
    "Reason": "What is the medical reason for giving {}",
    "ADE": "What are the injuries resulting from a medical intervention related to {}"
}

entity_id = {
    "Drug": 1, "Strength": 2, "Form": 3, "Dosage": 4, "Frequency": 5, "Route": 6, "Duration": 7, "Reason": 8, "ADE": 9
}

tail_id = {
    "Strength": 1, "Form": 2, "Dosage": 3, "Frequency": 4, "Route": 5, "Duration": 6, "Reason": 7, "ADE": 8
}

In [9]:
def get_sent(sents, idx1, idx2):
    if idx1 == idx2:
        return sents[idx1]
    elif idx1 == idx2 + 1:
        raise Exception(f"{idx1} {idx2} - entity not in the same sentence")
    else:
        raise Exception(f"{idx1} {idx2} - the entity has word spread in >2 sentences")
        
        
def to_json(data, p, fn="train"):
    import json
    
    ofn = p / f"mrc-ner.{fn}"
    
    with open(ofn, "w") as f:
        json.dump(data, f, indent=2)

def to_json_0(data, p, fn="train"):
    import json
    
    ofn = p / f"mrc-ner_0.{fn}"
    
    with open(ofn, "w") as f:
        json.dump(data, f, indent=2)


def to_json_1(data, p, fn="train"):
    import json
    
    ofn = p / f"mrc-ner_1.{fn}"
    
    with open(ofn, "w") as f:
        json.dump(data, f, indent=2)

def to_json_2(data, p, fn="train"):
    import json
    
    ofn = p / f"mrc-ner_2.{fn}"
    
    with open(ofn, "w") as f:
        json.dump(data, f, indent=2)


In [10]:
##single sentence as a sample

training_data = []


for i, fn in enumerate(p_tr.glob("*.ann")):

    txt_fn = p_tr / f"{fn.stem}.txt"
    ann_fn = p_tr / f"{fn.name}"
    txt, sents = pre_processing(txt_fn, MIMICIII_PATTERN, max_len=256)
    e2i, ens, relations, evns, attrs = load_annotation_brat(ann_fn)
    i2e = {v: k for k, v in e2i.items()}
    nsents, sent_bound = generate_BIO(sents, ens, file_id="", no_overlap=False, record_pos=True)
    nnsents = [w for sent in nsents for w in sent]
    mappings = create_entity_to_sent_mapping(nnsents, ens, i2e, fn.name)

    num_sents = len(nsents)
    sent_ids = set(range(num_sents))
    sent_with_entities = set()

    entity_sent_idx_mappings = defaultdict(list)

    # sentence with entities
    for en in ens:

        entype = en[1]
        type_id = entity_id[entype]
        
        s_idx, e_idx = mappings[en[-1]]
        word_info1 = nnsents[s_idx]
        word_info2 = nnsents[e_idx]
        
        sent_idx1 = word_info1[3][0]
        sent_idx2 = word_info2[3][0]
        sent_with_entities.add(sent_idx1)
    #         sent_with_entities.add(sent_idx2)
        
        try:
            sent_text = get_sent(nsents, sent_idx1, sent_idx2)
            context = " ".join([e[0] for e in sent_text])
            start = word_info1[3][1]
            end = word_info2[3][1]
            start_end = f"{start};{end}"
        except Exception as ex:
            print(fn.name)
            print(word_info1, word_info2)
            print(context)
            print(start, end, start_end)
            print(entype, type_id)
        
        # key will be sent_id, type
        # data will be tuple (context, entype, type_id, start, end, start_end)
        entity_sent_idx_mappings[(sent_idx1, entype)].append(
            (context, entype, type_id, start, end, start_end))


    # for i in sent_ids:
    for i in sent_with_entities:
        file_id = f"{fn.stem}"
        sent_id = f"{i}"
        sent_i_context = " ".join(e[0] for e in nsents[i])
        for k, v in entity_template.items():
            tid = entity_id[k]
            
            if (i, k) in entity_sent_idx_mappings:
                d = {
                        "context": sent_i_context,
                        "end_position": [],
                        "entity_label": k,
                        "impossible": False,
                        "qas_id": f"{file_id}.{sent_id}.{sent_id}.{tid}",
                        "query": v,
                        "span_position": [],
                        "start_position": []
                        }
                
                entities = entity_sent_idx_mappings[(i, k)]
        
                for ent in entities:
                    assert ent[0] == sent_i_context, f"expect context: {ent[0]} but get {sent_i_context}"
                    assert ent[1] == k, f"expect en type: {ent[1]} but get {k}"
                    s, e, se = ent[-3:]
                    d["start_position"].append(s)
                    d["end_position"].append(e)
                    d["span_position"].append(se)
                
                training_data.append(d)
            else:
                d = {
                        "context": sent_i_context,
                        "end_position": [],
                        "entity_label": k,
                        "impossible": True,
                        "qas_id": f"{file_id}.{sent_id}.{sent_id}.{tid}",
                        "query": v,
                        "span_position": [],
                        "start_position": []
                    }
                training_data.append(d)

161477.ann
last index not match  ('O2', 'Drug', (1701, 1703), 'T30')
112832.ann
last index not match  ('1mg', 'Strength', (9936, 9939), 'T150')
112832.ann
first index not match  ('MWF', 'Frequency', (9939, 9942), 'T151')
112832.ann
last index not match  ('MWF', 'Frequency', (9939, 9942), 'T151')
106361.ann
last index not match  ('2L', 'Dosage', (11456, 11458), 'T9')
106361.ann
first index not match  ('NC', 'Route', (11458, 11460), 'T10')
106361.ann
last index not match  ('NC', 'Route', (11458, 11460), 'T10')
106361.ann
last index not match  ('2L', 'Dosage', (11665, 11667), 'T12')
106361.ann
first index not match  ('NC', 'Route', (11667, 11669), 'T13')
106361.ann
last index not match  ('NC', 'Route', (11667, 11669), 'T13')
114144.ann
first index not match  ('O2', 'Drug', (3232, 3234), 'T161')
114144.ann
last index not match  ('O2', 'Drug', (3232, 3234), 'T161')
103293.ann
last index not match  ('1', 'Dosage', (5573, 5574), 'T71')
103293.ann
first index not match  ('2', 'Dosage', (9896, 

In [11]:
pout = Path("/data/datasets/cheng/mrc-for-ner-medical/2018_n2c2/data/mrc_entity_new")
# pout = Path("/data/datasets/cheng/mrc-for-ner-medical/2018_n2c2/data/mrc_trigger")
pout.mkdir(parents=True, exist_ok=True)

In [12]:
to_json(training_data, pout, "train")

In [13]:
##single sentence as a sample

dev_data = []


for i, fn in enumerate(p_dev.glob("*.ann")):

    txt_fn = p_dev / f"{fn.stem}.txt"
    ann_fn = p_dev / f"{fn.name}"
    txt, sents = pre_processing(txt_fn, MIMICIII_PATTERN, max_len=256)
    e2i, ens, _, evns, attrs = load_annotation_brat(ann_fn)
    i2e = {v: k for k, v in e2i.items()}
    nsents, sent_bound = generate_BIO(sents, ens, file_id="", no_overlap=False, record_pos=True)
    nnsents = [w for sent in nsents for w in sent]
    mappings = create_entity_to_sent_mapping(nnsents, ens, i2e, fn.name)

    num_sents = len(nsents)
    sent_ids = set(range(num_sents))
    sent_with_entities = set()

    entity_sent_idx_mappings = defaultdict(list)

    # sentence with entities
    for en in ens:

        entype = en[1]
        type_id = entity_id[entype]
        
        s_idx, e_idx = mappings[en[-1]]
        word_info1 = nnsents[s_idx]
        word_info2 = nnsents[e_idx]
        
        sent_idx1 = word_info1[3][0]
        sent_idx2 = word_info2[3][0]
        sent_with_entities.add(sent_idx1)
    #         sent_with_entities.add(sent_idx2)
        
        try:
            sent_text = get_sent(nsents, sent_idx1, sent_idx2)
            context = " ".join([e[0] for e in sent_text])
            start = word_info1[3][1]
            end = word_info2[3][1]
            start_end = f"{start};{end}"
        except Exception as ex:
            print(fn.name)
            print(word_info1, word_info2)
            print(context)
            print(start, end, start_end)
            print(entype, type_id)
        
        # key will be sent_id, type
        # data will be tuple (context, entype, type_id, start, end, start_end)
        entity_sent_idx_mappings[(sent_idx1, entype)].append(
            (context, entype, type_id, start, end, start_end))

    # for i in sent_ids:
    for i in sent_with_entities:
        file_id = f"{fn.stem}"
        sent_id = f"{i}"
        sent_i_context = " ".join(e[0] for e in nsents[i])
        for k, v in entity_template.items():
            tid = entity_id[k]
            if (i, k) in entity_sent_idx_mappings:
                d = {
                        "context": sent_i_context,
                        "end_position": [],
                        "entity_label": k,
                        "impossible": False,
                        "qas_id": f"{file_id}.{sent_id}.{sent_id}.{tid}",
                        "query": v,
                        "span_position": [],
                        "start_position": []
                        }
                
                entities = entity_sent_idx_mappings[(i, k)]
        
                for ent in entities:
                    assert ent[0] == sent_i_context, f"expect context: {ent[0]} but get {sent_i_context}"
                    assert ent[1] == k, f"expect en type: {ent[1]} but get {k}"
                    s, e, se = ent[-3:]
                    d["start_position"].append(s)
                    d["end_position"].append(e)
                    d["span_position"].append(se)
                
                dev_data.append(d)
            else:
                d = {
                        "context": sent_i_context,
                        "end_position": [],
                        "entity_label": k,
                        "impossible": True,
                        "qas_id": f"{file_id}.{sent_id}.{sent_id}.{tid}",
                        "query": v,
                        "span_position": [],
                        "start_position": []
                    }
                dev_data.append(d)

125281.ann
first index not match  ('Levaquin', 'Drug', (5311, 5319), 'T41')
125281.ann
last index not match  ('Levaquin', 'Drug', (5311, 5319), 'T41')
185982.ann
last index not match  ('O2', 'Drug', (1049, 1051), 'T78')
107902.ann
last index not match  ('p.o', 'Route', (5493, 5496), 'T36')
107902.ann
last index not match  ('p.o', 'Route', (5555, 5558), 'T45')
107902.ann
last index not match  ('p.o', 'Route', (5625, 5628), 'T52')
107902.ann
last index not match  ('p.o', 'Route', (6209, 6212), 'T85')
189637.ann
last index not match  ('Warfarin', 'Drug', (10053, 10061), 'T2')
189637.ann
first index not match  ('5 mg', 'Strength', (10061, 10065), 'T3')
114044.ann
first index not match  ('1', 'Dosage', (1571, 1572), 'T53')
114044.ann
last index not match  ('1', 'Dosage', (1571, 1572), 'T53')
114044.ann
first index not match  ('3', 'Dosage', (1588, 1589), 'T82')
114044.ann
last index not match  ('3', 'Dosage', (1588, 1589), 'T82')
106588.ann
first index not match  ('1', 'Dosage', (1548, 1549

In [None]:
to_json(dev_data, pout, "dev")

In [None]:
##single sentence as a sample
##test dataset
test_data = []


for i, fn in enumerate(p_dev.glob("*.ann")):

    txt_fn = p_dev / f"{fn.stem}.txt"
    ann_fn = p_dev / f"{fn.name}"
    txt, sents = pre_processing(txt_fn, MIMICIII_PATTERN, max_len=256)
    e2i, ens, _, evns, attrs = load_annotation_brat(ann_fn)
    i2e = {v: k for k, v in e2i.items()}
    nsents, sent_bound = generate_BIO(sents, ens, file_id="", no_overlap=False, record_pos=True)
    nnsents = [w for sent in nsents for w in sent]
    mappings = create_entity_to_sent_mapping(nnsents, ens, i2e, fn.name)

    num_sents = len(nsents)
    sent_ids = set(range(num_sents))
    sent_with_entities = set()

    entity_sent_idx_mappings = defaultdict(list)

    # sentence with entities
    for en in ens:

        entype = en[1]
        type_id = entity_id[entype]
        
        s_idx, e_idx = mappings[en[-1]]
        word_info1 = nnsents[s_idx]
        word_info2 = nnsents[e_idx]
        
        sent_idx1 = word_info1[3][0]
        sent_idx2 = word_info2[3][0]
        sent_with_entities.add(sent_idx1)
    #         sent_with_entities.add(sent_idx2)

    # for i in sent_ids:
    for i in sent_with_entities:
        file_sent_id = f"{fn.stem}_{i}"
        sent_i_context = " ".join(e[0] for e in nsents[i])
        for k, v in entity_template.items():
            tid = entity_id[k]
            d = {
                    "context": sent_i_context,
                    "end_position": [],
                    "entity_label": k,
                    "impossible": True,
                    "qas_id": f"{file_sent_id}.{tid}",
                    "query": v,
                    "span_position": [],
                    "start_position": []
                    }
            test_data.append(d)
          

In [None]:
to_json(test_data, pout, "test")

In [10]:
def find_sent_id(offset, sent_bound):
    for i in sent_bound:
        if offset[0]>= sent_bound[0] and offset[1] <= sent_bound[1]:
            sent_id = i
    return i


In [12]:
##single sentence as a sample
# relation extraction

training_data = []


for i, fn in enumerate(p_tr.glob("*.ann")):

    txt_fn = p_tr / f"{fn.stem}.txt"
    ann_fn = p_tr / f"{fn.name}"
    txt, sents = pre_processing(txt_fn, MIMICIII_PATTERN, max_len=256)
    e2i, ens, relations, evns, attrs = load_annotation_brat(ann_fn)
    i2e = {v: k for k, v in e2i.items()}
    nsents, sent_bound = generate_BIO(sents, ens, file_id="", no_overlap=False, record_pos=True)
    nnsents = [w for sent in nsents for w in sent]
    mappings = create_entity_to_sent_mapping(nnsents, ens, i2e, fn.name)

    num_sents = len(nsents)
    sent_ids = set(range(num_sents))
    sent_with_drugs = set()

    drug_lists = []
    attributes_lists = []
    tail_ann_idx_mappings = defaultdict(list)
    relation_idx_mappings = defaultdict(list)


    # sentence with entities
    for en in ens:
        entype = en[1]
        type_id = entity_id[entype]
        text = en[0]
        s_idx, e_idx = mappings[en[-1]]
        word_info1 = nnsents[s_idx]
        word_info2 = nnsents[e_idx]
        
        sent_idx1 = word_info1[3][0]
        sent_idx2 = word_info2[3][0]
        sent_with_drugs.add(sent_idx1)
        
        try:
            sent_text = get_sent(nsents, sent_idx1, sent_idx2)
            context = " ".join([e[0] for e in sent_text])
            start = word_info1[3][1]
            end = word_info2[3][1]
            start_end = f"{start};{end}"
        except Exception as ex:
            print(fn.name)
            print(word_info1, word_info2)
            print(context)
            print(start, end, start_end)
            print(entype, type_id)
            
            # key will be sent_id, type
            # data will be tuple (context, entype, type_id, start, end, start_end)
        if entype == "Drug":
            drug_lists.append(
                [sent_idx1, text, context, start, end, start_end, en[-1]])
        else:
            tail_ann_idx_mappings[(en[-1])].append((sent_idx1, text, context, type_id, start, end, start_end))

    for pair in relations:
        relation_idx_mappings[(pair[-1].split(":")[1],pair[0].split('-')[0])].append((pair[1].split(":")[1]))
       
        


    # for i in sent_ids:
    # for i in sent_with_drugs:
    #     file_sent_id = f"{fn.stem}_{i}"
    #     sent_i_context = " ".join(e[0] for e in nsents[i])
    #     drug_entities = entity_sent_idx_mappings[(i, "Drug")]

    for drug_ent in drug_lists:
        sent_id = drug_ent[0]
        drug_text = drug_ent[1]
        len_sent = len(nsents[sent_id])
        file_sent_id = f"{fn.stem}.{sent_id}"
        drug_sent_context = " ".join(e[0] for e in nsents[drug_ent[0]])
        ann_id = drug_ent[-1]
        for k, v in attribute_template.items():
            tid = tail_id[k]
            
            if (ann_id, k) in relation_idx_mappings:
                
                attributes = relation_idx_mappings[(ann_id, k)]

                sstart=[]
                eend=[]
                sspan=[]
        
                for ent in attributes:
                    tail_ann_id = ent
                    tail_sent_id = tail_ann_idx_mappings[(tail_ann_id)][0][0]
                    tail_context = tail_ann_idx_mappings[(tail_ann_id)][0][2]
                    if sent_id == tail_sent_id:
                        s, e, se = tail_ann_idx_mappings[(tail_ann_id)][0][-3:]
                        sstart.append(s)
                        eend.append(e)
                        sspan.append(se)
                        d = {
                                "context": drug_sent_context,
                                "end_position": eend,
                                "entity_label": k,
                                "impossible": False,
                                "qas_id": f"{file_sent_id}.{tail_sent_id}.{tid}",
                                "query": v.format(drug_text),
                                "span_position": sspan,
                                "start_position": sstart
                            }

                    if sent_id < tail_sent_id:
                        s= tail_ann_idx_mappings[(tail_ann_id)][0][-3] + len_sent
                        e= tail_ann_idx_mappings[(tail_ann_id)][0][-2] + len_sent
                        se= f"{s};{e}"
                        dd = {
                                "context": drug_sent_context + ' ' + tail_context,
                                "end_position": [e],
                                "entity_label": k,
                                "impossible": False,
                                "qas_id": f"{file_sent_id}.{tail_sent_id}.{tid}",
                                "query": v.format(drug_text),
                                "span_position": [se],
                                "start_position": [s]
                            }
                        training_data.append(dd)
                    if sent_id > tail_sent_id:
                        s= tail_ann_idx_mappings[(tail_ann_id)][0][-3] 
                        e= tail_ann_idx_mappings[(tail_ann_id)][0][-2]
                        se= f"{s};{e}"
                        dd = {
                                "context": tail_context + ' ' + drug_sent_context,
                                "end_position": [e],
                                "entity_label": k,
                                "impossible": False,
                                "qas_id": f"{file_sent_id}.{tail_sent_id}.{tid}",
                                "query": v.format(drug_text),
                                "span_position": [se],
                                "start_position": [s]
                            }
                        training_data.append(dd)
                training_data.append(d)
                 
            else:
                d = {
                        "context": drug_sent_context,
                        "end_position": [],
                        "entity_label": k,
                        "impossible": True,
                        "qas_id": f"{file_sent_id}.{sent_id}.{tid}",
                        "query": v.format(drug_text),
                        "span_position": [],
                        "start_position": []
                    }
                training_data.append(d)

KeyboardInterrupt: 

In [15]:
pout = Path("/data/datasets/cheng/mrc-for-ner-medical/2018_n2c2/data/mrc_relation")
# pout = Path("/data/datasets/cheng/mrc-for-ner-medical/2018_n2c2/data/mrc_trigger")
pout.mkdir(parents=True, exist_ok=True)

In [None]:
to_json(training_data, pout, "train")

In [None]:
##single sentence as a dev sample
# relation extraction

dev_data = []


for i, fn in enumerate(p_dev.glob("*.ann")):

    txt_fn = p_dev / f"{fn.stem}.txt"
    ann_fn = p_dev / f"{fn.name}"
    txt, sents = pre_processing(txt_fn, MIMICIII_PATTERN, max_len=256)
    e2i, ens, relations, evns, attrs = load_annotation_brat(ann_fn)
    i2e = {v: k for k, v in e2i.items()}
    nsents, sent_bound = generate_BIO(sents, ens, file_id="", no_overlap=False, record_pos=True)
    nnsents = [w for sent in nsents for w in sent]
    mappings = create_entity_to_sent_mapping(nnsents, ens, i2e, fn.name)

    num_sents = len(nsents)
    sent_ids = set(range(num_sents))
    sent_with_drugs = set()

    drug_lists = []
    attributes_lists = []
    tail_ann_idx_mappings = defaultdict(list)
    relation_idx_mappings = defaultdict(list)


    # sentence with entities
    for en in ens:
        entype = en[1]
        type_id = entity_id[entype]
        text = en[0]
        s_idx, e_idx = mappings[en[-1]]
        word_info1 = nnsents[s_idx]
        word_info2 = nnsents[e_idx]
        
        sent_idx1 = word_info1[3][0]
        sent_idx2 = word_info2[3][0]
        sent_with_drugs.add(sent_idx1)
        
        try:
            sent_text = get_sent(nsents, sent_idx1, sent_idx2)
            context = " ".join([e[0] for e in sent_text])
            start = word_info1[3][1]
            end = word_info2[3][1]
            start_end = f"{start};{end}"
        except Exception as ex:
            print(fn.name)
            print(word_info1, word_info2)
            print(context)
            print(start, end, start_end)
            print(entype, type_id)
            
            # key will be sent_id, type
            # data will be tuple (context, entype, type_id, start, end, start_end)
        if entype == "Drug":
            drug_lists.append(
                [sent_idx1, text, context, start, end, start_end, en[-1]])
        else:
            tail_ann_idx_mappings[(en[-1])].append((sent_idx1, text, context, type_id, start, end, start_end))

    for pair in relations:
        relation_idx_mappings[(pair[-1].split(":")[1],pair[0].split('-')[0])].append((pair[1].split(":")[1]))
       
        


    # for i in sent_ids:
    # for i in sent_with_drugs:
    #     file_sent_id = f"{fn.stem}_{i}"
    #     sent_i_context = " ".join(e[0] for e in nsents[i])
    #     drug_entities = entity_sent_idx_mappings[(i, "Drug")]

    for drug_ent in drug_lists:
        sent_id = drug_ent[0]
        drug_text = drug_ent[1]
        len_sent = len(nsents[sent_id])
        file_sent_id = f"{fn.stem}.{sent_id}"
        drug_sent_context = " ".join(e[0] for e in nsents[drug_ent[0]])
        ann_id = drug_ent[-1]
        for k, v in attribute_template.items():
            tid = tail_id[k]
            
            if (ann_id, k) in relation_idx_mappings:
                
                attributes = relation_idx_mappings[(ann_id, k)]

                sstart=[]
                eend=[]
                sspan=[]
        
                for ent in attributes:
                    tail_ann_id = ent
                    tail_sent_id = tail_ann_idx_mappings[(tail_ann_id)][0][0]
                    tail_context = tail_ann_idx_mappings[(tail_ann_id)][0][2]
                    if sent_id == tail_sent_id:
                        s, e, se = tail_ann_idx_mappings[(tail_ann_id)][0][-3:]
                        sstart.append(s)
                        eend.append(e)
                        sspan.append(se)
                        d = {
                                "context": drug_sent_context,
                                "end_position": eend,
                                "entity_label": k,
                                "impossible": False,
                                "qas_id": f"{file_sent_id}.{tail_sent_id}.{tid}",
                                "query": v.format(drug_text),
                                "span_position": sspan,
                                "start_position": sstart
                            }

                    if sent_id < tail_sent_id:
                        s= tail_ann_idx_mappings[(tail_ann_id)][0][-3] + len_sent
                        e= tail_ann_idx_mappings[(tail_ann_id)][0][-2] + len_sent
                        se= f"{s};{e}"
                        dd = {
                                "context": drug_sent_context + ' ' + tail_context,
                                "end_position": [e],
                                "entity_label": k,
                                "impossible": False,
                                "qas_id": f"{file_sent_id}.{tail_sent_id}.{tid}",
                                "query": v.format(drug_text),
                                "span_position": [se],
                                "start_position": [s]
                            }
                        dev_data.append(dd)
                    if sent_id > tail_sent_id:
                        s= tail_ann_idx_mappings[(tail_ann_id)][0][-3] 
                        e= tail_ann_idx_mappings[(tail_ann_id)][0][-2]
                        se= f"{s};{e}"
                        dd = {
                                "context": tail_context + ' ' + drug_sent_context,
                                "end_position": [e],
                                "entity_label": k,
                                "impossible": False,
                                "qas_id": f"{file_sent_id}.{tail_sent_id}.{tid}",
                                "query": v.format(drug_text),
                                "span_position": [se],
                                "start_position": [s]
                            }
                        dev_data.append(dd)
                dev_data.append(d)
                 
            else:
                d = {
                        "context": drug_sent_context,
                        "end_position": [],
                        "entity_label": k,
                        "impossible": True,
                        "qas_id": f"{file_sent_id}.{sent_id}.{tid}",
                        "query": v.format(drug_text),
                        "span_position": [],
                        "start_position": []
                    }
                dev_data.append(d)

In [None]:
to_json(dev_data, pout, "dev")

In [13]:
##single sentence as a test sample
# relation extraction

test_data = []

p_pred = Path("/data/datasets/cheng/mrc-for-ner-medical/2018_n2c2/exp/ner/results/test_1")
# for i, fn in enumerate(p_dev.glob("*.ann")):
for i, fn in enumerate(p_dev.glob("*.txt")):

    txt_fn = p_dev / f"{fn.name}"
    # ann_fn = p_dev / f"{fn.stem}.ann"
    ann_fn = p_pred / f"{fn.stem}.ann"
    txt, sents = pre_processing(txt_fn, MIMICIII_PATTERN, max_len=256)
    e2i, ens, relations, evns, attrs = load_annotation_brat(ann_fn)
    i2e = {v: k for k, v in e2i.items()}
    nsents, sent_bound = generate_BIO(sents, ens, file_id="", no_overlap=False, record_pos=True)
    nnsents = [w for sent in nsents for w in sent]
    mappings = create_entity_to_sent_mapping(nnsents, ens, i2e, fn.name)

    num_sents = len(nsents)
    sent_ids = set(range(num_sents))
    sent_with_drugs = set()

    drug_lists = []
    attributes_lists = []
    tail_ann_idx_mappings = defaultdict(list)
    relation_idx_mappings = defaultdict(list)


    # sentence with entities
    for en in ens:
        entype = en[1]
        type_id = entity_id[entype]
        text = en[0]
        s_idx, e_idx = mappings[en[-1]]
        word_info1 = nnsents[s_idx]
        word_info2 = nnsents[e_idx]
        
        sent_idx1 = word_info1[3][0]
        sent_idx2 = word_info2[3][0]
        sent_with_drugs.add(sent_idx1)
        
        try:
            sent_text = get_sent(nsents, sent_idx1, sent_idx2)
            context = " ".join([e[0] for e in sent_text])
            start = word_info1[3][1]
            end = word_info2[3][1]
            start_end = f"{start};{end}"
        except Exception as ex:
            print(fn.name)
            print(word_info1, word_info2)
            print(context)
            print(start, end, start_end)
            print(entype, type_id)
            
            # key will be sent_id, type
            # data will be tuple (context, entype, type_id, start, end, start_end)
        if entype == "Drug":
            # drug_lists.append([sent_idx1, text, context, start, end, start_end, en[-1]])
            drug_lists.append([sent_idx1, text, context, start, end, start_end, en[-2], en[-1]])

        else:
            tail_ann_idx_mappings[(en[-1])].append((sent_idx1, text, context, type_id, start, end, start_end))

    for pair in relations:
        relation_idx_mappings[(pair[-1].split(":")[1],pair[0].split('-')[0])].append((pair[1].split(":")[1]))
       
        


    # for i in sent_ids:
    # for i in sent_with_drugs:
    #     file_sent_id = f"{fn.stem}_{i}"
    #     sent_i_context = " ".join(e[0] for e in nsents[i])
    #     drug_entities = entity_sent_idx_mappings[(i, "Drug")]

    for drug_ent in drug_lists:
        sent_id = drug_ent[0]
        drug_text = drug_ent[1]
        drug_ann = drug_ent[-1]
        len_sent = len(nsents[sent_id])
        file_sent_id = f"{fn.stem}.{sent_id}"
        drug_sent_context = " ".join(e[0] for e in nsents[drug_ent[0]])
        ann_id = drug_ent[-1]
        for k, v in attribute_template.items():
            tid = tail_id[k]
            
            if (ann_id, k) in relation_idx_mappings:
                
                attributes = relation_idx_mappings[(ann_id, k)]

                sstart=[]
                eend=[]
                sspan=[]
        
                for ent in attributes:
                    tail_ann_id = ent
                    tail_sent_id = tail_ann_idx_mappings[(tail_ann_id)][0][0]
                    tail_context = tail_ann_idx_mappings[(tail_ann_id)][0][2]
                    if sent_id == tail_sent_id:
                        s, e, se = tail_ann_idx_mappings[(tail_ann_id)][0][-3:]
                        sstart.append(s)
                        eend.append(e)
                        sspan.append(se)
                        d = {
                                "context": drug_sent_context,
                                "end_position": eend,
                                "entity_label": [k, drug_text, drug_ent[-2][0],drug_ent[-2][1]],
                                "impossible": False,
                                "qas_id": f"{file_sent_id}.{tail_sent_id}.{tid}",
                                "query": v.format(drug_text),
                                "span_position": sspan,
                                "start_position": sstart
                            }

                    if sent_id < tail_sent_id:
                        s= tail_ann_idx_mappings[(tail_ann_id)][0][-3] + len_sent
                        e= tail_ann_idx_mappings[(tail_ann_id)][0][-2] + len_sent
                        se= f"{s};{e}"
                        dd = {
                                "context": drug_sent_context + ' ' + tail_context,
                                "end_position": [e],
                                "entity_label": [k, drug_text, drug_ent[-2][0],drug_ent[-2][1]],
                                "impossible": False,
                                "qas_id": f"{file_sent_id}.{tail_sent_id}.{tid}",
                                "query": v.format(drug_text),
                                "span_position": [se],
                                "start_position": [s]
                            }
                        test_data.append(dd)
                    if sent_id > tail_sent_id:
                        s= tail_ann_idx_mappings[(tail_ann_id)][0][-3] 
                        e= tail_ann_idx_mappings[(tail_ann_id)][0][-2]
                        se= f"{s};{e}"
                        dd = {
                                "context": tail_context + ' ' + drug_sent_context,
                                "end_position": [e],
                                "entity_label": [k, drug_text, drug_ent[-2][0],drug_ent[-2][1]],
                                "impossible": False,
                                "qas_id": f"{file_sent_id}.{tail_sent_id}.{tid}",
                                "query": v.format(drug_text),
                                "span_position": [se],
                                "start_position": [s]
                            }
                        test_data.append(dd)
                test_data.append(d)
                 
            else:
                d = {
                        "context": drug_sent_context,
                        "end_position": [],
                        "entity_label": [k, drug_text, drug_ent[-2][0],drug_ent[-2][1]],
                        "impossible": True,
                        "qas_id": f"{file_sent_id}.{sent_id}.{tid}",
                        "query": v.format(drug_text),
                        "span_position": [],
                        "start_position": []
                    }
                test_data.append(d)

In [14]:
to_json(test_data, pout, "test")

In [16]:
def remap_index_to_wordindex (token_start, token_end, tokens, ccontext, nnsents, sent_id):

    corpora_records = ccontext.split()
    word_2_char_mapping={}
    char_cursor=0
    for ind in range(len(corpora_records)):
        if(len(corpora_records[ind])>0):#the last space will not be considered
            start=char_cursor
            end=char_cursor+len(corpora_records[ind])
            word_2_char_mapping[ind]=[start,end]
            char_cursor=char_cursor+len(corpora_records[ind])+1#consider the white-space length
    # print(ccontext)
    # print(word_2_char_mapping)
    start_char_span=tokens.token_to_chars(token_start)
    end_char_span=tokens.token_to_chars(token_end)
    # print(start_char_span,end_char_span)
    for each_word in word_2_char_mapping:
        start = word_2_char_mapping[each_word][0]
        end = word_2_char_mapping[each_word][1]
        
        if(start_char_span[0]>=start and start_char_span[1]<=end):
            print('a')
            s_char = each_word

        if(end_char_span[0]>=start and end_char_span[1]<=end):
            print('b')
            e_char = each_word
    # print(sent_id, s_char, e_char, ccontext)
    for i, sent in enumerate(nnsents):
        if (sent_id, s_char) == sent[3]:
            print('c')
            s_offset = sent[1][0]
            # print(s_offset)
            # word_idx_s = i 
        if (sent_id, e_char) == sent[3]:
            print('d')
            e_offset = sent[1][1]
            # print(e_offset)
            # word_idx_e = i
    return s_offset, e_offset 

In [None]:
#convert the output to brat format
#NER task
BRAT_TEMPLATE_T = "{}\t{} {} {}\t{}"
output_template_t = BRAT_TEMPLATE_T

from tokenizers import BertWordPieceTokenizer
p_test = Path("/data/datasets/cheng/mrc-for-ner-medical/2018_n2c2/test_data/gold_standard_test")

mrc_entity_dir = Path('/data/datasets/cheng/mrc-for-ner-medical/2018_n2c2/data/mrc_entity/mrc-ner.dev')

entity_pred_dir = Path('/data/datasets/cheng/mrc-for-ner-medical/2018_n2c2/exp/ner/pred')

entity_model = 'pred_1203_bert-large-cased_2_4_3e-5_20'

entity_pred_fn = entity_pred_dir / f"{entity_model}.json"

p_output = Path('/data/datasets/cheng/mrc-for-ner-medical/2018_n2c2/exp/ner/results/test_1')

bert_path = "/home/alexgre/projects/transformer_pretrained_models/bert-large-cased"
vocab_file = os.path.join(bert_path, "vocab.txt")
tokenizer = BertWordPieceTokenizer(vocab_file)

with open(entity_pred_fn, "r") as f:
    entity_preds = json.load(f)

with open(mrc_entity_dir, "r") as f:
    mrc_entity_fn = json.load(f)

file_suffix = 'ann'

for i, fn in enumerate(p_test.glob("*.txt")):
    
    txt_fn = p_test / f"{fn.name}"
    txt_contents = open(txt_fn,'r').read()
    
    txt, sents = pre_processing(txt_fn, MIMICIII_PATTERN, max_len=256)
    output_fn = p_output / "{}.{}".format(fn.stem, file_suffix)
    
    nsents, sent_bound = generate_BIO(sents, [], file_id="", no_overlap=False, record_pos=True)
    nnsents = [w for sent in nsents for w in sent]

    num_sents = len(nsents)
    sent_ids = set(range(num_sents))

    entity_pred = []
    entities_T = []
    print(fn.stem)
    # print(nnsents)
    
  
    
    for i, pred in enumerate(entity_preds):
        sample_idx = str(pred['sample_idx'][0][0])[:6]
        sample_idx = int(sample_idx)
        if sample_idx == int(fn.stem):
            entity_pred.append([i,pred])
    
    for pred_e in entity_pred:
        idx = pred_e[0]
        ent_r = pred_e[1]
        sample_idx = str(ent_r['sample_idx'][0][0])[6:]
        sent_id = int(sample_idx)
        # print(sent_id)
        context = mrc_entity_fn[idx]['context']
        query = mrc_entity_fn[idx]['query']
        tokens = tokenizer.encode(query, context, add_special_tokens=True)

        ens_r = ent_r['en']
        
        for en_r in ens_r:
            token_s = en_r[0]
            token_e = en_r[1] - 1
            # print(en_r[2], token_s,token_e,context)
            ent_s, ent_e = remap_index_to_wordindex (token_s, token_e, tokens, context, nnsents, sent_id)
            ent_type = en_r[3]
            entity_word = txt_contents[ent_s: ent_e]
            entities_T.append((ent_type, ent_s, ent_e, entity_word))
    
    output_t = []
    for i, entity_T in enumerate(entities_T):
        type, offset_s, offset_e, text = entity_T
        if "\n" in text:
            text = text.replace("\n", " ")
        formatted_output_t = output_template_t.format("T{}".format(i), type, offset_s, offset_e, text)
        output_t.append (formatted_output_t)

    with open(output_fn, "w") as f:
        formatted_output = "\n".join(output_t)
        f.write(formatted_output)
        f.write("\n")


In [65]:
from tokenizers import BertWordPieceTokenizer
bert_path = "/home/alexgre/projects/transformer_pretrained_models/bert-large-cased"
vocab_file = os.path.join(bert_path, "vocab.txt")
tokenizer = BertWordPieceTokenizer(vocab_file)
query = 'what are the injuries resulting from a medical intervention related to sulfasalazine'

context = "sulfasalazine dc ' d due to concern for drug induced lupus ."
tokens = tokenizer.encode(query, context, add_special_tokens=True)
tokens_1 = tokenizer.encode(context,add_special_tokens=False)
start = tokens.token_to_chars(41)
print(tokens.tokens)
print(tokens.token_to_chars(35))
# print(tokens_1.tokens)

['[CLS]', 'what', 'are', 'the', 'injuries', 'resulting', 'from', 'a', 'medical', 'intervention', 'related', 'to', 'su', '##lf', '##asa', '##la', '##zine', '[SEP]', 'su', '##lf', '##asa', '##la', '##zine', 'd', '##c', "'", 'd', 'due', 'to', 'concern', 'for', 'drug', 'induced', 'l', '##up', '##us', '.', '[SEP]']
(56, 58)


In [17]:
#convert the output to brat format
#RE task
BRAT_TEMPLATE_T = "{}\t{} {} {}\t{}"
output_template_t = BRAT_TEMPLATE_T
BRAT_TEMPLATE_R = "{}\t{}-{} Arg1:{} Arg2:{}"
output_template_r = BRAT_TEMPLATE_R

from tokenizers import BertWordPieceTokenizer
p_test = Path("/data/datasets/cheng/mrc-for-ner-medical/2018_n2c2/test_data/gold_standard_test")

mrc_relation_dir = Path('/data/datasets/cheng/mrc-for-ner-medical/2018_n2c2/data/mrc_relation/mrc-ner.test')

relation_pred_dir = Path('/data/datasets/cheng/mrc-for-ner-medical/2018_n2c2/exp/re/pred')

relation_model = 'pred_1206_bert-large-cased_2_4_1e-5_20_e2e'

relation_pred_fn = relation_pred_dir / f"{relation_model}.json"

p_output = Path('/data/datasets/cheng/mrc-for-ner-medical/2018_n2c2/exp/re/results/test_e2e')

bert_path = "/home/alexgre/projects/transformer_pretrained_models/bert-large-cased"
vocab_file = os.path.join(bert_path, "vocab.txt")
tokenizer = BertWordPieceTokenizer(vocab_file)

with open(relation_pred_fn, "r") as f:
    relation_preds = json.load(f)

with open(mrc_relation_dir, "r") as f:
    mrc_relation_fn = json.load(f)

file_suffix = 'ann'

for i, fn in enumerate(p_test.glob("*.txt")):
    txt_fn = p_test / f"{fn.name}"
    txt_contents = open(txt_fn,'r').read()
    
    txt, sents = pre_processing(txt_fn, MIMICIII_PATTERN, max_len=256)
    output_fn = p_output / "{}.{}".format(fn.stem, file_suffix)
    
    nsents, sent_bound = generate_BIO(sents, [], file_id="", no_overlap=False, record_pos=True)
    nnsents = [w for sent in nsents for w in sent]

    num_sents = len(nsents)
    sent_ids = set(range(num_sents))
    print(fn.name)

    # ann_fn = p_test / f"{fn.stem}.ann"
    
    # e2i, ens, relations, evns, attrs = load_annotation_brat(ann_fn)
    # i2e = {v: k for k, v in e2i.items()}
    
    # mappings = create_entity_to_sent_mapping(nnsents, ens, i2e, fn.name)


    # sent_with_drugs = set()

    # drug_lists = []

    # sentence with entities
    # for en in ens:
    #     entype = en[1]
    #     type_id = entity_id[entype]
    #     text = en[0]
    #     s_idx, e_idx = mappings[en[-1]]
    #     word_info1 = nnsents[s_idx]
    #     word_info2 = nnsents[e_idx]
        
    #     sent_idx1 = word_info1[3][0]
    #     sent_idx2 = word_info2[3][0]
    #     sent_with_drugs.add(sent_idx1)
        
    #     try:
    #         sent_text = get_sent(nsents, sent_idx1, sent_idx2)
    #         context = " ".join([e[0] for e in sent_text])
    #         start = word_info1[3][1]
    #         end = word_info2[3][1]
    #         start_end = f"{start};{end}"
    #     except Exception as ex:
    #         print(fn.name)
    #         print(word_info1, word_info2)
    #         print(context)
    #         print(start, end, start_end)
    #         print(entype, type_id)
            
    #         # key will be sent_id, type
    #         # data will be tuple (context, entype, type_id, start, end, start_end)
    #     if entype == "Drug":
    #         drug_lists.append(
    #             [sent_idx1, text, context, start, end, start_end, en[-1]])

    relation_pred = []
    event_pred = defaultdict(list)
     
    for i, pred in enumerate(relation_preds):
        sample_idx = pred['sample_idx']
        sample_idx = int(sample_idx[0][0])
        if sample_idx == int(fn.stem):
            relation_pred.append([i,pred])
    
    # for drug_ent in drug_lists:
        
    #     head_text = drug_ent[1]
    #     head_offset_s = drug_ent[3]
    #     head_offset_e = drug_ent[4]
    #     head_ann = drug_ent[-1]

    for pred_r in relation_pred:
        idx = pred_r[0]
        re_r = pred_r[1]
        head_sent_id = int(re_r['head_sent_idx'][0][0])
        tail_sent_id = int(re_r['tail_sent_idx'][0][0])
        query = mrc_relation_fn[idx]['query']
        head_text = mrc_relation_fn[idx]['entity_label'][1]
        head_offset_s = mrc_relation_fn[idx]['entity_label'][2]
        head_offset_e = mrc_relation_fn[idx]['entity_label'][3]
        
        tokens = tokenizer.encode(query, context, add_special_tokens=True)

        tails_r = re_r['en']
        # head_sent_id = re_r['head_sent_idx'][0][0]
        # tail_sent_id = re_r['tail_sent_idx'][0][0]
        
        for en_r in tails_r: 
            ent_type = en_r[3]
            if head_sent_id == tail_sent_id:
                token_s = en_r[0]
                token_e = en_r[1] - 1
                context = " ".join(e[0] for e in nsents[head_sent_id])
                tokens = tokenizer.encode(query, context, add_special_tokens=True)
                # print(context)
                # print(query)
                # print(tokens.tokens)
                # print(head_sent_id, token_s,token_e)
                ent_s, ent_e = remap_index_to_wordindex (token_s, token_e, tokens, context, nnsents, head_sent_id)
            
            if head_sent_id < tail_sent_id:
                head_sent = " ".join(e[0] for e in nsents[head_sent_id])
                head_sent_token = tokenizer.encode(head_sent, add_special_tokens=False)
                query_token = tokenizer.encode(query, add_special_tokens=False)
                query_token_len = len(query_token)
                head_sent_len =  len(head_sent_token)
                if en_r[0]>= query_token_len + head_sent_len + 2:
                    token_s = en_r[0] - head_sent_len
                    token_e = en_r[1] -1 -  head_sent_len
                    context = " ".join(e[0] for e in nsents[tail_sent_id])
                    tokens = tokenizer.encode(query, context, add_special_tokens=True)
                    # print(head_sent_id, tail_sent_id)
                    # print(token_s,token_e)
                    ent_s, ent_e = remap_index_to_wordindex (token_s, token_e, tokens, context, nnsents, tail_sent_id) 

            if head_sent_id > tail_sent_id:
                tail_sent = " ".join(e[0] for e in nsents[tail_sent_id])
                tail_sent_token = tokenizer.encode(tail_sent, add_special_tokens=False)
                query_token = tokenizer.encode(query, add_special_tokens=False)
                query_token_len = len(query_token)
                tail_sent_len =  len(tail_sent_token)
                if en_r[1]-1 < query_token_len + tail_sent_len + 2:
                    token_s = en_r[0] 
                    token_e = en_r[1] -1
                    context = " ".join(e[0] for e in nsents[tail_sent_id])
                    tokens = tokenizer.encode(query, context, add_special_tokens=True)
                    # print(token_s,token_e)
                    ent_s, ent_e = remap_index_to_wordindex (token_s, token_e, tokens, context, nnsents, tail_sent_id) 

            entity_word = txt_contents[ent_s: ent_e]
            
            event_pred[(head_text, head_offset_s, head_offset_e)] .append((ent_type, ent_s, ent_e, entity_word))
    
    output_tr = []
    i = 1
    k = 1
    m = 1
    for  event in list(event_pred):
        head_ent = event

        text_h, head_s, head_e  = head_ent
        if "\n" in text_h:
            text_h = text_h.replace("\n", " ")
        formatted_output_head = output_template_t.format("T{}".format(i), 'Drug', head_s, head_e, text_h)
        output_tr.append (formatted_output_head)
        i= i + 1 

        for attributes in event_pred[event]:
            att_type, tail_s, tail_e, text_t  = attributes
            if "\n" in text_t:
                text_t = text_t.replace("\n", " ")
            formatted_output_att = output_template_t.format("T{}".format(i), att_type, tail_s, tail_e, text_t)
            formatted_output_rel = output_template_r.format("R{}".format(m), att_type, 'Drug', "T{}".format(i), "T{}".format(k))
            output_tr.append(formatted_output_att)
            output_tr.append(formatted_output_rel)
            i = i + 1
            m = m + 1
        k = i

    with open(output_fn, "w") as f:
        formatted_output = "\n".join(output_tr)
        f.write(formatted_output)
        f.write("\n")


110753.txt
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c
d
a
b
c

In [None]:
txt_fn = '/data/datasets/cheng/mrc-for-ner-medical/2018_n2c2/test_data/test/105585.txt'
txt, sents = pre_processing(txt_fn, MIMICIII_PATTERN, max_len=256)
e2i, ens, _, evns, attrs = load_annotation_brat(ann_fn)
i2e = {v: k for k, v in e2i.items()}
nsents, sent_bound = generate_BIO(sents, [], file_id="", no_overlap=False, record_pos=True)
nnsents = [w for sent in nsents for w in sent]
print(nnsents)

In [38]:
p_result = Path("/data/datasets/cheng/mrc-for-ner-medical/2018_n2c2/exp/re/results/test_1")
p_resultt = Path("/data/datasets/cheng/mrc-for-ner-medical/2018_n2c2/exp/re/results/test_single_sentence")

for i, fn in enumerate(p_dev.glob("*.ann")):

    txt_fn = p_dev / f"{fn.stem}.txt"
    ann_fn = p_result / f"{fn.name}"
    txt, sents = pre_processing(txt_fn, MIMICIII_PATTERN, max_len=256)
    e2i, ens, relations, evns, attrs = load_annotation_brat(ann_fn)
    i2e = {v: k for k, v in e2i.items()}
    nsents, sent_bound = generate_BIO(sents, ens, file_id="", no_overlap=False, record_pos=True)
    nnsents = [w for sent in nsents for w in sent]
    mappings = create_entity_to_sent_mapping(nnsents, ens, i2e, fn.name)

    num_sents = len(nsents)
    sent_ids = set(range(num_sents))

    ann_file = p_result / f"{fn.name}"

    file_suffix = 'ann'
    output_fn = p_resultt / "{}.{}".format(fn.stem, file_suffix)
    output = []
    with open(ann_file, "r") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            anns = line.split("\t")
            ann_id = anns[0]
            if ann_id.startswith("T"):
                output.append(line)
            if ann_id.startswith("R"):
                relation, tail, head = anns[1].split()[0], anns[1].split()[1], anns[1].split()[-1]
                head_ann= mappings[head.split(':')[1]][0]
                tail_ann= mappings[tail.split(':')[1]][0]
                head_sent_id = nnsents[head_ann][3][0]
                tail_sent_id = nnsents[tail_ann][3][0]
                if head_sent_id == tail_sent_id:
                    output.append(line)
    with open(output_fn, "w") as f:
        formatted_output = "\n".join(output)
        f.write(formatted_output)
        f.write("\n")


185982.ann
last index not match  ('O2', 'Drug', (1049, 1051), 'T65')
189637.ann
last index not match  ('Warfarin', 'Drug', (10053, 10061), 'T1')
110135.ann
first index not match  ('NS', 'Drug', (938, 940), 'T174')
110135.ann
last index not match  ('NS', 'Drug', (938, 940), 'T174')
183783.ann
first index not match  ('insulin', 'Drug', (6910, 6917), 'T101')
183783.ann
last index not match  ('insulin', 'Drug', (6910, 6917), 'T101')
113333.ann
last index not match  ('O2', 'Drug', (6776, 6778), 'T187')
102357.ann
last index not match  ('O2', 'Drug', (3123, 3125), 'T158')
112342.ann
first index not match  ('evo', 'Drug', (13632, 13635), 'T307')
112342.ann
last index not match  ('evo', 'Drug', (13632, 13635), 'T307')
104021.ann
last index not match  ('O2', 'Drug', (2768, 2770), 'T209')
183331.ann
first index not match  ('PRBCs', 'Drug', (4724, 4729), 'T15')
183331.ann
last index not match  ('PRBCs', 'Drug', (4724, 4729), 'T15')


In [26]:
fi = open('/data/datasets/cheng/mrc-for-ner-medical/2018_n2c2/data/mimic/100035.ann')
lines = fi.readlines()
print(lines[0])


T1	Reason 10179 10197	recurrent seizures



In [25]:
##single sentence as a sample
# relation extraction
# CSD > 0

training_data = []


for i, fn in enumerate(p_tr.glob("*.ann")):

    txt_fn = p_tr / f"{fn.stem}.txt"
    ann_fn = p_tr / f"{fn.name}"
    txt, sents = pre_processing(txt_fn, MIMICIII_PATTERN, max_len=256)
    e2i, ens, relations, evns, attrs = load_annotation_brat(ann_fn)
    i2e = {v: k for k, v in e2i.items()}
    nsents, sent_bound = generate_BIO(sents, ens, file_id="", no_overlap=False, record_pos=True)
    nnsents = [w for sent in nsents for w in sent]
    mappings = create_entity_to_sent_mapping(nnsents, ens, i2e, fn.name)

    num_sents = len(nsents)
    sent_ids = set(range(num_sents))
    sent_with_drugs = set()

    drug_lists = []
    attributes_lists = []
    tail_ann_idx_mappings = defaultdict(list)
    relation_idx_mappings = defaultdict(list)


    # sentence with entities
    for en in ens:
        entype = en[1]
        type_id = entity_id[entype]
        text = en[0]
        s_idx, e_idx = mappings[en[-1]]
        word_info1 = nnsents[s_idx]
        word_info2 = nnsents[e_idx]
        
        sent_idx1 = word_info1[3][0]
        sent_idx2 = word_info2[3][0]
        sent_with_drugs.add(sent_idx1)
        
        try:
            sent_text = get_sent(nsents, sent_idx1, sent_idx2)
            context = " ".join([e[0] for e in sent_text])
            start = word_info1[3][1]
            end = word_info2[3][1]
            start_end = f"{start};{end}"
        except Exception as ex:
            print(fn.name)
            print(word_info1, word_info2)
            print(context)
            print(start, end, start_end)
            print(entype, type_id)
            
            # key will be sent_id, type
            # data will be tuple (context, entype, type_id, start, end, start_end)
        if entype == "Drug":
            drug_lists.append(
                [sent_idx1, text, context, start, end, start_end, en[-1]])
        else:
            tail_ann_idx_mappings[(en[-1])].append((sent_idx1, text, context, type_id, start, end, start_end))

    for pair in relations:
        relation_idx_mappings[(pair[-1].split(":")[1],pair[0].split('-')[0])].append((pair[1].split(":")[1]))
       
        


    # for i in sent_ids:
    # for i in sent_with_drugs:
    #     file_sent_id = f"{fn.stem}_{i}"
    #     sent_i_context = " ".join(e[0] for e in nsents[i])
    #     drug_entities = entity_sent_idx_mappings[(i, "Drug")]

    for drug_ent in drug_lists:
        sent_id = drug_ent[0]
        drug_text = drug_ent[1]
        len_sent = len(nsents[sent_id])
        file_sent_id = f"{fn.stem}.{sent_id}"
        drug_sent_context = " ".join(e[0] for e in nsents[drug_ent[0]])
        ann_id = drug_ent[-1]
        for k, v in attribute_template.items():
            tid = tail_id[k]
            
            if (ann_id, k) in relation_idx_mappings:
                
                attributes = relation_idx_mappings[(ann_id, k)]

                sstart=[]
                eend=[]
                sspan=[]
        
                for ent in attributes:
                    tail_ann_id = ent
                    tail_sent_id = tail_ann_idx_mappings[(tail_ann_id)][0][0]
                    tail_context = tail_ann_idx_mappings[(tail_ann_id)][0][2]
                    if sent_id < tail_sent_id:
                        s= tail_ann_idx_mappings[(tail_ann_id)][0][-3] + len_sent
                        e= tail_ann_idx_mappings[(tail_ann_id)][0][-2] + len_sent
                        se= f"{s};{e}"
                        dd = {
                                "context": drug_sent_context + ' ' + tail_context,
                                "end_position": [e],
                                "entity_label": k,
                                "impossible": False,
                                "qas_id": f"{file_sent_id}.{tail_sent_id}.{tid}",
                                "query": v.format(drug_text),
                                "span_position": [se],
                                "start_position": [s]
                            }
                        training_data.append(dd)
                    if sent_id > tail_sent_id:
                        s= tail_ann_idx_mappings[(tail_ann_id)][0][-3] 
                        e= tail_ann_idx_mappings[(tail_ann_id)][0][-2]
                        se= f"{s};{e}"
                        dd = {
                                "context": tail_context + ' ' + drug_sent_context,
                                "end_position": [e],
                                "entity_label": k,
                                "impossible": False,
                                "qas_id": f"{file_sent_id}.{tail_sent_id}.{tid}",
                                "query": v.format(drug_text),
                                "span_position": [se],
                                "start_position": [s]
                            }
                        training_data.append(dd)
                 
            else:
                d = {
                        "context": drug_sent_context,
                        "end_position": [],
                        "entity_label": k,
                        "impossible": True,
                        "qas_id": f"{file_sent_id}.{sent_id}.{tid}",
                        "query": v.format(drug_text),
                        "span_position": [],
                        "start_position": []
                    }
                training_data.append(d)

161477.ann
last index not match  ('O2', 'Drug', (1701, 1703), 'T30')
112832.ann
last index not match  ('1mg', 'Strength', (9936, 9939), 'T150')
112832.ann
first index not match  ('MWF', 'Frequency', (9939, 9942), 'T151')
112832.ann
last index not match  ('MWF', 'Frequency', (9939, 9942), 'T151')
106361.ann
last index not match  ('2L', 'Dosage', (11456, 11458), 'T9')
106361.ann
first index not match  ('NC', 'Route', (11458, 11460), 'T10')
106361.ann
last index not match  ('NC', 'Route', (11458, 11460), 'T10')
106361.ann
last index not match  ('2L', 'Dosage', (11665, 11667), 'T12')
106361.ann
first index not match  ('NC', 'Route', (11667, 11669), 'T13')
106361.ann
last index not match  ('NC', 'Route', (11667, 11669), 'T13')
114144.ann
first index not match  ('O2', 'Drug', (3232, 3234), 'T161')
114144.ann
last index not match  ('O2', 'Drug', (3232, 3234), 'T161')
103293.ann
last index not match  ('1', 'Dosage', (5573, 5574), 'T71')
103293.ann
first index not match  ('2', 'Dosage', (9896, 

In [26]:
to_json_1(training_data, pout, "train")

In [27]:
##single sentence as a dev sample
# relation extraction
# CSD > 0
dev_data = []


for i, fn in enumerate(p_dev.glob("*.ann")):

    txt_fn = p_dev / f"{fn.stem}.txt"
    ann_fn = p_dev / f"{fn.name}"
    txt, sents = pre_processing(txt_fn, MIMICIII_PATTERN, max_len=256)
    e2i, ens, relations, evns, attrs = load_annotation_brat(ann_fn)
    i2e = {v: k for k, v in e2i.items()}
    nsents, sent_bound = generate_BIO(sents, ens, file_id="", no_overlap=False, record_pos=True)
    nnsents = [w for sent in nsents for w in sent]
    mappings = create_entity_to_sent_mapping(nnsents, ens, i2e, fn.name)

    num_sents = len(nsents)
    sent_ids = set(range(num_sents))
    sent_with_drugs = set()

    drug_lists = []
    attributes_lists = []
    tail_ann_idx_mappings = defaultdict(list)
    relation_idx_mappings = defaultdict(list)


    # sentence with entities
    for en in ens:
        entype = en[1]
        type_id = entity_id[entype]
        text = en[0]
        s_idx, e_idx = mappings[en[-1]]
        word_info1 = nnsents[s_idx]
        word_info2 = nnsents[e_idx]
        
        sent_idx1 = word_info1[3][0]
        sent_idx2 = word_info2[3][0]
        sent_with_drugs.add(sent_idx1)
        
        try:
            sent_text = get_sent(nsents, sent_idx1, sent_idx2)
            context = " ".join([e[0] for e in sent_text])
            start = word_info1[3][1]
            end = word_info2[3][1]
            start_end = f"{start};{end}"
        except Exception as ex:
            print(fn.name)
            print(word_info1, word_info2)
            print(context)
            print(start, end, start_end)
            print(entype, type_id)
            
            # key will be sent_id, type
            # data will be tuple (context, entype, type_id, start, end, start_end)
        if entype == "Drug":
            drug_lists.append(
                [sent_idx1, text, context, start, end, start_end, en[-1]])
        else:
            tail_ann_idx_mappings[(en[-1])].append((sent_idx1, text, context, type_id, start, end, start_end))

    for pair in relations:
        relation_idx_mappings[(pair[-1].split(":")[1],pair[0].split('-')[0])].append((pair[1].split(":")[1]))
       
        


    # for i in sent_ids:
    # for i in sent_with_drugs:
    #     file_sent_id = f"{fn.stem}_{i}"
    #     sent_i_context = " ".join(e[0] for e in nsents[i])
    #     drug_entities = entity_sent_idx_mappings[(i, "Drug")]

    for drug_ent in drug_lists:
        sent_id = drug_ent[0]
        drug_text = drug_ent[1]
        len_sent = len(nsents[sent_id])
        file_sent_id = f"{fn.stem}.{sent_id}"
        drug_sent_context = " ".join(e[0] for e in nsents[drug_ent[0]])
        ann_id = drug_ent[-1]
        for k, v in attribute_template.items():
            tid = tail_id[k]
            
            if (ann_id, k) in relation_idx_mappings:
                
                attributes = relation_idx_mappings[(ann_id, k)]

                sstart=[]
                eend=[]
                sspan=[]
        
                for ent in attributes:
                    tail_ann_id = ent
                    tail_sent_id = tail_ann_idx_mappings[(tail_ann_id)][0][0]
                    tail_context = tail_ann_idx_mappings[(tail_ann_id)][0][2]
                    if sent_id < tail_sent_id:
                        s= tail_ann_idx_mappings[(tail_ann_id)][0][-3] + len_sent
                        e= tail_ann_idx_mappings[(tail_ann_id)][0][-2] + len_sent
                        se= f"{s};{e}"
                        dd = {
                                "context": drug_sent_context + ' ' + tail_context,
                                "end_position": [e],
                                "entity_label": k,
                                "impossible": False,
                                "qas_id": f"{file_sent_id}.{tail_sent_id}.{tid}",
                                "query": v.format(drug_text),
                                "span_position": [se],
                                "start_position": [s]
                            }
                        training_data.append(dd)
                    if sent_id > tail_sent_id:
                        s= tail_ann_idx_mappings[(tail_ann_id)][0][-3] 
                        e= tail_ann_idx_mappings[(tail_ann_id)][0][-2]
                        se= f"{s};{e}"
                        dd = {
                                "context": tail_context + ' ' + drug_sent_context,
                                "end_position": [e],
                                "entity_label": k,
                                "impossible": False,
                                "qas_id": f"{file_sent_id}.{tail_sent_id}.{tid}",
                                "query": v.format(drug_text),
                                "span_position": [se],
                                "start_position": [s]
                            }
                        training_data.append(dd)
                 
            else:
                d = {
                        "context": drug_sent_context,
                        "end_position": [],
                        "entity_label": k,
                        "impossible": True,
                        "qas_id": f"{file_sent_id}.{sent_id}.{tid}",
                        "query": v.format(drug_text),
                        "span_position": [],
                        "start_position": []
                    }
                dev_data.append(d)

125281.ann
first index not match  ('Levaquin', 'Drug', (5311, 5319), 'T41')
125281.ann
last index not match  ('Levaquin', 'Drug', (5311, 5319), 'T41')
185982.ann
last index not match  ('O2', 'Drug', (1049, 1051), 'T78')
107902.ann
last index not match  ('p.o', 'Route', (5493, 5496), 'T36')
107902.ann
last index not match  ('p.o', 'Route', (5555, 5558), 'T45')
107902.ann
last index not match  ('p.o', 'Route', (5625, 5628), 'T52')
107902.ann
last index not match  ('p.o', 'Route', (6209, 6212), 'T85')
189637.ann
last index not match  ('Warfarin', 'Drug', (10053, 10061), 'T2')
189637.ann
first index not match  ('5 mg', 'Strength', (10061, 10065), 'T3')
114044.ann
first index not match  ('1', 'Dosage', (1571, 1572), 'T53')
114044.ann
last index not match  ('1', 'Dosage', (1571, 1572), 'T53')
114044.ann
first index not match  ('3', 'Dosage', (1588, 1589), 'T82')
114044.ann
last index not match  ('3', 'Dosage', (1588, 1589), 'T82')
106588.ann
first index not match  ('1', 'Dosage', (1548, 1549

In [24]:
to_json_1(dev_data, pout, "dev")