In [1]:
import os
import xml.etree.ElementTree as ET
import pandas as pd
import numpy as npx
import itertools
import re
import yaml

In [2]:
class XmlRecord:
    def __init__(self, fileLocation, surgicalPatients=None, maxHospitalNumbers=None):
        """
        Initialises an object for an XML file, storing the information extracted from the XML
        If surgicalPatients and maxHospitalNumbers are given, look up the record in the surgical DB
        and add operation details to the object
        """
        self._fileLocation = fileLocation
        self._caseElement = ET.parse(self._fileLocation).getroot()
        
        self._meetingElement= self._caseElement.find('meeting')
        self._demographicsElement = self._caseElement.find('demographics')
        self._eventsElement = self._caseElement.find('events')
        MRN = self._demographicsElement.find('identifier').text
        
        self._dataDict = {
            "MRN1": MRN,
            "MRN2": None,
            "MRN3": None,
            "MRN4": None,
            "IDP": None, #CHECK THIS!
        }
        
        self._eventCount = int(self._eventsElement.find('count').text)
        self._rawClassificationArray = self.extractRawClassificationArray()
        
        self.lookUpSurgicalResult(surgicalPatients, maxHospitalNumbers)
        self._hippocampalSclerosis = self.extractHippocampalSclerosis(self._fileLocation)
            
    def extractRawClassificationArray(self):
        rawClassificationArray = []
        for nthEvent, event in zip(range(self._eventCount), self._eventsElement):
            eventClassificationArray = []
            eventClassificationLength = int(event.find('classificationLength').text)
            for nthClassification, classification in zip(range(eventClassificationLength), event):
                eventClassificationArray.append(classification.text)
            rawClassificationArray.append(eventClassificationArray)
        return rawClassificationArray
    
    def RawClassificationToChecklists(self, rawClassificationArray=None):
        if not rawClassificationArray:
            rawClassificationArray = self._rawClassificationArray
        flattenedList = list((itertools.chain(*rawClassificationArray)))
#         numberOfFeatures = len(pdfToMatlabToMLTranslation['pdf'])
        pdfFeatureDict = dict(zip(pdfToMatlabToMLTranslation['pdf'], np.zeros(1000).astype(bool)))
        mlFeatureDict = dict(zip(mlLabels, np.zeros(1000).astype(bool)))
        
        stepsFound = []
        stepsNotFound = []
        stepsFoundByRegex = []
        asAboveFound = []
        
        for classificationStep in flattenedList:
            found = False
            if classificationStep:
                for pdfLabel in pdfToMatlabToMLTranslation['pdf']:
                    if pdfLabel.lower() == classificationStep.lower() or pdfLabel.lower() == stripClassificationStep(classificationStep):
                        pdfFeatureDict[pdfLabel] = True
                        mlLabelFound = pdfToMatlabToMLTranslation[pdfToMatlabToMLTranslation['pdf'] == pdfLabel]['ml'].values[0]
                        mlFeatureDict[mlLabelFound] = True
                        found = True
                        stepsFound.append(classificationStep)
                if not found:
                    keysFound = regexYamlMatch(semiologyRegexDict, classificationStep)
                    if keysFound:
                        stepsFoundByRegex.append([keysFound, classificationStep])
                        for keyFound in keysFound:
                            mlFeatureDict[keyFound] = True
                        found = True
                if 'as above' in stripClassificationStep(classificationStep):
                    found = True
                    stepsFound.append(classificationStep)
                if not found:
                    stepsNotFound.append(classificationStep)
        
        stepsFoundByRegex = [i for i in stepsFoundByRegex if i]
        self._pdfFeatureDict = pdfFeatureDict
        self._mlFeatureDict = mlFeatureDict
        self._stepsFound = stepsFound
        self._stepsNotFound = stepsNotFound
        self._stepsFoundByRegex = stepsFoundByRegex
        self._asAboveFound = asAboveFound
            
        return {'pdfFeatureDict': pdfFeatureDict,
                'mlFeatureDict': mlFeatureDict, 
                'stepsFound': stepsFound,
                'stepsNotFound': stepsNotFound,
                'stepsFoundByRegex': stepsFoundByRegex,
                'asAboveFound': asAboveFound}
    
    def extractHippocampalSclerosis(self, fileLocation):
        with open(fileLocation) as file:
            filetext = file.read()
        for regex in hippocampalSclerosisRegex:
            matches = re.findall(regex, filetext) 
            if matches:
                return True
        return False
    
    def lookUpSurgicalResult(self, surgicalPatients, maxHospitalNumbers):
        '''
        Takes the MRN associated with this xml file, and looks it up in a DF of surgical patients supplied by
        surgicalPatients.
        If the record exists (ie this patient has had surgery), self._hadSurgery = True.
        If hadSurgery = True, look up which zone was removed (eg T Lx), and if the patient was Entirely Seizure
        Free (boolean column in the excel file)
        
        Returns a dict: hadSurgery = True/False, surgerySuccess = True/False/Na, '''
        foundRecord = False
        recordNumber = None
        for columnNumber in range(maxHospitalNumbers):
            columnRecords = np.where(
                    (surgicalPatients[columnNumber] == self._dataDict['MRN1']) | (surgicalPatients[columnNumber] == '0'+self._dataDict['MRN1'])
                )[0]
            if len(columnRecords) == 1 and foundRecord == False:
                recordNumber = columnRecords[0]
                foundRecord = True
            elif len(columnRecords) == 1 and foundRecord == True:
                raise ValueError('Multiple MRNs matched')
            elif len(columnRecords) > 1  and foundRecord == False: #If multiple records found, choose the latest
                print('Found multiple records in surgical DB matching MRN', self._dataDict['MRN1'], '. Choosing latest record.')
                recordNumber = columnRecords[-1]
                foundRecord = True

        if foundRecord:
            self._dataDict['Had surgery'] = True
            self._dataDict['Entirely Seizure-Free'] = surgicalPatients.loc[recordNumber]['boolean']
            self._dataDict['OP Type'] = surgicalPatients.loc[recordNumber]['OP Type']
            self._dataDict['OP Date'] = surgicalPatients.loc[recordNumber]['OP Date']
            self._dataDict['Side'] = surgicalPatients.loc[recordNumber]['Side']
            self._dataDict['IC'] = surgicalPatients.loc[recordNumber]['IC']
        else:
            self._dataDict['Had surgery'] = False
            self._dataDict['Entirely Seizure-Free'] = None
            self._dataDict['OP Type'] = None
            self._dataDict['OP Date'] = None
            self._dataDict['Side'] = None
            self._dataDict['IC'] = None

