In [1]:
import os, xml.etree.ElementTree as ET, numpy as np

In [2]:
START_CDATA = "<TEXT><![CDATA["
END_CDATA   = "]]></TEXT>"

TAGS        = ['MEDICATION', 'OBSEE', 'SMOKER', 'HYPERTENSION', 'PHI', 'FAMILY_HIST']

def read_xml_file(xml_path, PHI_tag_type='ALL_CHILDREN', match_text=True):
    with open(xml_path, mode='r') as f:
        lines = f.readlines()
        text, in_text = [], False
        for i, l in enumerate(lines):
            if START_CDATA in l:
                text.append(list(l[l.find(START_CDATA) + len(START_CDATA):]))
                in_text = True
            elif END_CDATA in l:
                text.append(list(l[:l.find(END_CDATA)]))
                break
            elif in_text:
                if xml_path.endswith('180-03.xml') and '0808' in l and 'Effingham' in l:
                    print("Adjusting known error")
                    l = l[:9] + ' ' * 4 + l[9:]
#                 elif xml_path.endswith('188-05.xml') and 'Johnson & Johnson' in l:
#                     print("Adjusting known error")
#                     l = l.replace('&', 'and')
                text.append(list(l))
        
    pos_transformer = {}
    
    linear_pos = 1
    for line, sentence in enumerate(text):
        for char_pos, char in enumerate(sentence):
            pos_transformer[linear_pos] = (line, char_pos)
            linear_pos += 1
        
    xml_parsed = ET.parse(xml_path)
    tag_containers = xml_parsed.findall('TAGS')
    assert len(tag_containers) == 1, "Found multiple tag sets!"
    tag_container = tag_containers[0]
    
    PHI_tags = tag_container.getchildren() if PHI_tag_type == 'ALL_CHILDREN' else tag_container.findall('PHI')
    PHI_labels = [['O'] * len(sentence) for sentence in text]
    for PHI_tag in PHI_tags:
        base_label = PHI_tag.attrib['TYPE']
        start_pos, end_pos, PHI_text = PHI_tag.attrib['start'], PHI_tag.attrib['end'], PHI_tag.attrib['text']
        start_pos, end_pos = int(start_pos)+1, int(end_pos)
        PHI_text = ' '.join(PHI_text.split())
#         if PHI_text == "0808 O’neil’s Court":
#             print("Adjusting known error")
#             end_pos -= 4
        if PHI_text == 'Johnson and Johnson' and xml_path.endswith('188-05.xml'):
            print("Adjusting known error")
            PHI_text = 'Johnson & Johnson'
        

        (start_line, start_char), (end_line, end_char) = pos_transformer[start_pos], pos_transformer[end_pos]
            
        obs_text = []
        for line in range(start_line, end_line+1):
            t = text[line]
            s = start_char if line == start_line else 0
            e = end_char if line == end_line else len(t)
            obs_text.append(''.join(t[s:e+1]).strip())
        obs_text = ' '.join(obs_text)
        obs_text = ' '.join(obs_text.split())
              
        if match_text: assert obs_text == PHI_text, (
            ("Texts don't match! %s v %s" % (PHI_text, obs_text)) + '\n' + str((
                start_pos, end_pos, line, s, e, t, xml_path
            ))
        )
        
        PHI_labels[end_line][end_char]     = 'I-%s' % base_label
        PHI_labels[start_line][start_char] = 'B-%s' % base_label
        
        for line in range(start_line, end_line+1):
            t = text[line]
            s = start_char+1 if line == start_line else 0
            e = end_char-1 if line == end_line else len(t)-1
            for i in range(s, e+1): PHI_labels[line][i] = 'I-%s' % base_label

    return text, PHI_labels
    
def merge_into_words(text_by_char, all_labels_by_char):
    assert len(text_by_char) == len(all_labels_by_char), "Incorrect # of sentences!"
    
    N = len(text_by_char)
    
    text_by_word, all_labels_by_word = [], []
    
    for sentence_num in range(N):
        sentence_by_char = text_by_char[sentence_num]
        labels_by_char   = all_labels_by_char[sentence_num]
        
        assert len(sentence_by_char) == len(labels_by_char), "Incorrect # of chars in sentence!"
        S = len(sentence_by_char)
        
        if labels_by_char == (['O'] * len(sentence_by_char)):
            sentence_by_word = ''.join(sentence_by_char).split()
            labels_by_word   = ['O'] * len(sentence_by_word)
        else: 
            sentence_by_word, labels_by_word = [], []
            text_chunks, labels_chunks = [], []
            s = 0
            for i in range(S):
                if i == S-1:
                    text_chunks.append(sentence_by_char[s:])
                    labels_chunks.append(labels_by_char[s:])
                elif labels_by_char[i] == 'O': continue
                else:
                    if i > 0 and labels_by_char[i-1] == 'O':
                        text_chunks.append(sentence_by_char[s:i])
                        labels_chunks.append(labels_by_char[s:i])
                        s = i
                    if labels_by_char[i+1] == 'O' or labels_by_char[i+1][2:] != labels_by_char[i][2:]:
                        text_chunks.append(sentence_by_char[s:i+1])
                        labels_chunks.append(labels_by_char[s:i+1])
                        s = i+1
                
            for text_chunk, labels_chunk in zip(text_chunks, labels_chunks):
                assert len(text_chunk) == len(labels_chunk), "Bad Chunking (len)"
                assert len(text_chunk) > 0, "Bad chunking (len 0)" + str(text_chunks) + str(labels_chunks)
                
                labels_set = set(labels_chunk)
                assert labels_set == set(['O']) or (len(labels_set) <= 3 and 'O' not in labels_set), (
                    ("Bad chunking (contents) %s" % ', '.join(labels_set))+ str(text_chunks) + str(labels_chunks)
                )
                
                text_chunk_by_word = ''.join(text_chunk).split()
                W = len(text_chunk_by_word)
                if W == 0: 
