In [5]:
import os
import json
import pandas as pd
import spacy
import numpy as np
from tqdm.notebook import tqdm
from spacy.tokens import DocBin
from spacy.vocab import Vocab
from spacy.scorer import Scorer
from spacy.training import Example

In [6]:
def _load_and_run_model(base_path, file, model):
    # Load model components and test data
    nlp = spacy.load(base_path + file + '/' + model + '/model-best')
    doc_bin = DocBin().from_disk(base_path + file + '/test.spacy')
    docs = list(doc_bin.get_docs(nlp.vocab))
    scorer = Scorer(nlp)

    # Run inference
    examples = []
    for doc in docs:
        prediction = nlp(doc.text)
        example = Example(prediction, doc)
        examples.append(example)

    # Assess performance from inference
    scores = scorer.score(examples)
    return examples, scores

def _tuple_overlap(tL, tR):
    # tL: tuple(begin, end)
    # tR: tuple(begin, end)
    tLrange = set(range(len(tL)))
    tRrange = set(range(len(tR)))
               
    InterSection = len(tLrange.intersection(tRrange))
    Union = len(tLrange.union(tRrange))

    return InterSection/Union if Union>0 else np.nan


def span_overlap_counter(examples):
    overlaplist = []
    for case in examples:
        span_set_spancat = set()
        span_set_labeled = set()
        
        for span in case.reference.spans['sc']:
            span_set_labeled.add(span)
            
        for span in case.predicted.spans['sc']:
            span_set_spancat.add(span)

        jaccard_indices = []
        for span_l in span_set_labeled:
            _jaccard_indices = []
            for span_s in span_set_spancat:
                _jaccard_indices.append(_tuple_overlap(span_l, span_s))
            try:
                max_ = max(_jaccard_indices)
            except:
                max_ = np.nan
            jaccard_indices.append(max_)
        overlaplist.append(jaccard_indices)
    return overlaplist


def span_overlap_counter_reverse(examples):
    overlaplist = []
    for case in examples:
        span_set_spancat = set()
        span_set_labeled = set()
        
        for span in case.reference.spans['sc']:
            span_set_labeled.add(span)
            
        for span in case.predicted.spans['sc']:
            span_set_spancat.add(span)

        jaccard_indices = []
        for span_s in span_set_spancat:
            _jaccard_indices = []
            for span_l in span_set_labeled:
                _jaccard_indices.append(_tuple_overlap(span_l, span_s))
            try:
                max_ = max(_jaccard_indices)
            except:
                max_ = np.nan
            jaccard_indices.append(max_)
        overlaplist.append(jaccard_indices)
    return overlaplist

In [10]:
base_path = '/training/echo/text_mining/reduced_labels/spacy_data/'

df = pd.DataFrame(columns=['entity', 'p_t_06', 'p_t_08', 'p_t_10', 'p_w_06', 'p_w_08', 'p_w_10', 'p_m_06', 'p_m_08', 'p_m_10', 
                           'r_t_06', 'r_t_08', 'r_t_10', 'r_w_06', 'r_w_08', 'r_w_10', 'r_m_06', 'r_m_08', 'r_m_10',
                           'f_t_06', 'f_t_08', 'f_t_10', 'f_w_06', 'f_w_08', 'f_w_10', 'f_m_06', 'f_m_08', 'f_m_10',
                           'jaccard_06', 'jaccard_08', 'jaccard_10', 'jaccard_rev_06', 'jaccard_rev_08', 'jaccard_rev_10'])

files = [x for x in os.listdir(base_path) if not x.startswith('.')]

# Iterate over all abnormalities
for file in tqdm(files):
    data = {'entity': file}

    # Iterate over model versions
    for model in [x for x in os.listdir(base_path + file) if not x.endswith('.spacy')]:
        
        # Extract model version
        nw = model.split('_')[-1]

        # Load model and data, run inference
        examples, scores = _load_and_run_model(base_path, file, model)

        # Assess PRF
        for metric in ['p', 'r', 'f']:
            # Token-based PRF
            data[f'{metric}_t_{nw}'] = scores[f'token_{metric}']
            
            # Weighted PRF (identical to PRF reported in meta.json)
            data[f'{metric}_w_{nw}'] = scores[f'spans_sc_{metric}']
            
            # Macro PRF
            data[f'{metric}_m_{nw}'] = np.mean([v[f'{metric}'] for _, v in scores['spans_sc_per_type'].items()])    

        # Assess Jaccard index
        OverlapJaccardIndices = span_overlap_counter(examples)
        OverlapJaccardIndicesRev = span_overlap_counter_reverse(examples)
        data[f'jaccard_{nw}'] = np.nanmean([_v for v in OverlapJaccardIndices for _v in v])
        data[f'jaccard_rev_{nw}'] = np.nanmean([_v for v in OverlapJaccardIndicesRev for _v in v])
        
    # Add data row
    df.loc[len(df)] = data

  0%|          | 0/11 [00:00<?, ?it/s]

In [16]:
df['f_w_max'] = np.nanmax(df[['f_w_06', 'f_w_08', 'f_w_10']], axis=1)

df['nw_max'] = np.nanargmax(df[['f_w_06', 'f_w_08', 'f_w_10']], axis=1)
df['nw_max'].replace({0: '06', 1: '08', 2: '10'}, inplace=True)

df['jaccard_max'] = df.apply(lambda x: round(x['jaccard_' + str(x['nw_max'])], 2), axis=1)
df['jaccard_rev_max'] = df.apply(lambda x: round(x['jaccard_rev_' + str(x['nw_max'])], 2), axis=1)

