In [1]:
import os
import json
import pandas as pd
import spacy
import numpy as np
from tqdm.notebook import tqdm
from spacy.tokens import DocBin, Doc, Span, SpanGroup
from spacy.vocab import Vocab
from spacy.matcher import Matcher
from spacy.scorer import Scorer, PRFScore
from spacy.training import Example
from sklearn.metrics import f1_score, precision_score, recall_score
import matplotlib.pyplot as plt

In [2]:
def load_and_run_regex(base_path, file):
    # Load pattern matcher and test data
    folder = os.path.dirname(os.path.dirname(base_path))
    regex_file = folder + '/regex/' + file + '.txt'
    doc_file = base_path + file + '/test.spacy'
    
    nlp = spacy.blank('nl')    
    doc_bin = DocBin().from_disk(doc_file)
    docs = list(doc_bin.get_docs(nlp.vocab))
    matcher = Matcher(nlp.vocab)

    # Read regex patterns
    patterns = []
    with open(regex_file, encoding = "ISO-8859-1") as f:
        for l in f.readlines():
            patterns.append(json.loads(l)) 

    # And add them to the matcher
    for pat in patterns:
        matcher.add(pat['label'], [pat['pattern']])

    # Run inference
    examples = []
    for doc in docs:
        prediction = nlp(doc.text)
        matches = matcher(prediction)
        spans = []
        for match in matches:
            match_span = prediction[match[1]:match[2]]
            match_span.label_ = prediction.vocab.strings[match[0]]
            spans.append(match_span)
        prediction.spans['sc'] = spans
        example = Example(prediction, doc)
        examples.append(example)
    
    return examples

def _tuple_overlap(tL, tR):
    # tL: tuple(begin, end)
    # tR: tuple(begin, end)
    tLrange = set(range(len(tL)))
    tRrange = set(range(len(tR)))
               
    InterSection = len(tLrange.intersection(tRrange))
    Union = len(tLrange.union(tRrange))

    return InterSection/Union if Union>0 else np.nan


def span_overlap_counter(examples):
    overlaplist = []
    for case in examples:
        span_set_spancat = set()
        span_set_labeled = set()
        
        for span in case.reference.spans['sc']:
            span_set_labeled.add(span)
            
        for span in case.predicted.spans['sc']:
            span_set_spancat.add(span)

        jaccard_indices = []
        for span_l in span_set_labeled:
            _jaccard_indices = []
            for span_s in span_set_spancat:
                _jaccard_indices.append(_tuple_overlap(span_l, span_s))
            try:
                max_ = max(_jaccard_indices)
            except:
                max_ = np.nan
            jaccard_indices.append(max_)
        overlaplist.append(jaccard_indices)
    return overlaplist


def span_overlap_counter_reverse(examples):
    overlaplist = []
    for case in examples:
        span_set_spancat = set()
        span_set_labeled = set()
        
        for span in case.reference.spans['sc']:
            span_set_labeled.add(span)
            
        for span in case.predicted.spans['sc']:
            span_set_spancat.add(span)

        jaccard_indices = []
        for span_s in span_set_spancat:
            _jaccard_indices = []
            for span_l in span_set_labeled:
                _jaccard_indices.append(_tuple_overlap(span_l, span_s))
            try:
                max_ = max(_jaccard_indices)
            except:
                max_ = np.nan
            jaccard_indices.append(max_)
        overlaplist.append(jaccard_indices)
    return overlaplist


