In [1]:
import os
import json
import pandas as pd
import spacy
import numpy as np
from tqdm.notebook import tqdm
from spacy.tokens import DocBin
from spacy.vocab import Vocab
from spacy.scorer import Scorer
from spacy.training import Example
from sklearn.metrics import f1_score, precision_score, recall_score, precision_recall_fscore_support

In [2]:
def load_and_run_model(base_path, file, model):
    # Load model components and test data
    nlp = spacy.load(base_path + file + '/' + model + '/model-best')
    doc_bin = DocBin().from_disk(base_path + file + '/test.spacy')
    docs = list(doc_bin.get_docs(nlp.vocab))
    tok2vec = nlp.get_pipe('tok2vec')
    spancat = nlp.get_pipe('spancat')
    scorer = Scorer(nlp)
    
    label_rev_map = {v: k for k, v in spancat._label_map.items()}
    label_rev_map[spancat._negative_label_i] = 'no_label'

    # Run inference
    examples = []
    true_spans = []
    pred_spans = []
    fp_count = 0
    fp_cols = ['mild', 'moderate', 'severe', 'present']
    
    for doc in docs:
        # Run pipeline
        prediction = nlp(doc.text)
        example = Example(prediction, doc)
        examples.append(example)

        # Run components separately, to access predictions for specific spans
        pred_doc = tok2vec(doc)
        indices, scores = spancat.predict([pred_doc])
        for span in pred_doc.spans['sc']:
            
            # Find index to corresponding prediction
            target_value = [span.start, span.end]
            span_index = None
            for i, value in enumerate(indices.data):
                if np.array_equal(value, target_value):
                    span_index = i
                    break
                    
            # Access prediction for gold span        
            gold_span_preds = scores[[span_index]]
            predicted_label = label_rev_map[gold_span_preds.argmax()]
            
            # Store span labels for PRF calculations
            true_spans.append(span.label_)
            pred_spans.append(predicted_label)

        for span in prediction.spans['sc']:
            # Find false positive labels
            span.label_ = span.label_.replace('not_present', 'normal')
            if any(fp_col in span.label_ for fp_col in fp_cols):
                if not any(fp_col in sp.label_.replace('not_present', 'normal') for sp in doc.spans['sc'] for fp_col in fp_cols):
                    fp_count += 1
    
    # Assess performance from inference
    scores = scorer.score(examples)
    scores['fp_manual'] = fp_count / len(pred_spans)
    return examples, scores, true_spans, pred_spans

def _tuple_overlap(tL, tR):
    # tL: tuple(begin, end)
    # tR: tuple(begin, end)
    tLrange = set(range(len(tL)))
    tRrange = set(range(len(tR)))
               
    InterSection = len(tLrange.intersection(tRrange))
    Union = len(tLrange.union(tRrange))

    return InterSection/Union if Union>0 else np.nan


def span_overlap_counter(examples):
    overlaplist = []
    for case in examples:
        span_set_spancat = set()
        span_set_labeled = set()
        
        for span in case.reference.spans['sc']:
            span_set_labeled.add(span)
            
        for span in case.predicted.spans['sc']:
            span_set_spancat.add(span)

        jaccard_indices = []
        for span_l in span_set_labeled:
            _jaccard_indices = []
            for span_s in span_set_spancat:
                _jaccard_indices.append(_tuple_overlap(span_l, span_s))
            try:
                max_ = max(_jaccard_indices)
            except:
                max_ = np.nan
            jaccard_indices.append(max_)
        overlaplist.append(jaccard_indices)
    return overlaplist


def span_overlap_counter_reverse(examples):
    overlaplist = []
    for case in examples:
        span_set_spancat = set()
        span_set_labeled = set()
        
        for span in case.reference.spans['sc']:
            span_set_labeled.add(span)
            
        for span in case.predicted.spans['sc']:
            span_set_spancat.add(span)

        jaccard_indices = []
        for span_s in span_set_spancat:
            _jaccard_indices = []
            for span_l in span_set_labeled:
                _jaccard_indices.append(_tuple_overlap(span_l, span_s))
            try:
                max_ = max(_jaccard_indices)
            except:
                max_ = np.nan
            jaccard_indices.append(max_)
        overlaplist.append(jaccard_indices)
    return overlaplist


