# Imports

In [None]:
import os
from collections import Counter

In [None]:
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['figure.facecolor'] = 'white'
matplotlib.rcParams['figure.figsize'] = (15, 5)

In [None]:
import pandas as pd
pd.options.display.max_columns = None

In [None]:
%run ../utils/__init__.py
config_logging(logging.INFO)

In [None]:
%run ../datasets/common/constants.py

In [None]:
ACTUAL_DISEASES = CHEXPERT_DISEASES[1:]
ACTUAL_DISEASES

# Load model (DELETEME)

In [None]:
%run ../models/checkpoint/__init__.py
%run ../utils/files.py

In [None]:
run_id = RunId('1215_174443', debug=False, task='cls')

In [None]:
compiled_model = load_compiled_model(run_id)
compiled_model.metadata['model_kwargs']

# Load data (DELETEME)

In [None]:
%run ../datasets/__init__.py

In [None]:
dataset_kwargs = {
    'dataset_name': 'iu-x-ray',
    'hierarchical': True,
    'dataset_type': 'all',
    'image_size': (256, 256),
    'frontal_only': False,
    'max_samples': None,
}
dataloader = prepare_data_report_generation(**dataset_kwargs)
len(dataloader.dataset)

# Analyze sentence vs diseases

In [None]:
from medai.datasets.iu_xray import DATASET_DIR as IU_DIR

In [None]:
fpath = os.path.join(IU_DIR, 'reports', 'sentences_with_chexpert_labels.csv')
df = pd.read_csv(fpath)
print(len(df))
df.head()

In [None]:
TOTAL_APPEARANCES = df['appearances'].sum()
TOTAL_APPEARANCES

## Utils

In [None]:
def remove_non_covered_info(df):
    def _remove_key_with_value(d, key, value):
        if key not in df.columns:
            print(f'Key not found in df: {key}')
        else:
            d = d.loc[d[key] == value]
        return d
    
    df = _remove_key_with_value(df, 'obfuscated', False)
    df = _remove_key_with_value(df, 'time', False)
    return df

In [None]:
def collect_sentences_for_disease(target_disease, remove_other=True, remove_useless_info=True):
    only_df = df

    if remove_other:
        # Keep only sentences that do not mention other diseases
        other_diseases = list(ACTUAL_DISEASES)
        other_diseases.remove(target_disease)
        only_df = only_df.loc[(only_df[other_diseases] == -2).all(axis=1)]
    
    if remove_useless_info:
        only_df = remove_non_covered_info(only_df)

    grouped = only_df.groupby(target_disease)['sentence'].apply(
        lambda x: sorted(list(x), key=lambda y: len(y)),
    )
    print([(valuation, len(sentences)) for valuation, sentences in grouped.iteritems()])
    
    return grouped

## Annotate non-covered info

### Obfuscated sentences

Contain xxxx

In [None]:
def contains_obfuscated(sentence):
    return 'xxxx' in sentence

In [None]:
df['obfuscated'] = [
    contains_obfuscated(sentence)
    for sentence in df['sentence']
]
df.head()

In [None]:
obf_df = df.loc[df['obfuscated'] == True]
n_sentences = len(obf_df)
obf_appears = obf_df['appearances'].sum()
perc = obf_appears / TOTAL_APPEARANCES * 100
print(f'Obfuscation: sentences={n_sentences:,}, appearances={obf_appears:,} ({perc:.2f}%)')

### Time-related sentences

"shown again", "given history", etc

In [None]:
_TIME_MENTIONS = set([
    'unchanged', 'improved', 'given history',
    'previous', 'with prior', 'no change',
    'prior exam', 'consistent with prior',
    'prior study', 'compared to prior',
    'from the prior',
    'prior',
    'has been removed',
    'have been removed',
    'interval', 'persistent', 'remain',
    'stable', 'now', 'again',
])

In [None]:
l = list(df[df['sentence'].str.contains('again')]['sentence'])
len(l), l

In [None]:
def mentions_time(sentence):
    return any(
        time_mention in sentence
        for time_mention in _TIME_MENTIONS
    )

In [None]:
df['time'] = [
    mentions_time(sentence)
    for sentence in df['sentence']
]
df.head()

In [None]:
time_df = df.loc[df['time'] == True]
n_sentences = len(time_df)
time_appears = time_df['appearances'].sum()
perc = time_appears / TOTAL_APPEARANCES * 100
print(f'Time: sentences={n_sentences:,}, appearances={time_appears:,} ({perc:.2f}%)')

## Heart

In [None]:
grouped = collect_sentences_for_disease('Cardiomegaly')

In [None]:
grouped[1]

In [None]:
grouped = collect_sentences_for_disease('Enlarged Cardiomediastinum')

In [None]:
grouped[1]

## Lungs

In [None]:
grouped = collect_sentences_for_disease('Lung Opacity')

In [None]:
[s for s in grouped[1] if all(k not in s for k in ('right', 'left', 'apic',
                                                   'bilateral', 'bibasilar',
                                                   'interstitial', 'perihilar'))]

In [None]:
grouped = collect_sentences_for_disease('Pleural Other', True, True)

In [None]:
grouped[1]

## Others

In [None]:
grouped = collect_sentences_for_disease('Support Devices')

In [None]:
grouped[1]

## Sentences with more than one disease

In [None]:
subdf = df.loc[((df[ACTUAL_DISEASES] == 1) | (df[ACTUAL_DISEASES] == -1)).sum(axis=1) > 1]
subdf.head(1)

In [None]:
n_appear = subdf['appearances'].sum()
perc = n_appear / TOTAL_APPEARANCES * 100
print(f'More than 1 disease: sentences={len(subdf):,}, appearances={n_appear:,} ({perc:.2f}%)')

In [None]:
list(subdf['sentence'])

# Debug RG-templates model

In [None]:
import torch

In [None]:
%run ../datasets/vocab/__init__.py
%run ../utils/nlp.py

In [None]:
vocab = load_vocab('iu_xray')
len(vocab)

In [None]:
report_reader = ReportReader(vocab)

In [None]:
%run ../models/report_generation/templates/__init__.py

In [None]:
model = create_rg_template_model('chex-v1', ACTUAL_DISEASES, vocab)
model

In [None]:
labels = torch.tensor([[1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1],
                       [1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0],
                      ]).long()
labels.size()

In [None]:
reports = model(labels)
[
    report_reader.idx_to_text(r)
    for r in reports
]

# Debug chexpert-labeler

Check that made-up sentences evaluate correctly with chexpert

In [None]:
%run ../metrics/report_generation/chexpert.py

In [None]:
sentences = [
#     'there are pulmonary nodules or mass identified',
#     'one or more airspace opacities can be seen',
#     'pulmonary edema is seen',
#     'there is focal consolidation',
#     'there is evidence of pneumonia',
#     'no atelectasis',
#     'pleural effusion is seen',
#     'pleural thickening is present',
#     'a fracture is identified',
    'a device is seen',
]

In [None]:
temp_df = pd.DataFrame(sentences, columns=['s'])

In [None]:
labels = apply_labeler_to_column(temp_df, 's')
labels