In [69]:
import json
import pandas as pd
import os
import numpy as np
from spacy.cli.debug_data import debug_data, debug_data_cli, _get_span_characteristics

In [2]:
columns = ['Entity', 'Cases', 'Any label', 'normal', 'mild', 'moderate', 'severe', 'present']
rows = ['aortic_stenosis', 'aortic_regurgitation', 'lv_dias_func', 'lv_sys_func', 'rv_sys_func', 'lv_dil', 'rv_dil', 'tricuspid_regurgitation'
       'mitral_regurgitation', 'pe', 'wma']
label_count_df = pd.DataFrame(columns=columns)
label_count_dict = {}

with open('/training/echo/text_mining/datasets/spancat/reduced_labels/merged_labels.jsonl', 'r') as f:
    for line in f:
        json_line = json.loads(line)
        for span in json_line['spans']:
            label = span['label']
            if 'not_present' in label:
                entity = '_'.join(label.split('_')[:-2])
            else:
                entity = '_'.join(label.split('_')[:-1])
            if span['label'] in label_count_dict:
                label_count_dict[span['label']] += 1
            else:
                label_count_dict[span['label']] = 1

In [3]:
dict_per_type = {}
for entity, value in label_count_dict.items():
    split_entity = entity.split('_')
    if len(split_entity) == 1:
        severity = 'present'
        entity = split_entity[0]
    elif len(split_entity) > 2 and '_'.join(split_entity[-2:]) == 'not_present':
        severity = 'normal'
        entity = '_'.join(split_entity[:-2])
    else:
        severity = split_entity[-1]
        entity = '_'.join(split_entity[:-1])
    if entity in dict_per_type.keys():
        dict_per_type[entity][severity] = value
    else:
        dict_per_type[entity] = {}
        dict_per_type[entity][severity] = value        

In [4]:
columns = ['Entity', 'Cases', 'Any label', 'normal', 'mild', 'moderate', 'severe', 'present']
# rows = ['aortic_stenosis', 'aortic_regurgitation', 'lv_dias_func', 'lv_sys_func', 'rv_sys_func', 'lv_dil', 'rv_dil', 'tricuspid_regurgitation'
#        'mitral_regurgitation', 'pe', 'wma']
label_count_df = pd.DataFrame(columns=columns)
for entity, values in dict_per_type.items():
    sum = 0
    data = {'Entity': entity}
    for label, value in values.items():    
        sum += value
        data[label] = value
    data['Any label'] = sum
    label_count_df.loc[len(label_count_df)] = data

In [5]:
label_count_df

Unnamed: 0,Entity,Cases,Any label,normal,mild,moderate,severe,present
0,rv_sys_func,,2640,1932,445.0,199.0,64.0,
1,lv_sys_func,,5247,3113,1042.0,495.0,494.0,
2,lv_dil,,2469,1925,256.0,94.0,52.0,142.0
3,lv_dias_func,,1597,536,665.0,263.0,133.0,
4,tricuspid_valve_native_regurgitation,,1954,1422,294.0,165.0,73.0,
5,wma,,1334,421,,,,913.0
6,aortic_valve_native_regurgitation,,2318,1657,501.0,123.0,37.0,
7,mitral_valve_native_regurgitation,,2902,1793,814.0,228.0,67.0,
8,aortic_valve_native_stenosis,,1850,1582,111.0,73.0,84.0,
9,rv_dil,,1723,1370,165.0,75.0,28.0,85.0


# Document-level statistics

In [6]:
path = '/training/echo/text_mining/datasets/spancat/reduced_labels/'
files = [x for x in os.listdir(path) if x.endswith('.jsonl') and 'merged_labels' not in x]

In [7]:
label_dict = {'No label': -1,
             'Normal': 0,
             'Present': 1,
             'Mild': 2,
             'Moderate': 3,
             'Severe': 4}
label_dict_rev = {v: k for k, v in label_dict.items()}

columns = ['Entity', 'Cases', 'Any label', 'Normal', 'Mild', 'Moderate', 'Severe', 'Present']
df = pd.DataFrame(columns=columns)