In [3]:
def defineRemovalLists():
    fullRemovalList = []
    basicBannedList = [
        'sz',
        ' >',
        '>',  
    ]
    bracketOnlyBannedList = [
        'h',
        'vt',
        'vtr',
        'r',
    ]
    specialBannedList = [
        'h/vt',
        'h/vtr',
        'vt/h',
        'vtr/h',
        
        'h/vt',
        'h/vtr',
        'vt/h',
        'vtr/h',
        
        'h /vt',
        'h /vtr',
        'vt /h',
        'vtr /h',
        
        'h;vt',
        'h;vtr',
        'vt;h',
        'vtr;h',
        
        'h; vt',
        'h; vtr',
        'vt; h',
        'vtr; h',
        
        'h,vt',
        'h,vtr',
        'vt,h',
        'vtr,h',
        
        'h, vt',
        'h, vtr',
        'vt, h',
        'vtr, h',
        '',
    ]
    brackets = [['(', ')'],
                ['{', '}'],
                ['[', ']'],]
    
    for stringToRemove in specialBannedList+basicBannedList:
        fullRemovalList.append(stringToRemove)
        for bracket in brackets:
            fullRemovalList.append((bracket[0]+stringToRemove+bracket[1]))
    
    for stringToRemove in bracketOnlyBannedList:
        for bracket in brackets:
                fullRemovalList.append((bracket[0]+stringToRemove+bracket[1]))
            
    return fullRemovalList

In [4]:
def stripClassificationStep(stepText):
    stepText = stepText.lower()
    for stringToRemove in fullRemovalList:
        stepText = stepText.replace(stringToRemove, '')
    stepText = stepText.rstrip()
    return stepText

In [5]:
def defineDataLabels():
    dataLabels = [
        'MRN1',
        'MRN2',
        'MRN3',
        'MRN4',
        'IDP',
        'Had surgery',
        'OP Date',
        'OP Type',
        'Side',
        'IC',
        'Entirely Seizure-Free',
        'ILAE 1 at 1yr',]
    return dataLabels

In [6]:
def regexYamlMatch(semiologyRegexDict, textToSearch):
    """
    Goes through multiple keys (and corresponding regular expressions), given as a dict,
    and returns a list of keys with matching expressions
    """
    keysFound = []
    for semiologyKey, semiologRegexList in semiologyRegexDict.items():
        result = False
        for semiologyRegex in semiologRegexList:
#             if 'LOA' in textToSearch:
#             print('Pre-result:', result)
#             print('Semiology regex:', semiologyRegex)
#             print('Found?:', bool(re.match(semiologyRegex, textToSearch)))
            result = result or bool(re.match(semiologyRegex, textToSearch.lower())) or bool(re.match(semiologyRegex, textToSearch)) 
        if result:
            keysFound.append(semiologyKey)
    return keysFound

In [7]:
def flattenDictionary(dictionary):
    final = {}
    def _flattenDictionary(dictionary, key=None):
        if isinstance(dictionary, dict):
            for k, v in dictionary.items():
                _flattenDictionary(v, k)
        else:
            final[key] = dictionary
    _flattenDictionary(dictionary)
    return final