def calculate_false_label_proportion(true_spans, pred_spans):
    fp = 0
    total = len(true_spans)
    for true, pred in zip(true_spans, pred_spans):
        if true != pred and pred != 'no_label':
            fp += 1
    return fp / total

In [3]:
base_path = '/training/echo/text_mining/spancat_models/reduced_labels/spacy_data/'

df = pd.DataFrame(columns=['entity', 
                           'p_w_06', 'p_w_08', 'p_w_10', 'p_m_06', 'p_m_08', 'p_m_10', 
                           'r_w_06', 'r_w_08', 'r_w_10', 'r_m_06', 'r_m_08', 'r_m_10',
                           'f_w_06', 'f_w_08', 'f_w_10', 'f_m_06', 'f_m_08', 'f_m_10',
                           'jaccard_06', 'jaccard_08', 'jaccard_10', 'jaccard_rev_06', 'jaccard_rev_08', 'jaccard_rev_10',
                          'pgold_w_06', 'pgold_w_08', 'pgold_w_10', 'pgold_m_06', 'pgold_m_08', 'pgold_m_10', 'pgold_mi_06', 'pgold_mi_08', 'pgold_mi_10',
                          'rgold_w_06', 'rgold_w_08', 'rgold_w_10', 'rgold_m_06', 'rgold_m_08', 'rgold_m_10', 'rgold_mi_06', 'rgold_mi_08', 'rgold_mi_10',
                          'fgold_w_06', 'fgold_w_08', 'fgold_w_10', 'fgold_m_06', 'fgold_m_08', 'fgold_m_10', 'fgold_mi_06', 'fgold_mi_08', 'fgold_mi_10', 
                           'fp', 'fp_manual'])

files = [x for x in os.listdir(base_path) if not x.startswith('.')]

# Iterate over all abnormalities
for file in tqdm(files):
    data = {'entity': file}

    # Iterate over model versions
    for model in [x for x in os.listdir(base_path + file) if not x.endswith('.spacy')]:
        
        # Extract model version
        nw = model.split('_')[-1]

        # Load model and data, run inference
        examples, scores, true_spans, pred_spans = load_and_run_model(base_path, file, model)
        
        # Table 3
        scores_w = precision_recall_fscore_support(true_spans, pred_spans, average='weighted', zero_division=0) # Weighted PRF for gold spans
        scores_m = precision_recall_fscore_support(true_spans, pred_spans, average='macro', zero_division=0) # Macro PRF for gold spans
        scores_mi = precision_recall_fscore_support(true_spans, pred_spans, average='micro', zero_division=0) # Micro PRF for gold spans
        
        data[f'pgold_w_{nw}'] = scores_w[0]
        data[f'pgold_m_{nw}'] = scores_m[0]
        data[f'pgold_mi_{nw}'] = scores_mi[0]

        data[f'rgold_w_{nw}'] = scores_w[1]
        data[f'rgold_m_{nw}'] = scores_m[1]
        data[f'rgold_mi_{nw}'] = scores_mi[1]

        data[f'fgold_w_{nw}'] = scores_w[2]
        data[f'fgold_m_{nw}'] = scores_m[2]
        data[f'fgold_mi_{nw}'] = scores_mi[2]
        
        # Assess PRF for all spans
        for metric in ['p', 'r', 'f']:
            # Table 4
            data[f'{metric}_w_{nw}'] = scores[f'spans_sc_{metric}'] # Weighted PRF (identical to PRF reported in meta.json)
            data[f'{metric}_m_{nw}'] = np.mean([v[f'{metric}'] for _, v in scores['spans_sc_per_type'].items()]) # Macro PRF 

        # Assess Jaccard index (Table 5)
        OverlapJaccardIndices = span_overlap_counter(examples)
        OverlapJaccardIndicesRev = span_overlap_counter_reverse(examples)
        data[f'jaccard_{nw}'] = np.nanmean([_v for v in OverlapJaccardIndices for _v in v])
        data[f'jaccard_rev_{nw}'] = np.nanmean([_v for v in OverlapJaccardIndicesRev for _v in v])

        # Calculate false label proportion (Table ?)
        #data['fp'] = calculate_false_label_proportion(true_spans, pred_spans)
        data['fp_manual'] = scores['fp_manual']
        
    # Add data row
    df.loc[len(df)] = data

  0%|          | 0/12 [00:00<?, ?it/s]