#                     assert labels_set == set(['O']), "0-word chunking and non-0 label!" + str(
#                         text_chunks) + str(labels_chunks
#                     )
                    continue
                
                if labels_chunk[0] == 'O': labels_chunk_by_word = ['O'] * W
                elif W == 1:               labels_chunk_by_word = [labels_chunk[0]]
                elif W == 2:               labels_chunk_by_word = [labels_chunk[0], labels_chunk[-1]]
                else:                      labels_chunk_by_word = [
                        labels_chunk[0]
                    ] + [labels_chunk[1]] * (W - 2) + [
                        labels_chunk[-1]
                    ]
                    
                sentence_by_word.extend(text_chunk_by_word)
                labels_by_word.extend(labels_chunk_by_word)

        assert len(sentence_by_word) == len(labels_by_word), "Incorrect # of words in sentence!"    
        
        if len(sentence_by_word) == 0: continue
            
        text_by_word.append(sentence_by_word)
        all_labels_by_word.append(labels_by_word)
    return text_by_word, all_labels_by_word

def reprocess_PHI_labels(folders, base_path='.', PHI_tag_type='PHI', match_text=True, dev_set_size=None):
    all_texts_by_patient, all_labels_by_patient = {}, {}

    for folder in folders:
        folder_dir = os.path.join(base_path, folder)
        xml_filenames = [x for x in os.listdir(folder_dir) if x.endswith('xml')]
        for xml_filename in xml_filenames:
            patient_num = int(xml_filename[:3])
            xml_filepath = os.path.join(folder_dir, xml_filename)
            
            text_by_char, labels_by_char = read_xml_file(
                xml_filepath,
                PHI_tag_type=PHI_tag_type,
                match_text=match_text
            )
            text_by_word, labels_by_word = merge_into_words(text_by_char, labels_by_char)
            
            if patient_num not in all_texts_by_patient:
                all_texts_by_patient[patient_num] = []
                all_labels_by_patient[patient_num] = []
            
            all_texts_by_patient[patient_num].extend(text_by_word)
            all_labels_by_patient[patient_num].extend(labels_by_word)
            
    patients = set(all_texts_by_patient.keys())
    
    if dev_set_size is None: train_patients, dev_patients = list(patients), []
    else:
        N_train = int(len(patients) * (1-dev_set_size))
        patients_random = np.random.permutation(list(patients))
        train_patients = list(patients_random[:N_train])
        dev_patients   = list(patients_random[N_train:])
    
    train_texts, train_labels = [], []
    dev_texts, dev_labels = [], []
    
    for patient_num in train_patients:
        train_texts.extend(all_texts_by_patient[patient_num])
        train_labels.extend(all_labels_by_patient[patient_num])

    for patient_num in dev_patients:
        dev_texts.extend(all_texts_by_patient[patient_num])
        dev_labels.extend(all_labels_by_patient[patient_num])


    train_out_text_by_sentence = []
    for text, labels in zip(train_texts, train_labels):
        train_out_text_by_sentence.append('\n'.join('%s %s' % x for x in zip(text, labels)))
    dev_out_text_by_sentence = []
    for text, labels in zip(dev_texts, dev_labels):
        dev_out_text_by_sentence.append('\n'.join('%s %s' % x for x in zip(text, labels)))

    return '\n\n'.join(train_out_text_by_sentence), '\n\n'.join(dev_out_text_by_sentence)

In [3]:
final_train_text, final_dev_text = reprocess_PHI_labels(
    ['../../../training-PHI-Gold-Set1/', '../../../training-PHI-Gold-Set2/'], PHI_tag_type='ALL_CHILDREN',
    dev_set_size=0.1, match_text=True
)

  PHI_tags = tag_container.getchildren() if PHI_tag_type == 'ALL_CHILDREN' else tag_container.findall('PHI')