In [8]:
def extractSurgicalPatientsDf(surgicalPatientRecordPath):
    surgicalPatients = pd.read_excel(surgicalPatientRecordPath) #loads a df of patients who have had surgery
    splitHospitalNumberDf = surgicalPatients['Hospital No'].str.split(', ', expand=True) #splits hospital numbers, if multiple
    numberOfSurgicalRecords, maxHospitalNumbers = splitHospitalNumberDf.shape #maximum number of hospital numbers per user
    surgicalPatients = pd.concat([surgicalPatients['Hospital No'].str.split(',', expand=True), surgicalPatients], axis=1)
    return surgicalPatients, maxHospitalNumbers

In [9]:
def log_progress(sequence, every=None, size=None, name='Items'):
    from ipywidgets import IntProgress, HTML, VBox
    from IPython.display import display

    is_iterator = False
    if size is None:
        try:
            size = len(sequence)
        except TypeError:
            is_iterator = True
    if size is not None:
        if every is None:
            if size <= 200:
                every = 1
            else:
                every = int(size / 200)     # every 0.5%
    else:
        assert every is not None, 'sequence is iterator, set every'

    if is_iterator:
        progress = IntProgress(min=0, max=1, value=1)
        progress.bar_style = 'info'
    else:
        progress = IntProgress(min=0, max=size, value=0)
    label = HTML()
    box = VBox(children=[label, progress])
    display(box)

    index = 0
    try:
        for index, record in enumerate(sequence, 1):
            if index == 1 or index % every == 0:
                if is_iterator:
                    label.value = '{name}: {index} / ?'.format(
                        name=name,
                        index=index
                    )
                else:
                    progress.value = index
                    label.value = u'{name}: {index} / {size}'.format(
                        name=name,
                        index=index,
                        size=size
                    )
            yield record
    except:
        progress.bar_style = 'danger'
        raise
    else:
        progress.bar_style = 'success'
        progress.value = index
        label.value = "{name}: {index}".format(
            name=name,
            index=str(index or '?')
        )

In [10]:
#Global file paths
xmlFolderDirectory = '' #'/Volumes/Encrypted/test_XMLs'
surgicalPatientRecordPath = ''
pdfToMatlabToMLTranslationPath = ''
mlLabelsPath = ''
yamlFilePath = ''

In [None]:
#Global variables
pdfToMatlabToMLTranslation = pd.read_csv(pdfToMatlabToMLTranslationPath)
mlLabels = pd.read_csv(mlLabelsPath, header=None, dtype=str)[0].values
with open(yamlFilePath) as f:
    yamlFile = yaml.load(f)
semiologyRegexDict = flattenDictionary(yamlFile['semiology'])
hippocampalSclerosisRegex = yamlFile['Hippocampal Sclerosis']
fullRemovalList = defineRemovalLists()
dataLabels = defineDataLabels()
surgicalPatients, maxHospitalNumbers = extractSurgicalPatientsDf(surgicalPatientRecordPath)

In [None]:
xmlObjects = []
stepsFound = []
stepsNotFound = []
stepsFoundByRegex = []
asAboveFound = []

for filename in log_progress(os.listdir(xmlFolderDirectory)):
    if filename.endswith(".xml"):
        try:
            newXmlRecord = XmlRecord(xmlFolderDirectory+"/"+filename, surgicalPatients=surgicalPatients, maxHospitalNumbers=maxHospitalNumbers)
            xmlRecordDict = newXmlRecord.RawClassificationToChecklists()
            xmlObjects.append(newXmlRecord)

            stepsFound += xmlRecordDict['stepsFound']
            stepsNotFound += xmlRecordDict['stepsNotFound']
            stepsFoundByRegex.append(xmlRecordDict['stepsFoundByRegex'])
            asAboveFound += xmlRecordDict['asAboveFound']
        except:
            print('Error in processing file: ', filename)

In [None]:
print('Classification steps matched directly: ', len(stepsFound))
print('Classification steps found by regex: ', len([i for i in stepsFoundByRegex if i]))
print('Classification steps not matched: ', len(stepsNotFound))
# print('List of classification steps matched by regex: ', stepsFoundByRegex)
# print('List of classification steps not matched: ', stepsNotFound)

In [None]:
df = pd.DataFrame(index=range(len(xmlObjects)), columns=dataLabels+list(mlLabels))
for index, currentXml in enumerate(xmlObjects):
    for dictTitle, dictValue in currentXml._dataDict.items():
        df[dictTitle][index] = dictValue
    for semiologyTitle, semiologyValue in currentXml._mlFeatureDict.items():
        try:
            df[semiologyTitle][index] = semiologyValue
        except KeyError: #Picks up 'delete'
            pass
    df['Hippocampal Sclerosis'][index] = currentXml._hippocampalSclerosis