## Imports

In [1]:
%env CUDA_VISIBLE_DEVICES=1

env: CUDA_VISIBLE_DEVICES=1


In [2]:
import torch
import matplotlib.pyplot as plt
import pandas as pd
import os

In [3]:
import matplotlib
matplotlib.rcParams['figure.facecolor'] = 'white'

In [4]:
pd.options.display.max_columns = None

## Load data

In [5]:
%run ../datasets/__init__.py

In [6]:
dataset_kwargs = {
    'dataset_name': 'iu-x-ray',
    'dataset_type': 'all',
    'max_samples': None,
}
dataloader = prepare_data_classification(**dataset_kwargs)
dataset = dataloader.dataset
len(dataset)

7426

## Compare runtime chexpert vs holistic

### Load holistic chexpert

In [7]:
%run ../datasets/iu_xray.py

In [8]:
fpath = os.path.join(DATASET_DIR, 'reports', 'reports_with_chexpert_labels.csv')
df = pd.read_csv(fpath, index_col=0)
df.replace(-1, 1, inplace=True)
df.replace(-2, 0, inplace=True)
df.head()

Unnamed: 0,Reports,No Finding,Enlarged Cardiomediastinum,Cardiomegaly,Lung Lesion,Lung Opacity,Edema,Consolidation,Pneumonia,Atelectasis,Pneumothorax,Pleural Effusion,Pleural Other,Fracture,Support Devices,filename
0,the cardiac silhouette and mediastinum size ar...,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.xml
1,the cardiomediastinal silhouette is within nor...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,10.xml
2,both lungs are clear and expanded . heart and ...,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,100.xml
3,there is xxxx increased opacity within the rig...,0.0,0.0,0.0,1.0,1.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1000.xml
4,interstitial markings are diffusely prominent ...,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1001.xml


### Calculate light-labeler chexpert

In [8]:
%run ../metrics/report_generation/labeler_correctness/light_labeler.py
%run ../utils/nlp.py

In [9]:
labeler = ChexpertLightLabeler(dataloader.dataset.get_vocab())
labeler

<__main__.ChexpertLightLabeler at 0x7fe2393a3f28>

In [10]:
report_reader = ReportReader(dataloader.dataset.get_vocab())

In [12]:
reports = list(df['Reports'])
reports = [
    report_reader.text_to_idx(report)
    for report in reports
]
len(reports)

3826

In [14]:
%%time

labels = labeler(reports)
labels.shape

CPU times: user 525 ms, sys: 43.4 ms, total: 568 ms
Wall time: 58.2 s


(3826, 14)

In [20]:
labels[labels == -2] = 0
labels[labels == -1] = 1
labels

array([[1., 0., 0., ..., 0., 0., 0.],
       [1., 0., 0., ..., 0., 0., 0.],
       [1., 0., 0., ..., 0., 0., 0.],
       ...,
       [1., 0., 0., ..., 0., 0., 0.],
       [0., 1., 0., ..., 0., 0., 0.],
       [1., 0., 0., ..., 0., 0., 0.]])

### Calculate with full-labeler

In [9]:
%run ../metrics/report_generation/labeler_correctness/full_labeler.py
%run ../utils/nlp.py

In [10]:
labeler = ChexpertFullLabeler(dataloader.dataset.get_vocab())
labeler

<__main__.ChexpertFullLabeler at 0x7fa8d46db9b0>

In [11]:
report_reader = ReportReader(dataloader.dataset.get_vocab())

In [12]:
reports = list(df['Reports'])
reports = [
    report_reader.text_to_idx(report)
    for report in reports
]
len(reports)

3826

In [13]:
%%time

labels = labeler(reports)
labels.shape

CPU times: user 136 ms, sys: 30.6 ms, total: 166 ms
Wall time: 15min 25s


(3826, 14)

In [14]:
labels[labels == -2] = 0
labels[labels == -1] = 1
labels

