# Try abnormality matcher

In [None]:
import os
from collections import Counter, defaultdict
import importlib

In [None]:
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['figure.facecolor'] = 'white'
matplotlib.rcParams['figure.figsize'] = (15, 5)

In [None]:
import pandas as pd
pd.options.display.max_columns = None

In [None]:
%run ../../utils/__init__.py
config_logging(logging.INFO)

In [None]:
%run ../../datasets/common/constants.py
%run ../../datasets/vocab/__init__.py

In [None]:
from medai.datasets import iu_xray, mimic_cxr
IU_DIR = iu_xray.DATASET_DIR
MIMIC_DIR = mimic_cxr.DATASET_DIR

# Inspect sentences

In [None]:
%run ../../metrics/report_generation/abn_match/chexpert.py
%run ../../metrics/report_generation/abn_match/textray.py
%run ../../metrics/report_generation/chexpert.py
%run -n ../../eval_report_generation_chexpert_labeler.py

In [None]:
dataset_dir = IU_DIR
# dataset_dir = MIMIC_DIR

In [None]:
vocab = load_vocab(os.path.join(dataset_dir, 'reports'), 'v4')
len(vocab)

In [None]:
fpath = os.path.join(dataset_dir, 'reports', 'sentences_with_chexpert_labels.csv')
SENTENCES_DF = pd.read_csv(fpath)
SENTENCES_DF.head(3)

In [None]:
def add_suffix(col):
    if col in CHEXPERT_LABELS:
        return f'{col}-gt'
    return col
SENTENCES_DF.rename(columns=add_suffix, inplace=True)
SENTENCES_DF.head(3)

In [None]:
sentences = list(SENTENCES_DF['sentence'])
len(sentences)

In [None]:
labeler = ChexpertLighterLabeler(vocab, use_idx=False, device='cpu')
# labeler = TextRayLabeler(vocab, device='cpu')
labeler

In [None]:
%%time

labels, lung_locations = labeler.label_report(sentences[13])
labels.shape

In [None]:
labels

In [None]:
sentences[13]

In [None]:
lung_locations

In [None]:
if labels.size(1) == 13:
    nf_column = torch.zeros(labels.size(0), device=labels.device).unsqueeze(-1)
    labels = torch.cat((nf_column, labels), dim=1)
labels.size()

In [None]:
labels = labels.cpu().numpy()
assert len(labels) == len(SENTENCES_DF)

In [None]:
full_df = pd.concat([
    SENTENCES_DF,
    pd.DataFrame(labels, index=SENTENCES_DF.index, columns=labels_with_suffix('gen')),
], axis=1)

assert len(full_df) == len(labels)

In [None]:
full_df2 = full_df.replace({-2: 0, -1: 1})

In [None]:
acc, precision, recall, f1, roc_auc, pr_auc = _calculate_metrics(full_df2)
precision[1:], precision[1:].mean(), recall[1:], recall[1:].mean(), f1[1:], f1[1:].mean()

In [None]:
target = labeler.diseases[9]
colgt = f'{target}-gt'
colgen = f'{target}-gen'

d = full_df
# d = d[((d[colgt] != -1) & (d[colgen] == -1))]
d = d[((d[colgt] == -1) & (d[colgen] != -1))]
d = d[['sentence', colgt, colgen]]
print(len(d))
d.head(5)

In [None]:
sorted(list(d['sentence']))

# Inspect modifiers

In [None]:
from collections import Counter
from nltk.corpus import stopwords

In [None]:
STOPWORDS = set(stopwords.words('english'))
STOPWORDS.add(',')
STOPWORDS.add('.')
STOPWORDS.add('/')
len(STOPWORDS)

In [None]:
def to_word_counter(sentences, remove_stop=False):
    word_counter = Counter()
    for sentence in sentences:
        for word in sentence.split():
            if remove_stop and word in STOPWORDS:
                    continue
            word_counter[word] += 1
    word_counter = sorted(word_counter.items(), key=lambda x: x[1], reverse=True)
    return word_counter

In [None]:
fpath = os.path.join(dataset_dir, 'reports', 'sentences_with_chexpert_labels.csv')
SENTENCES_DF = pd.read_csv(fpath)
SENTENCES_DF.head(3)

In [None]:
target = 'Lung Opacity'
d = SENTENCES_DF
d = d[d[target] == 1]
# d = d[((d[target] == 1) | (d[target] == -1))]
print(len(d))
d.head(3)

In [None]:
sorted(list(d['sentence']))

In [None]:
wc = to_word_counter(d['sentence'], remove_stop=True)
wc[:10]

In [None]:
wc

In [None]:
## TODO: this could be useful:
AMOUNTS = ['innumerable', 'multiple', 'three', 'a few']
SIZE = ['NUMBER', 'large', 'small', 'moderate sized', 'width', 'diameter']
COMPARISON = ['than', # e.g. right larger than left
             ]

In [None]:
_OPACITY_MODIFIERS = set([
    # lung*
    'left', 'right',
    'lobe', 'lobes', # lobe*
    'basal', 'base', 'bases',
    'basilar', 'bibasilar',
    'bilaterally', 'bilateral', 'lateral', # lateral*
    'lower', 'upper', 'midlung', 'middle', 'central',
    'biapical', 'apex', 'apical',
    'perihilar', # hilar* | hilum
    'costophrenic', 'retrocardiac',
    'lingula', 'lingular', # lingula*
    'anterior', 'posterior',

    'mild', 'minimal', 'slightly', 'small',
    'patchy', 'streaky', 'bandlike', 'reticular',
    'focal', 'diffuse', 'scattered',
    'subsegmental', # segmental*
    'parenchymal', 'interstitial', 'alveolar',
    'chronic',
    'prominent',

    'calcified', 'discrete', 'poorly defined', 'vague',
    'subtle', 'asymmetric', 'strandy', 'shaped', 'rotated',
    'irregular', 'coarse', 'residual', 'maximal thickness',
    'thin', 'smooth',

    'ring shaped',
    'wedge shaped',
    'lobulated',
    'central lucency'
])

In [None]:
_MEDIAST_MODIFIERS = set([
    'minimal', 'borderline',
    'mild', 'mildly',
    'moderate', 'moderately',
    'slight', 'slightly',
    'significantly',
    'severe', 'severely',
])

In [None]:
[
    s for s in sorted(list(d['sentence']))
    if any(m in s for m in _OPACITY_MODIFIERS)
]

In [None]:
'right middle lobe', 'right lower lobe', 'left base'