for metric in ['p', 'r', 'f']:
    df[f'{metric}'] = df.apply(lambda x: str(round(x[f'{metric}_w_' + str(x['nw_max'])], 2)) + ' (' + str(round(x[f'{metric}_m_' + str(x['nw_max'])], 2)) + ')', axis=1)

In [17]:
df.sort_values('entity', inplace=True)

In [18]:
df

Unnamed: 0,entity,p_t_06,p_t_08,p_t_10,p_w_06,p_w_08,p_w_10,p_m_06,p_m_08,p_m_10,...,jaccard_rev_06,jaccard_rev_08,jaccard_rev_10,f_w_max,nw_max,jaccard_max,jaccard_rev_max,p,r,f
3,aortic_regurgitation,0.919592,0.919592,0.919592,0.953162,0.969072,0.961832,0.726851,0.487992,0.481824,...,0.984377,0.990736,0.988368,0.888646,6,0.97,0.98,0.95 (0.73),0.83 (0.59),0.89 (0.65)
1,aortic_stenosis,0.918409,0.918409,0.918409,0.889286,0.884211,0.90146,0.808844,0.830746,0.89515,...,0.971602,0.967452,0.970827,0.818182,8,0.95,0.97,0.88 (0.83),0.76 (0.66),0.82 (0.73)
10,diastolic_dysfunction,0.918409,0.918409,0.918409,0.92549,0.906015,0.926923,0.884612,0.864572,0.922596,...,0.956255,0.954327,0.960298,0.874773,10,0.95,0.96,0.93 (0.92),0.83 (0.77),0.87 (0.83)
4,lv_dil,0.918409,0.918409,0.918409,0.813679,0.841232,0.834606,0.46069,0.694208,0.528561,...,0.91893,0.927243,0.910699,0.817972,8,0.92,0.93,0.84 (0.69),0.8 (0.68),0.82 (0.68)
5,lv_syst_func,0.918409,0.918409,0.918409,0.767179,0.775424,0.76,0.367394,0.372593,0.374272,...,0.903879,0.907641,0.90385,0.754032,6,0.9,0.9,0.77 (0.37),0.74 (0.36),0.75 (0.36)
6,mitral_regurgitation,0.918409,0.918409,0.918409,0.949791,0.946612,0.96614,0.703604,0.716748,0.712843,...,0.97258,0.974003,0.990233,0.921079,8,0.97,0.97,0.95 (0.72),0.9 (0.69),0.92 (0.7)
2,pe,0.9253,0.9253,0.9253,0.833333,0.903743,0.891192,0.266836,0.313543,0.306222,...,0.954478,0.976047,0.973228,0.787185,10,0.96,0.97,0.89 (0.31),0.7 (0.26),0.79 (0.28)
0,rv_dil,0.924163,0.924163,0.924163,0.9375,0.916155,0.913934,0.751793,0.692235,0.724523,...,0.989505,0.98883,0.988254,0.88189,8,0.99,0.99,0.92 (0.69),0.85 (0.67),0.88 (0.68)
9,rv_syst_func,0.918409,0.918409,0.918409,0.892857,0.926437,0.899358,0.694759,0.524931,0.714256,...,0.967988,0.970221,0.966893,0.881743,6,0.96,0.97,0.89 (0.69),0.87 (0.61),0.88 (0.63)
8,tricuspid_regurgitation,0.918409,0.918409,0.918409,0.925072,0.932515,0.929448,0.858377,0.850428,0.837897,...,0.982471,0.986085,0.987553,0.90678,6,0.97,0.98,0.93 (0.86),0.89 (0.81),0.91 (0.83)


In [19]:
## Table 4
#df[['entity', 'f', 'r', 'p']].to_latex('/training/echo/text_mining/output/table4_pipeline_performance.tex', index=False)
df[['entity', 'f', 'r', 'p']]

Unnamed: 0,entity,f,r,p
3,aortic_regurgitation,0.89 (0.65),0.83 (0.59),0.95 (0.73)
1,aortic_stenosis,0.82 (0.73),0.76 (0.66),0.88 (0.83)
10,diastolic_dysfunction,0.87 (0.83),0.83 (0.77),0.93 (0.92)
4,lv_dil,0.82 (0.68),0.8 (0.68),0.84 (0.69)
5,lv_syst_func,0.75 (0.36),0.74 (0.36),0.77 (0.37)
6,mitral_regurgitation,0.92 (0.7),0.9 (0.69),0.95 (0.72)
2,pe,0.79 (0.28),0.7 (0.26),0.89 (0.31)
0,rv_dil,0.88 (0.68),0.85 (0.67),0.92 (0.69)
9,rv_syst_func,0.88 (0.63),0.87 (0.61),0.89 (0.69)
8,tricuspid_regurgitation,0.91 (0.83),0.89 (0.81),0.93 (0.86)


In [20]:
## Table 5
#df[['entity', 'jaccard_max']].to_latex('/training/echo/text_mining/output/table5_jaccard_labeltospan.tex', index=False)
df[['entity', 'jaccard_max', 'jaccard_rev_max']]

Unnamed: 0,entity,jaccard_max,jaccard_rev_max
3,aortic_regurgitation,0.97,0.98
1,aortic_stenosis,0.95,0.97
10,diastolic_dysfunction,0.95,0.96
4,lv_dil,0.92,0.93
5,lv_syst_func,0.9,0.9
6,mitral_regurgitation,0.97,0.97
2,pe,0.96,0.97
0,rv_dil,0.99,0.99
9,rv_syst_func,0.96,0.97
8,tricuspid_regurgitation,0.97,0.98