In [4]:
df

Unnamed: 0,entity,p_w_06,p_w_08,p_w_10,p_m_06,p_m_08,p_m_10,r_w_06,r_w_08,r_w_10,...,fgold_w_08,fgold_w_10,fgold_m_06,fgold_m_08,fgold_m_10,fgold_mi_06,fgold_mi_08,fgold_mi_10,fp,fp_manual
0,rv_dil,0.925852,0.944196,0.951835,0.735818,0.955034,0.742884,0.87666,0.802657,0.787476,...,0.889123,0.875654,0.617692,0.754865,0.573332,0.876894,0.80303,0.787879,,0.015152
1,aortic_stenosis,0.864238,0.891697,0.892473,0.851292,0.863274,0.867674,0.786145,0.743976,0.75,...,0.848225,0.853161,0.629022,0.602891,0.616769,0.786145,0.743976,0.75,,0.006024
2,pe,0.893617,0.935673,0.905556,0.315992,0.187135,0.302381,0.702929,0.669456,0.682008,...,0.71252,0.743161,0.242167,0.155945,0.218207,0.702929,0.669456,0.682008,,0.004184
3,aortic_regurgitation,0.944316,0.954869,0.965087,0.727276,0.730792,0.712014,0.853249,0.842767,0.811321,...,0.905885,0.884206,0.535125,0.532867,0.509319,0.854167,0.84375,0.810417,,0.002083
4,lv_dil,0.849885,0.838235,0.854497,0.930521,0.488377,0.511951,0.823266,0.765101,0.722595,...,0.830974,0.811082,0.763564,0.42904,0.428664,0.821429,0.763393,0.720982,,0.006696
5,merged_labels,0.847892,0.840834,0.853057,0.693673,0.575646,0.729104,0.615839,0.609057,0.613433,...,0.65139,0.663798,0.599363,0.512913,0.640361,0.616107,0.609123,0.613706,,0.001964
6,lv_syst_func,0.791623,0.760314,0.793959,0.434855,0.418297,0.435125,0.749257,0.767096,0.729435,...,0.853789,0.829904,0.412705,0.420447,0.409802,0.749505,0.767327,0.729703,,0.006931
7,mitral_regurgitation,0.968815,0.957895,0.945674,0.722228,0.947261,0.925614,0.903101,0.881783,0.910853,...,0.930021,0.946911,0.569404,0.725706,0.728678,0.904031,0.882917,0.911708,,0.005758
8,wma,0.586873,0.618834,0.000245,0.587024,0.614089,0.000246,0.605578,0.549801,0.083665,...,0.706895,0.142159,0.511328,0.484996,0.092459,0.605578,0.549801,0.083665,,0.035857
9,tricuspid_regurgitation,0.929825,0.925595,0.954225,0.847541,0.86188,0.486626,0.880886,0.861496,0.750693,...,0.911992,0.799323,0.683904,0.666394,0.367706,0.880886,0.861496,0.750693,,0.00277


In [5]:
# Choose entity-specific best model with regards to negative weight
# Based on highest F-score
df['f_w_max'] = np.nanmax(df[['f_w_06', 'f_w_08', 'f_w_10']], axis=1)
df['nw_max'] = np.nanargmax(df[['f_w_06', 'f_w_08', 'f_w_10']], axis=1)
df['nw_max'].replace({0: '06', 1: '08', 2: '10'}, inplace=True)

# Pick corresponding performance metrics
df['jaccard_max'] = df.apply(lambda x: round(x['jaccard_' + str(x['nw_max'])], 2), axis=1)
df['jaccard_rev_max'] = df.apply(lambda x: round(x['jaccard_rev_' + str(x['nw_max'])], 2), axis=1)

for metric in ['p', 'r', 'f']:
    df[f'{metric}'] = df.apply(lambda x: str(round(x[f'{metric}_w_' + str(x['nw_max'])], 2)) + ' (' + str(round(x[f'{metric}_m_' + str(x['nw_max'])], 2)) + ')', axis=1)
    df[f'{metric}gold'] = df.apply(lambda x: str(round(x[f'{metric}gold_w_' + str(x['nw_max'])], 2)) + ' (' + str(round(x[f'{metric}gold_m_' + str(x['nw_max'])], 2)) + ')', axis=1)