for file in files:
    with open(path + file, 'r') as f:
        entity = file.split('.jsonl')[0]
        nrows = 0
        any_label = 0
        count_dict = {k: 0 for k in label_dict.keys()}        
        for line in f:
            nrows += 1
            json_line = json.loads(line)
            highest_label_score = -1
            for span in json_line['spans']:
                label = span['label']
                if (label.endswith('not_present') or label.endswith('normal')) and highest_label_score < 0:
                    highest_label_score = 0
                if ((label.endswith('present') and not label.endswith('not_present')) or label == 'pe') and highest_label_score < 1:
                    highest_label_score = 1
                if label.endswith('mild') and highest_label_score < 2:
                    highest_label_score = 2
                if label.endswith('moderate') and highest_label_score < 3:
                    highest_label_score = 3
                if label.endswith('severe') and highest_label_score < 4:
                    highest_label_score = 4
            highest_label = label_dict_rev[highest_label_score]
            count_dict[highest_label] += 1
            if highest_label != 'No label':
                any_label += 1
        data = {'Entity': entity, 'Cases': nrows, 'Any label': any_label, 'Normal': count_dict['Normal'], 'Mild': count_dict['Mild'], 
                'Moderate': count_dict['Moderate'], 'Severe': count_dict['Severe'], 'Present': count_dict['Present']}
        df.loc[len(df)] = data

In [8]:
df = df.sort_values('Entity').reset_index(drop=True)

In [9]:
df.to_latex('/training/echo/text_mining/output/table2_document_label_counts.tex', index=False)

# Span-level statistics

In [21]:
path = '/training/echo/text_mining/datasets/spancat/reduced_labels/'
files = [x for x in os.listdir(path) if x.endswith('.jsonl') and 'merged_labels' not in x]

In [47]:
label_dict = {'Normal': 0,
             'Present': 1,
             'Mild': 2,
             'Moderate': 3,
             'Severe': 4}
label_dict_rev = {v: k for k, v in label_dict.items()}

columns = ['Entity', 'Cases', 'Total # of spans', 'Median span characters (IQR)', 'Median normal span characters (IQR)', 'Median non-normal span characters (IQR)', 'Normal', 'Mild', 'Moderate', 'Severe', 'Present']
df = pd.DataFrame(columns=columns)

for file in files:
    with open(path + file, 'r') as f:
        entity = file.split('.jsonl')[0]
        nrows = 0
        spans = 0
        span_length = []
        span_length_norm = []
        span_length_sev = []
        count_dict = {k: 0 for k in label_dict.keys()}        
        for line in f:
            nrows += 1
            json_line = json.loads(line)
            for span in json_line['spans']:
                label = span['label']
                if (label.endswith('not_present') or label.endswith('normal')):
                    count_dict['Normal'] += 1
                    span_length_norm.append(span['end'] - span['start'] - 1)
                if ((label.endswith('present') and not label.endswith('not_present')) or label == 'pe'):
                    count_dict['Present'] += 1   
                if label.endswith('mild'):
                    count_dict['Mild'] += 1  
                if label.endswith('moderate'):
                    count_dict['Moderate'] += 1  
                if label.endswith('severe'):
                    count_dict['Severe'] += 1  
                spans += 1
                span_length.append(span['end'] - span['start'] - 1)
                if not (label.endswith('not_present') or label.endswith('normal')):
                    span_length_sev.append(span['end'] - span['start'] - 1)
        data = {'Entity': entity, 'Cases': nrows, 'Total # of spans': spans, 
                'Median span characters (IQR)': f'{np.median(span_length)} ({np.quantile(span_length, 0.25)}-{np.quantile(span_length, 0.75)})', 
                'Median normal span characters (IQR)': f'{np.median(span_length_norm)} ({np.quantile(span_length_norm, 0.25)}-{np.quantile(span_length_norm, 0.75)})', 
                'Median non-normal span characters (IQR)': f'{np.median(span_length_sev)} ({np.quantile(span_length_sev, 0.25)}-{np.quantile(span_length_sev, 0.75)})', 
                'Normal': count_dict['Normal'], 'Mild': count_dict['Mild'], 
                'Moderate': count_dict['Moderate'], 'Severe': count_dict['Severe'], 'Present': count_dict['Present']}
        df.loc[len(df)] = data