def regex_scorer(examples, span_key, **cfg):
    score = PRFScore()
    score_gold = PRFScore() 
    score_per_type = dict()
    score_gold_per_type = dict()
    total_pred = 0
    total_pred_gold = 0
    fp_count = 0
    fp_cols = ['mild', 'moderate', 'severe', 'present']
    for example in examples:
        pred_doc = example.predicted
        gold_doc = example.reference

        # Find all labels
        labels = set([k.label_ for k in gold_doc.spans[span_key]])
        
        # If labeled, find all labels in pred
        if len(pred_doc.spans[span_key])>0:
            labels |= set([k.label_ for k in pred_doc.spans[span_key]])
            
        # Set up all labels for per type scoring and prepare gold per type
        gold_per_type: Dict[str, Set] = {label: set() for label in labels}
        
        for label in labels:
            if label not in score_per_type:
                score_per_type[label] = PRFScore()
                score_gold_per_type[label] = PRFScore()
                
        # Find all predidate labels
        gold_spans = set()
        pred_spans = set()
        pred_spans_gold = set()
        for span in gold_doc.spans[span_key]:
            gold_span: Tuple
            gold_span = (span.label_, span.start, span.end - 1)
            gold_spans.add(gold_span)
            gold_per_type[span.label_].add(gold_span)                
        pred_per_type: Dict[str, Set] = {label: set() for label in labels}        
        pred_gold_per_type: Dict[str, Set] = {label: set() for label in labels} 
        if len(pred_doc.spans[span_key])>0:
            for span in pred_doc.spans[span_key]:
                total_pred += 1
                pred_span: Tuple
                pred_span = (span.label_, span.start, span.end - 1)
                pred_spans.add(pred_span)
                pred_per_type[span.label_].add(pred_span)  

                span.label_ = span.label_.replace('not_present', 'normal')
                if any(fp_col in span.label_ for fp_col in fp_cols):
                    if not any(fp_col in sp.label_.replace('not_present', 'normal') for sp in gold_doc.spans[span_key] for fp_col in fp_cols):
                        fp_count += 1
            
            for span in example.get_aligned_spans_x2y(pred_doc.spans[span_key]):
                total_pred_gold += 1
                pred_span_gold: Tuple
                pred_span_gold = (span.label_, span.start, span.end - 1)
                pred_spans_gold.add(pred_span_gold)  
                pred_gold_per_type[span.label_].add(pred_span)    
        # Scores per label
        for k, v in score_per_type.items():
            if k in pred_per_type:
                v.score_set(pred_per_type[k], gold_per_type[k])
        for k, v in score_gold_per_type.items():
            if k in pred_gold_per_type:
                v.score_set(pred_gold_per_type[k], gold_per_type[k])
        # Score for all labels
        score_gold.score_set(pred_spans_gold, gold_spans)
        score.score_set(pred_spans, gold_spans)
    # Assemble final result
    final_scores: Dict[str, Any] = {
        "pgold": None,
        "rgold": None,
        "fgold": None,
        "p": None,
        "r": None,
        "f": None,
    }
    final_scores["score_per_type"] = None
    final_scores["score_gold_per_type"] = None
    if len(score) > 0:
        final_scores["pgold"] = score_gold.precision
        final_scores["rgold"] = score_gold.recall
        final_scores["fgold"] = score_gold.fscore
        final_scores["fpgold"] = score_gold.fp / len(examples)
        final_scores["p"] = score.precision
        final_scores["r"] = score.recall
        final_scores["f"] = score.fscore
        final_scores["fp"] = score.fp / len(examples)    
        final_scores["fp_manual"] = fp_count / total_pred
        final_scores["score_per_type"] = {
            k: v.to_dict() for k, v in score_per_type.items()
        }      
        final_scores["score_gold_per_type"] = {
            k: v.to_dict() for k, v in score_per_type.items()
        }    
    return final_scores

In [16]:
base_path = '/home/jovyan/work/projects/echo_text_mining/spancat_models/reduced_labels/spacy_data/'

df = pd.DataFrame(columns=['entity', 'p_w', 'p_m', 'pgold_w', 'pgold_m', 'r_w', 'r_m', 'rgold_w', 'rgold_m', 'f_w', 'f_m', 'fgold_w', 'fgold_m',
                           'jaccard', 'jaccard_rev', 'fp', 'fpgold', 'fp_manual'])

files = [x for x in os.listdir(base_path) if not x.startswith('.')]