array([[1., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [1., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 1., 0., ..., 0., 0., 0.],
       [1., 0., 0., ..., 0., 0., 0.]])

### Calculate metrics between runtime and holistic

In [15]:
%run ../metrics/report_generation/chexpert.py
%run -n ../eval_report_generation_chexpert_labeler.py

In [17]:
def add_suffix(col):
    if col in CHEXPERT_LABELS:
        return f'{col}-gt'
    return col
df.rename(
    columns=add_suffix,
    inplace=True,
)
df.head()

Unnamed: 0,Reports,No Finding-gt,Enlarged Cardiomediastinum-gt,Cardiomegaly-gt,Lung Lesion-gt,Lung Opacity-gt,Edema-gt,Consolidation-gt,Pneumonia-gt,Atelectasis-gt,Pneumothorax-gt,Pleural Effusion-gt,Pleural Other-gt,Fracture-gt,Support Devices-gt,filename
0,the cardiac silhouette and mediastinum size ar...,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.xml
1,the cardiomediastinal silhouette is within nor...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,10.xml
2,both lungs are clear and expanded . heart and ...,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,100.xml
3,there is xxxx increased opacity within the rig...,0.0,0.0,0.0,1.0,1.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1000.xml
4,interstitial markings are diffusely prominent ...,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1001.xml


In [18]:
columns = labels_with_suffix('gen')
full_df = pd.concat([df, pd.DataFrame(labels, columns=columns)], axis=1)
full_df

Unnamed: 0,Reports,No Finding-gt,Enlarged Cardiomediastinum-gt,Cardiomegaly-gt,Lung Lesion-gt,Lung Opacity-gt,Edema-gt,Consolidation-gt,Pneumonia-gt,Atelectasis-gt,Pneumothorax-gt,Pleural Effusion-gt,Pleural Other-gt,Fracture-gt,Support Devices-gt,filename,No Finding-gen,Enlarged Cardiomediastinum-gen,Cardiomegaly-gen,Lung Lesion-gen,Lung Opacity-gen,Edema-gen,Consolidation-gen,Pneumonia-gen,Atelectasis-gen,Pneumothorax-gen,Pleural Effusion-gen,Pleural Other-gen,Fracture-gen,Support Devices-gen
0,the cardiac silhouette and mediastinum size ar...,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.xml,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,the cardiomediastinal silhouette is within nor...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,10.xml,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,both lungs are clear and expanded . heart and ...,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,100.xml,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,there is xxxx increased opacity within the rig...,0.0,0.0,0.0,1.0,1.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1000.xml,0.0,0.0,0.0,1.0,1.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0
4,interstitial markings are diffusely prominent ...,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1001.xml,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3821,sternotomy sutures and bypass grafts have been...,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,995.xml,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3822,heart size is normal and lungs are clear . no ...,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,996.xml,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3823,calcified mediastinal xxxx . no focal areas of...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,997.xml,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3824,cardiomediastinal silhouette demonstrates norm...,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,998.xml,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [19]:
acc, precision, recall, f1, roc_auc = _calculate_metrics(full_df)
acc, precision, recall, f1, roc_auc

(array([0.99895452, 0.99973863, 0.99947726, 1.        , 1.        ,
        1.        , 1.        , 0.99973863, 1.        , 0.99973863,
        1.        , 1.        , 1.        , 0.99973863]),
 array([0.99734396, 1.        , 0.99677419, 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 0.98148148,
        1.        , 1.        , 1.        , 1.        ]),
 array([1.        , 0.9974359 , 1.        , 1.        , 1.        ,
        1.        , 1.        , 0.98275862, 1.        , 1.        ,
        1.        , 1.        , 1.        , 0.99561404]),
 array([0.99867021, 0.9987163 , 0.99838449, 1.        , 1.        ,
        1.        , 1.        , 0.99130435, 1.        , 0.99065421,
        1.        , 1.        , 1.        , 0.9978022 ]),
 array([0.99913941, 0.99871795, 0.99968828, 1.        , 1.        ,
        1.        , 1.        , 0.99137931, 1.        , 0.99986748,
        1.        , 1.        , 1.        , 0.99780702]))

In [20]:
roc_auc

array([0.99913941, 0.99871795, 0.99968828, 1.        , 1.        ,
       1.        , 1.        , 0.99137931, 1.        , 0.99986748,
       1.        , 1.        , 1.        , 0.99780702])