Adjusting known error


In [4]:
test_text, _ = reprocess_PHI_labels(
    ['../../../testing-PHI-Gold-fixed'], PHI_tag_type='ALL_CHILDREN', match_text=False, dev_set_size=None
)

  PHI_tags = tag_container.getchildren() if PHI_tag_type == 'ALL_CHILDREN' else tag_container.findall('PHI')


In [5]:
print(final_train_text[:500])

Record O
date: O
2087-04-06 B-DATE

PROBLEMS O

Diabetes O
mellitus O

Hypertension O

Psoriasis O

Hysterectomy O
due O
to O
bleeding O

Gastrectomy O
partial, O
PUD O
2061 B-DATE

PVD: O
iliac O
disease, O
compensated O
distally O

MEDICATIONS O

one O
touch O
test O
strips O

ATENOLOL O
50MG O
1 O
Tablet(s) O
PO O
QD O
90 O
day(s) O

GLYBURIDE O
5MG O
0.5 O
Tablet(s) O
PO O
QD O

63 B-AGE
yo O
returns O
for O
med O
refills O
after O
1 O
1/2 O
yr O
hiatus. O
Has O
not O
kept O
up O
with O
rout


In [6]:
print(final_dev_text[:400])

Record O
date: O
2121-05-09 B-DATE

BCH B-HOSPITAL
EMERGENCY O
DEPT O
VISIT O

HOLCOMB,DENNIS B-PATIENT
833-12-06-0 B-MEDICALRECORD
VISIT O
DATE: O
05/09/21 B-DATE

This O
patient O
was O
seen, O
interviewed O
and O
examined O
by O
myself O
as O
well O

as O
Dr. O
Petty B-DOCTOR
whose O
I O
have O
reviewed O
and O
whose O
findings O
I O
have O

confirmed. O

HISTORY O
OF O
PRESENTING O
COMPLAINT: 


In [7]:
print(test_text[:400])

Record O
date: O
2069-04-07 B-DATE

Mr. O
Villegas B-PATIENT
is O
seen O
today. O
I O
have O
not O
seen O
him O
since O
November B-DATE
. O

About O
three O
weeks O
ago O
he O
stopped O
his O
Prednisone O
on O
his O
own O
because O

he O
was O
gaining O
weight. O
He O
does O
feel O
that O
his O
shoulders O
are O

definitely O
improved. O
It O
is O
unclear O
what O
he O
is O
actually O
taking, O
bu


In [8]:
labels = {}
for s in final_train_text, final_dev_text, test_text:
    for line in s.split('\n'):
        if line == '': continue
        label = line.split()[-1]
        assert label == 'O' or label.startswith('B-') or label.startswith('I-'), "label wrong! %s" % label
        if label not in labels: labels[label] = 1
        else: labels[label] += 1


In [9]:
labels

{'O': 777074,
 'B-DATE': 12451,
 'B-AGE': 1995,
 'B-DOCTOR': 4786,
 'B-HOSPITAL': 2306,
 'B-PATIENT': 2192,
 'B-MEDICALRECORD': 1032,
 'I-DOCTOR': 3471,
 'I-PATIENT': 1191,
 'B-IDNUM': 456,
 'I-HOSPITAL': 1827,
 'B-STREET': 349,
 'I-STREET': 711,
 'B-CITY': 651,
 'B-STATE': 501,
 'B-ZIP': 349,
 'B-PHONE': 524,
 'I-DATE': 1373,
 'B-ORGANIZATION': 205,
 'I-ORGANIZATION': 168,
 'B-PROFESSION': 413,
 'I-PROFESSION': 346,
 'I-PHONE': 100,
 'B-USERNAME': 356,
 'I-AGE': 10,
 'I-CITY': 171,
 'I-IDNUM': 30,
 'B-COUNTRY': 183,
 'I-MEDICALRECORD': 47,
 'I-COUNTRY': 21,
 'B-BIOID': 1,
 'B-LOCATION-OTHER': 17,
 'I-LOCATION-OTHER': 15,
 'B-EMAIL': 5,
 'B-FAX': 10,
 'I-FAX': 2,
 'B-DEVICE': 15,
 'B-HEALTHPLAN': 1,
 'I-HEALTHPLAN': 1,
 'I-STATE': 18,
 'B-URL': 2,
 'I-URL': 4,
 'I-DEVICE': 2}

In [10]:
f = open("../../../processed/label.txt", "w")
for label in reversed(sorted(labels)):
    f.write(label+"\n")
f.close()


with open('../../../processed/train.txt', mode='w') as f:
    f.write(final_train_text)
with open('../../../processed/dev.txt', mode='w') as f:
    f.write(final_dev_text)
with open('../../../processed/test.txt', mode='w') as f:
    f.write(test_text)