df['fp'] = df['fp'].apply(lambda x: round(x, 3))

In [6]:
df = df.sort_values('entity').reset_index(drop=True)

In [7]:
df[['entity', 'nw_max']]

Unnamed: 0,entity,nw_max
0,aortic_regurgitation,6
1,aortic_stenosis,6
2,diastolic_dysfunction,6
3,lv_dil,6
4,lv_syst_func,6
5,merged_labels,10
6,mitral_regurgitation,6
7,pe,6
8,rv_dil,6
9,rv_syst_func,6


In [7]:
df

Unnamed: 0,entity,p_w_06,p_w_08,p_w_10,p_m_06,p_m_08,p_m_10,r_w_06,r_w_08,r_w_10,...,f_w_max,nw_max,jaccard_max,jaccard_rev_max,p,pgold,r,rgold,f,fgold
0,aortic_regurgitation,0.944316,0.954869,0.965087,0.727276,0.730792,0.712014,0.853249,0.842767,0.811321,...,0.896476,6,0.99,0.99,0.94 (0.73),0.97 (0.58),0.85 (0.62),0.85 (0.5),0.9 (0.67),0.91 (0.54)
1,aortic_stenosis,0.864238,0.891697,0.892473,0.851292,0.863274,0.867674,0.786145,0.743976,0.75,...,0.823344,6,0.96,0.97,0.86 (0.85),1.0 (0.78),0.79 (0.67),0.79 (0.53),0.82 (0.74),0.88 (0.63)
2,diastolic_dysfunction,0.901818,0.889286,0.911197,0.859334,0.852451,0.870997,0.849315,0.85274,0.808219,...,0.87478,6,0.98,0.98,0.9 (0.86),0.98 (0.77),0.85 (0.81),0.85 (0.64),0.87 (0.83),0.91 (0.7)
3,lv_dil,0.849885,0.838235,0.854497,0.930521,0.488377,0.511951,0.823266,0.765101,0.722595,...,0.836364,6,0.96,0.95,0.85 (0.93),1.0 (0.83),0.82 (0.85),0.82 (0.71),0.84 (0.89),0.9 (0.76)
4,lv_syst_func,0.791623,0.760314,0.793959,0.434855,0.418297,0.435125,0.749257,0.767096,0.729435,...,0.769857,6,0.95,0.95,0.79 (0.43),0.97 (0.49),0.75 (0.41),0.75 (0.36),0.77 (0.42),0.84 (0.41)
5,merged_labels,0.847892,0.840834,0.853057,0.693673,0.575646,0.729104,0.615839,0.609057,0.613433,...,0.713668,10,0.97,0.98,0.85 (0.73),0.77 (0.72),0.61 (0.6),0.61 (0.59),0.71 (0.65),0.66 (0.64)
6,mitral_regurgitation,0.968815,0.957895,0.945674,0.722228,0.947261,0.925614,0.903101,0.881783,0.910853,...,0.934804,6,0.99,1.0,0.97 (0.72),0.96 (0.59),0.9 (0.69),0.9 (0.55),0.93 (0.71),0.93 (0.57)
7,pe,0.893617,0.935673,0.905556,0.315992,0.187135,0.302381,0.702929,0.669456,0.682008,...,0.786885,6,0.96,0.97,0.89 (0.32),0.85 (0.3),0.7 (0.25),0.7 (0.21),0.79 (0.28),0.76 (0.24)
8,rv_dil,0.925852,0.944196,0.951835,0.735818,0.955034,0.742884,0.87666,0.802657,0.787476,...,0.900585,6,0.99,0.99,0.93 (0.74),0.98 (0.65),0.88 (0.71),0.88 (0.59),0.9 (0.72),0.93 (0.62)
9,rv_syst_func,0.897059,0.930952,0.930864,0.620974,0.465942,0.449998,0.875,0.80123,0.772541,...,0.885892,6,0.99,0.99,0.9 (0.62),0.94 (0.55),0.88 (0.66),0.88 (0.53),0.89 (0.64),0.91 (0.54)


