In [298]:
import xml.etree.ElementTree as ET
import random
import os
import codecs
import collections
import numpy as np
import sys
import time
import json
import re

## Convert xml to json

In [301]:
processed_json_dir = "resources/i2b2/"
paths = {'test':"resources/i2b2/test/",
         'train':"resources/i2b2/train/",
         'dev':"resources/i2b2/dev/"}

In [302]:
code_dict = {"&quot;": '_quot_',
                "&amp;": "_amp_",
                "&gt;": "_gt_",
                "&lt;":"_lt_",
               "&apos;":'_apos_'} 
symbol_dict = {"&":"_amp_",
                '"':"_quot_",
                "<": "_lt_",
                ">": "_gt_",
                "'":"_apos_"}
reverse_dict = {item:key for key, item in symbol_dict.items()}# "_amp_":"&" etc
reverse_dict.update({key: reverse_dict[item] for key, item in code_dict.items()})
reverse_dict

{'_amp_': '&',
 '_quot_': '"',
 '_lt_': '<',
 '_gt_': '>',
 '_apos_': "'",
 '&quot;': '"',
 '&amp;': '&',
 '&gt;': '>',
 '&lt;': '<',
 '&apos;': "'"}

In [305]:
timex_sizes = {}
entities = {}
sectimerels=0
eerels=0
etrels=0
ttrels=0
rels={}

def file_name(file_dir):
    L=[]
    for root, dirs, files in os.walk(file_dir):
        for file in files:
            if os.path.splitext(file)[1] == '.xml':
                L.append(file)
    return L

def data_process(split):
    corrected=0
#     max_len = []
    inDIR = paths[split]
    fileList = file_name(inDIR)
    data =[]
    timemods = set()
    timetypes = set()
    tlinktype = set()
    eventtype = set()
    eventmod = set()
    eventpol = set()
    sectime_types = set()
    global sectimerels, eerels, etrels, ttrels
    
    fileNO = 0
    for f in fileList:
        try:
            gold_fname = open(inDIR+"/" + f, "r")
        except:
            print("\n"+inDIR+f+" NOT FOUND")
        xmlString = ""
        for lines in gold_fname.readlines():
            for w, rep in code_dict.items():
                lines = lines.replace(w,rep)
            xmlString += lines.replace("&","_amp_")
        gold_fname.close()

        labelType = set()
        parser = ET.XMLParser(encoding="utf-8")
        root = ET.fromstring(xmlString, parser=parser)
        text = root.find("TEXT").text.replace("\n", " ")
        for w, rep in reverse_dict.items():
            text = text.replace(w,rep)
        
        tags = root.find("TAGS")        

        event_entities = []
        timex_entities = []
        sectime_entities = []

        tlinks = []
        for tlink in tags.findall("TLINK"):
            link = {}
            link['id'] = f[:-4] +"_"+ str(tlink.attrib['id'] )
            link['head'] = tlink.attrib['fromID']
            link['tail'] = tlink.attrib['toID']
            link['type'] = tlink.attrib['type']#.replace("SIMULTANEOUS", "OVERLAP")
            if link['type'] in rels.keys():
                rels[link['type']]+=1
            else:
                rels[link['type']] = 1
            if link['type'] == '':
                continue
            if link['head'].startswith("S") or link['tail'].startswith("S"):
                sectimerels+=1
            elif link['head'].startswith("E"):
                if link['tail'].startswith("E"):
                    eerels+=1
                elif link['tail'].startswith("T"):
                    etrels+=1
            elif link['head'].startswith("T"):
                if link['tail'].startswith("E"):
                    etrels+=1
                elif link['tail'].startswith("T"):
                    ttrels+=1 

            tlinktype.add(link['type'])
            tlinks.append(link)

        events = []
        for event in tags.findall("EVENT"):
            e={}
            e["id"] = f[:-4] +"_"+ str(event.attrib['id'] )
            e["start"] = event.attrib['start']
            e["end"] = event.attrib["end"]

            e["modality"] = event.attrib["modality"].upper()
            e["polarity"] = event.attrib["polarity"].upper()
            e["type"] = event.attrib['type']
            if e['type'] in entities.keys():
                entities[e['type']]+=1
            else:
                entities[e['type']] = 1

            if e["type"] == '':
                continue
            e["text"] = event.attrib['text']
            for w, rep in reverse_dict.items():
                e['text'] = e['text'].replace(w,rep)
            if e['text'].strip()!="":
                span = text[int(e['start']):int(e['end'])]
                if span!= e['text']:
                    try:
                        e['start'], e['end'] = offset_correction(e, text)
                    except:
                        print(f"offset correction file {f}| event:{e['id']}| start:{e['start']}| end:{e['end']}| span:{span}*| ent_text:{e['text']}*")

                    assert text[int(e['start']):int(e['end'])] == e['text'], \
                           f"offset correction file {f}| event:{e['id']}| start:{e['start']}| end:{e['end']}| span:{span}*| ent_text:{e['text']}*"

                    corrected+=1


                eventtype.add(e["type"])
                eventmod.add(e["modality"])
                eventpol.add(e["polarity"])
                events.append(e)            

        times = []
        for event in tags.findall("TIMEX3"):
            time = {}
            time["id"] = f[:-4] +"_"+ str(event.attrib['id'])
            time["start"] = event.attrib['start']
            time["end"] = event.attrib['end']
            time["val"] = event.attrib['val']
            time["mod"] = event.attrib['mod']
            time["type"] = event.attrib["type"]
            if time['type'] in entities.keys():
                entities[time['type']]+=1
            else:
                entities[time['type']] = 1
            time["text"] = event.attrib["text"]
            for w, rep in reverse_dict.items():
                time['text'] = time['text'].replace(w,rep)
                    
            if time["text"].strip()!="":
                span = text[int(time['start']):int(time['end'])]
                if span!= time['text']:
                    try:
                        time['start'], time['end'] = offset_correction(time, text)
                    except:
                        print(f"offset correction file {f}| timex: {time['text']}| span:{span}| start:{time['start']}, end:{time['end']}")

                    assert text[int(time['start']):int(time['end'])] == time['text'],\
                           f"offset correction file {f}| timex: {time['text']}| span:{span}| start:{time['start']}, end:{time['end']}"

                    corrected+=1

                timemods.add(time["mod"])
                times.append(time)

        sectimes = []
        for event in tags.findall("SECTIME"):
            sectime = {}
            sectime["id"] = f[:-4] +"_"+ str(event.attrib['id'])
            sectime["start"] = event.attrib['start']
            sectime["end"] = event.attrib['end']
            sectime["val"] = event.attrib['dvalue']
            sectime["type"] = event.attrib["type"]
            if sectime['type'] in entities.keys():
                entities[sectime['type']]+=1
            else:
                entities[sectime['type']] = 1
            sectime["text"] = event.attrib["text"]
            for w, rep in reverse_dict.items():
                sectime['text'] = sectime['text'].replace(w,rep)
            
            if sectime["text"].strip()!="":
                span = text[int(sectime['start']):int(sectime['end'])]
                if span!= sectime['text']:
                    try:
                        sectime['start'], sectime['end'] = offset_correction(sectime, text)
                    except:
                        print(f"offset correction file {f}| sectime: {sectime['text']}| span:{span}| start:{sectime['start']}, end:{sectime['end']}")
                    assert text[int(sectime['start']):int(sectime['end'])] == sectime['text'],\
                          f"offset correction file {f}| timex: {time['text']}| span:{span}| start:{time['start']}, end:{time['end']}"

                    corrected+=1

                sectime_types.add(sectime["type"])
                sectimes.append(sectime)

        fileNO += 1
        observation = {"file_id":f,"text":text, "entities":{"events":events, "timex":times, "sectimes":sectimes},\
                       "relations":tlinks}
        data.append(observation)
    with open(processed_json_dir+split+".json", 'w') as fp:
        json.dump(data, fp)
    print(f"corrected {corrected}")
    print("*"*80)
    print( "time MODS:{}".format(timemods))
    print("*"*80)
    print("tlinktypes:{}".format( tlinktype))
    print("*"*80)
    print("EVEnt types:{}".format(eventtype))
    print("*"*80)
    print("event mods:{}".format(eventmod))
    print("*"*80)
    print("event polarity:{}".format(eventpol))
    print("*"*80)
    print("sectime types:{}".format(sectime_types))

    