In [48]:
df = df.sort_values('Entity').reset_index(drop=True)

In [49]:
df

Unnamed: 0,Entity,Cases,Total # of spans,Median span characters (IQR),Median normal span characters (IQR),Median non-normal span characters (IQR),Normal,Mild,Moderate,Severe,Present
0,aortic_regurgitation,5615,2607,19.0 (11.0-26.0),19.0 (17.0-27.0),11.0 (10.0-21.0),1849,562,146,50,0
1,aortic_stenosis,5000,1850,19.0 (17.0-26.75),19.0 (17.0-28.0),19.0 (11.0-23.0),1582,111,73,84,0
2,diastolic_dysfunction,5000,1597,30.0 (21.0-31.0),27.0 (21.0-33.0),30.0 (18.0-31.0),536,665,263,133,0
3,lv_dil,5000,2469,19.0 (17.0-29.0),19.0 (17.0-31.0),20.0 (18.0-22.0),1925,256,94,52,142
4,lv_syst_func,5000,5247,27.0 (21.0-36.0),27.0 (21.0-35.0),29.0 (20.0-38.0),3113,1042,495,494,0
5,mitral_regurgitation,5000,2902,19.0 (10.0-25.0),19.0 (16.0-27.0),11.0 (9.0-20.0),1793,814,228,67,0
6,pe,8686,1295,20.0 (15.0-27.0),19.0 (7.0-21.0),38.0 (24.0-58.0),987,158,55,50,45
7,rv_dil,8203,2812,19.0 (17.0-29.25),18.0 (17.0-34.0),20.0 (18.0-22.0),2195,294,132,50,141
8,rv_syst_func,5000,2640,25.0 (18.75-32.0),23.0 (17.0-31.0),27.0 (20.0-35.0),1932,445,199,64,0
9,tricuspid_regurgitation,5000,1954,19.0 (12.0-25.0),19.0 (17.0-28.0),10.0 (8.0-16.0),1422,294,165,73,0


In [None]:
#df.to_latex('/training/echo/text_mining/output/table2_span_label_counts.tex', index=False)

# Additional span characteristics table

In [89]:
import os
import json
import pandas as pd
import spacy
import numpy as np
from spacy.tokens import DocBin
from spacy.tokens import Doc
from spacy.vocab import Vocab
from spacy.scorer import Scorer
from spacy_experimental.coref.coref_scorer import score_span_predictions
from spacy.training import Example
import matplotlib.pyplot as plt
from collections import defaultdict
from spacy import displacy
from spacy.cli.debug_data import _get_span_characteristics, _compile_gold
from tqdm.notebook import tqdm

In [153]:
replace_dict = {
    'aortic_stenosis': 'Aortic stenosis',
    'aortic_regurgitation': 'Aortic regurgitation',
    'diastolic_dysfunction': 'Diastolic dysfunction',
    'lv_syst_func': 'Left ventricular systolic dysunction',
    'rv_syst_func': 'Right ventricular systolic dysfunction',
    'lv_dil': 'Left ventricular dilatation',
    'rv_dil': 'Right ventricular dilatation',
    'tricuspid_regurgitation': 'Tricuspid regurgitation',
    'mitral_regurgitation': 'Mitral regurgitation',
    'pe': 'Pericardial effusion',
    'wma': 'Wall motion abnormalities'
}

In [154]:
files = [x for x in os.listdir('spacy_data/') if not x.startswith('.') and x != 'merged_labels']
cols = ['N_spans', 'Length', 'SD', 'BD']
names = ['Entity', 'Severity']
index = pd.MultiIndex(levels=[[], []], codes=[[], []], names=names)
df = pd.DataFrame(index=index, columns=cols)