In [8]:
## Table 3 - PRF scores for the exact gold spans
#df[['entity', 'fgold', 'rgold', 'pgold']].to_latex('/training/echo/text_mining/output/table3_spancat_pipeline_performance_goldspans.tex', index=False)
df[['entity', 'fgold', 'rgold', 'pgold']]

Unnamed: 0,entity,fgold,rgold,pgold
0,aortic_regurgitation,0.91 (0.54),0.85 (0.5),0.97 (0.58)
1,aortic_stenosis,0.88 (0.63),0.79 (0.53),1.0 (0.78)
2,diastolic_dysfunction,0.91 (0.7),0.85 (0.64),0.98 (0.77)
3,lv_dil,0.9 (0.76),0.82 (0.71),1.0 (0.83)
4,lv_syst_func,0.84 (0.41),0.75 (0.36),0.97 (0.49)
5,merged_labels,0.66 (0.64),0.61 (0.59),0.77 (0.72)
6,mitral_regurgitation,0.93 (0.57),0.9 (0.55),0.96 (0.59)
7,pe,0.76 (0.24),0.7 (0.21),0.85 (0.3)
8,rv_dil,0.93 (0.62),0.88 (0.59),0.98 (0.65)
9,rv_syst_func,0.91 (0.54),0.88 (0.53),0.94 (0.55)


In [9]:
## Table 4 - PRF scores for all predicted spans
#df[['entity', 'f', 'r', 'p']].to_latex('/training/echo/text_mining/output/table4_spancat_pipeline_performance.tex', index=False)
df[['entity', 'f', 'r', 'p']]

Unnamed: 0,entity,f,r,p
0,aortic_regurgitation,0.9 (0.67),0.85 (0.62),0.94 (0.73)
1,aortic_stenosis,0.82 (0.74),0.79 (0.67),0.86 (0.85)
2,diastolic_dysfunction,0.87 (0.83),0.85 (0.81),0.9 (0.86)
3,lv_dil,0.84 (0.89),0.82 (0.85),0.85 (0.93)
4,lv_syst_func,0.77 (0.42),0.75 (0.41),0.79 (0.43)
5,merged_labels,0.71 (0.65),0.61 (0.6),0.85 (0.73)
6,mitral_regurgitation,0.93 (0.71),0.9 (0.69),0.97 (0.72)
7,pe,0.79 (0.28),0.7 (0.25),0.89 (0.32)
8,rv_dil,0.9 (0.72),0.88 (0.71),0.93 (0.74)
9,rv_syst_func,0.89 (0.64),0.88 (0.66),0.9 (0.62)


In [10]:
## Table 5 - Jaccard similarity for all gold spans and for all predicted spans 
#df[['entity', 'jaccard_max', 'jaccard_rev_max']].to_latex('/training/echo/text_mining/output/table5_spancat_jaccard_labeltospan.tex', index=False)
df[['entity', 'jaccard_max', 'jaccard_rev_max']]

Unnamed: 0,entity,jaccard_max,jaccard_rev_max
0,aortic_regurgitation,0.99,0.99
1,aortic_stenosis,0.96,0.97
2,diastolic_dysfunction,0.98,0.98
3,lv_dil,0.96,0.95
4,lv_syst_func,0.95,0.95
5,merged_labels,0.97,0.98
6,mitral_regurgitation,0.99,1.0
7,pe,0.96,0.97
8,rv_dil,0.99,0.99
9,rv_syst_func,0.99,0.99


In [11]:
## Table x - False positives as a percentage of all predicted labels
df[['entity', 'fp_manual']].round(3).to_latex('/home/jovyan/work/projects/echo_text_mining/output/tablex_spancat_fp.tex', index=False)
df[['entity', 'fp_manual']]

Unnamed: 0,entity,fp_manual
0,aortic_regurgitation,0.002083
1,aortic_stenosis,0.006024
2,diastolic_dysfunction,0.030822
3,lv_dil,0.006696
4,lv_syst_func,0.006931
5,merged_labels,0.001964
6,mitral_regurgitation,0.005758
7,pe,0.004184
8,rv_dil,0.015152
9,rv_syst_func,0.026639