In [310]:
def shift(ent, text, start, end, offset):
    if ent['text'] == text[start+offset:end+offset]: #equal shift
        return start+offset, end+offset, True
    if ent['text'] == text[start:end+offset]: #only shift end
        return start, end+offset, True
    if ent['text'] == text[start+offset:end]: #only shift start
        return start+offset, end, True
    if ent['text'] == text[start+offset:end+2*offset]: #shift end twice compared to start
        return start+offset, end+2*offset, True
    if ent['text'] == text[start+offset:end+3*offset]: #shift end twice compared to start            
        return start+offset, end+3*offset, True
    if ent['text'] == text[start+offset:end+offset-5]:#shift end thrice compared to start
        return start+offset, end+offset-5, True
    if ent['text'] == text[start+offset:end+offset-10]:#shift end thrice compared to start
        return start+offset, end+offset-10, True
    return start, end, False
    
def offset_correction(ent, text): #match in multiple of 1, 4, 5, 6
    start, end = int(ent['start']), int(ent['end'])
    for offset in [1, -1, 2, -2, 3, -3, 4, -4, 5, 6, -8]:
        start, end, fix = shift(ent, text, start, end, offset)
        if fix:
            return start, end
    for multiple in range(1, 25):
        start, end, fix = shift(ent, text, start, end, -multiple*5)
        if fix:
            return start, end  
    print()
        

In [311]:
for split in paths.keys():
    if split == 'test':
        continue
    print(split)
    data_process(split)

train
corrected 3243
********************************************************************************
time MODS:{'MIDDLE', 'MORE', 'NA', 'APPROX', 'END', 'START'}
********************************************************************************
tlinktypes:{'SIMULTANEOUS', 'AFTER', 'BEFORE'}
********************************************************************************
EVEnt types:{'CLINICAL_DEPT', 'OCCURRENCE', 'PROBLEM', 'EVIDENTIAL', 'TEST', 'TREATMENT'}
********************************************************************************
event mods:{'FACTUAL', 'PROPOSED', 'CONDITIONAL', 'HYPOTHETICAL', 'POSSIBLE'}
********************************************************************************
event polarity:{'POS', 'NEG'}
********************************************************************************
sectime types:{'DISCHARGE', 'ADMISSION'}
dev
corrected 269
********************************************************************************
time MODS:{'APPROX', 'END', 'START', 'NA'}
****