# Iterate over all abnormalities
for file in tqdm(files):
    if file == 'merged_labels':
        continue
    data = {'entity': file}

    # Load model and data, run inference
    examples = load_and_run_regex(base_path, file)
    
    # Assess PRF for all spans
    scores = regex_scorer(examples, 'sc')
    
    for metric in ['p', 'r', 'f']:    
        # Table 3
        data[f'{metric}gold_w'] = scores[f'{metric}gold'] # Weighted PRF for gold spans
        data[f'{metric}gold_m'] = np.mean([v[f'{metric}'] for _, v in scores['score_gold_per_type'].items()]) # Macro PRF for gold spans        
        
        # Table 4
        data[f'{metric}_w'] = scores[f'{metric}'] # Weighted PRF (identical to PRF reported in meta.json)
        data[f'{metric}_m'] = np.mean([v[f'{metric}'] for _, v in scores['score_per_type'].items()]) # Macro PRF

    # Assess Jaccard index (Table 5)
    OverlapJaccardIndices = span_overlap_counter(examples)
    OverlapJaccardIndicesRev = span_overlap_counter_reverse(examples)
    data[f'jaccard'] = round(np.nanmean([_v for v in OverlapJaccardIndices for _v in v]), 2)
    data[f'jaccard_rev'] = round(np.nanmean([_v for v in OverlapJaccardIndicesRev for _v in v]), 2)

    # Table x - False positive predictions
    data['fp'] = scores['fp']
    data['fpgold'] = scores['fpgold']
    data['fp_manual'] = scores['fp_manual']
    
    # Add data row
    df.loc[len(df)] = data

  0%|          | 0/12 [00:00<?, ?it/s]

In [17]:
df.sort_values('entity', inplace=True)
df.reset_index(drop=True, inplace=True)

In [18]:
df

Unnamed: 0,entity,p_w,p_m,pgold_w,pgold_m,r_w,r_m,rgold_w,rgold_m,f_w,f_m,fgold_w,fgold_m,jaccard,jaccard_rev,fp,fpgold,fp_manual
0,aortic_regurgitation,0.944812,0.924106,0.955814,0.924106,0.897275,0.866919,0.861635,0.866919,0.92043,0.893913,0.906284,0.893913,0.99,0.99,0.022936,0.017431,0.002208
1,aortic_stenosis,0.825444,0.769586,0.854037,0.769586,0.840361,0.745696,0.828313,0.745696,0.832836,0.75483,0.840979,0.75483,0.96,0.95,0.060762,0.048404,0.0
2,diastolic_dysfunction,0.607639,0.608597,0.615942,0.608597,0.599315,0.592657,0.582192,0.592657,0.603448,0.60039,0.598592,0.60039,0.85,0.85,0.116375,0.109166,0.041667
3,lv_dil,0.697495,0.855675,0.850602,0.855675,0.809843,0.85831,0.789709,0.85831,0.749482,0.854962,0.819026,0.854962,0.96,0.87,0.161689,0.063852,0.007707
4,lv_syst_func,0.331887,0.30922,0.371968,0.30922,0.155015,0.18582,0.139818,0.18582,0.211326,0.222092,0.20324,0.222092,0.84,0.82,0.317199,0.239959,0.023861
5,mitral_regurgitation,0.92549,0.918622,0.942675,0.918622,0.914729,0.892784,0.860465,0.892784,0.920078,0.905195,0.899696,0.905195,0.99,0.98,0.039135,0.027806,0.005882
6,pe,0.930818,0.256344,0.930818,0.256344,0.619247,0.186538,0.619247,0.186538,0.743719,0.210844,0.743719,0.210844,0.96,0.97,0.00652,0.00652,0.0
7,rv_dil,0.748663,0.883551,0.797619,0.883551,0.796964,0.833642,0.762808,0.833642,0.772059,0.853835,0.779825,0.853835,0.98,0.93,0.087796,0.063512,0.012478
8,rv_syst_func,0.962567,0.946152,0.966851,0.946152,0.368852,0.325782,0.358607,0.325782,0.533333,0.475168,0.523169,0.475168,0.98,0.99,0.007209,0.006179,0.0
9,tricuspid_regurgitation,0.905149,0.822552,0.91453,0.822552,0.925208,0.842265,0.889197,0.842265,0.915068,0.832219,0.901685,0.832219,0.99,0.98,0.036045,0.030896,0.00542


In [19]:
for metric in ['f', 'r', 'p']:
    df[metric] = df.apply(lambda x: str(round(x[f'{metric}_w'], 2)) + ' (' + str(round(x[f'{metric}_m'], 2)) + ')', axis=1)
    df[f'{metric}gold'] = df.apply(lambda x: str(round(x[f'{metric}gold_w'], 2)) + ' (' + str(round(x[f'{metric}gold_m'], 2)) + ')', axis=1)