for file in tqdm(files):
    nlp = spacy.load('spacy_data/' + file + '/model_nw_10/model-best')
    doc_bin = DocBin().from_disk('spacy_data/' + file + '/full.spacy')
    docs = list(doc_bin.get_docs(nlp.vocab))
    examples = []
    for doc in docs:
        # Run pipeline
        prediction = nlp(doc.text)
        example = Example(prediction, doc)
        examples.append(example)
    
    factory_names = [nlp.get_pipe_meta(pipe).factory for pipe in nlp.pipe_names]
    gold_data = _compile_gold(examples, factory_names, nlp, make_proj=True)
    sc = _get_span_characteristics(examples, gold_data,'sc')
    file = replace_dict[file]
    
    for lab in sc['labels']:
        if lab.endswith('not_present') or lab.endswith('normal'):
            label = 'Normal'
        if (lab.endswith('present') and not lab.endswith('not_present')) or lab == 'pe':
            label = 'Present'
        if lab.endswith('mild'):
            label = 'Mild'
        if lab.endswith('moderate'):
            label = 'Moderate'
        if lab.endswith('severe'):
            label = 'Severe'
        
        df.loc[(file, label), 'N_spans'] = sc['spans_per_type'][lab]
        df.loc[(file, label), 'Length'] = round(sc['lengths'][lab], 2)
        df.loc[(file, label), 'SD'] = round(sc['sd'][lab], 2)
        df.loc[(file, label), 'BD'] = round(sc['bd'][lab], 2)

    df.loc[(file, 'Overall'), 'N_spans'] = np.sum([x for x in sc['spans_per_type'].values()])
    df.loc[(file, 'Overall'), 'Length'] = round(sc['avg_length'], 2)
    df.loc[(file, 'Overall'), 'SD'] = round(sc['avg_sd'], 2)
    df.loc[(file, 'Overall'), 'BD'] = round(sc['avg_bd'], 2)

  0%|          | 0/11 [00:00<?, ?it/s]

In [155]:
severity_order = {'Overall': -1, 'Normal': 0, 'Mild': 1, 'Moderate': 2, 'Severe': 3, 'Present': 4}
df['Severity_order'] = df.index.get_level_values('Severity').map(severity_order)
df = df.sort_values(by=['Severity_order'], kind='mergesort').drop(columns='Severity_order')

In [156]:
df = df.sort_index(level=['Entity'], sort_remaining=False)
df = df.style.format(precision=2)

In [157]:
df

Unnamed: 0_level_0,Unnamed: 1_level_0,N_spans,Length,SD,BD
Entity,Severity,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
Aortic regurgitation,Overall,2607,2.47,2.62,1.25
Aortic regurgitation,Normal,1849,2.48,2.42,1.28
Aortic regurgitation,Mild,562,2.39,3.08,1.05
Aortic regurgitation,Moderate,146,2.68,2.9,1.45
Aortic regurgitation,Severe,50,2.37,4.33,1.8
Aortic stenosis,Overall,1850,2.48,2.6,1.35
Aortic stenosis,Normal,1582,2.48,2.4,1.31
Aortic stenosis,Mild,111,2.45,3.54,1.57
Aortic stenosis,Moderate,73,2.53,3.43,1.71
Aortic stenosis,Severe,84,2.39,4.33,1.52


In [160]:
df.to_latex('/training/echo/text_mining/output/span_label_charateristics.tex')

TypeError: to_latex() got an unexpected keyword argument 'index'

# Train-test split

In [15]:
testid = pd.read_csv('/training/echo/analysis/datasets/test_echoid.csv')

In [16]:
path = '/training/echo/text_mining/datasets/spancat/reduced_labels/'
files = [x for x in os.listdir(path) if x.endswith('.jsonl') and 'merged_labels' not in x]

In [17]:
columns = ['Entity', 'Train', 'Test']
df = pd.DataFrame(columns=columns)

for file in files:
    with open(path + file, 'r') as f:
        entity = file.split('.jsonl')[0]
        count_train = 0
        count_test = 0
        for line in f:
            json_line = json.loads(line)
            if json_line['_input_hash'] in testid['input_hash'].values:
                for span in json_line['spans']:
                    count_test += 1
            else:
                for span in json_line['spans']:
                    count_train += 1
        data = {'Entity': entity, 'Train': count_train, 'Test': count_test}
        df.loc[len(df)] = data

In [18]:
df.sort_values('Entity', inplace=True)

In [19]:
df.to_latex('/training/echo/text_mining/output/table_entity_counts_traintest.tex', index=False)