In [20]:
## Table 3 - PRF scores for the exact gold spans
df[['entity', 'fgold', 'rgold', 'pgold']].to_latex('/training/echo/text_mining/output/table3_regex_pipeline_performance_goldspans.tex', index=False)
df[['entity', 'fgold', 'rgold', 'pgold']]

Unnamed: 0,entity,fgold,rgold,pgold
0,aortic_regurgitation,0.91 (0.89),0.86 (0.87),0.96 (0.92)
1,aortic_stenosis,0.84 (0.75),0.83 (0.75),0.85 (0.77)
2,diastolic_dysfunction,0.6 (0.6),0.58 (0.59),0.62 (0.61)
3,lv_dil,0.82 (0.85),0.79 (0.86),0.85 (0.86)
4,lv_syst_func,0.2 (0.22),0.14 (0.19),0.37 (0.31)
5,mitral_regurgitation,0.9 (0.91),0.86 (0.89),0.94 (0.92)
6,pe,0.74 (0.21),0.62 (0.19),0.93 (0.26)
7,rv_dil,0.78 (0.85),0.76 (0.83),0.8 (0.88)
8,rv_syst_func,0.52 (0.48),0.36 (0.33),0.97 (0.95)
9,tricuspid_regurgitation,0.9 (0.83),0.89 (0.84),0.91 (0.82)


In [21]:
## Table 4 - PRF scores for all predicted spans
df[['entity', 'f', 'r', 'p']].to_latex('/training/echo/text_mining/output/table4_regex_pipeline_performance.tex', index=False)
df[['entity', 'f', 'r', 'p']]

Unnamed: 0,entity,f,r,p
0,aortic_regurgitation,0.92 (0.89),0.9 (0.87),0.94 (0.92)
1,aortic_stenosis,0.83 (0.75),0.84 (0.75),0.83 (0.77)
2,diastolic_dysfunction,0.6 (0.6),0.6 (0.59),0.61 (0.61)
3,lv_dil,0.75 (0.85),0.81 (0.86),0.7 (0.86)
4,lv_syst_func,0.21 (0.22),0.16 (0.19),0.33 (0.31)
5,mitral_regurgitation,0.92 (0.91),0.91 (0.89),0.93 (0.92)
6,pe,0.74 (0.21),0.62 (0.19),0.93 (0.26)
7,rv_dil,0.77 (0.85),0.8 (0.83),0.75 (0.88)
8,rv_syst_func,0.53 (0.48),0.37 (0.33),0.96 (0.95)
9,tricuspid_regurgitation,0.92 (0.83),0.93 (0.84),0.91 (0.82)


In [22]:
## Table 5 - Jaccard similarity for all gold spans and for all predicted spans 
df[['entity', 'jaccard', 'jaccard_rev']].to_latex('/training/echo/text_mining/output/table5_regex_jaccard_labeltospan.tex', index=False)
df[['entity', 'jaccard', 'jaccard_rev']]

Unnamed: 0,entity,jaccard,jaccard_rev
0,aortic_regurgitation,0.99,0.99
1,aortic_stenosis,0.96,0.95
2,diastolic_dysfunction,0.85,0.85
3,lv_dil,0.96,0.87
4,lv_syst_func,0.84,0.82
5,mitral_regurgitation,0.99,0.98
6,pe,0.96,0.97
7,rv_dil,0.98,0.93
8,rv_syst_func,0.98,0.99
9,tricuspid_regurgitation,0.99,0.98


In [23]:
# Table x - False positives
df[['entity', 'fp_manual']].round(3).to_latex('/home/jovyan/work/projects/echo_text_mining/output/tablex_regex_fp.tex', index=False)
df[['entity', 'fp_manual']].round(3)

Unnamed: 0,entity,fp_manual
0,aortic_regurgitation,0.002
1,aortic_stenosis,0.0
2,diastolic_dysfunction,0.042
3,lv_dil,0.008
4,lv_syst_func,0.024
5,mitral_regurgitation,0.006
6,pe,0.0
7,rv_dil,0.012
8,rv_syst_func,0.0
9,tricuspid_regurgitation,0.005
