# Imports

In [None]:
import os
from collections import Counter
import importlib

In [None]:
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['figure.facecolor'] = 'white'
matplotlib.rcParams['figure.figsize'] = (15, 5)

In [None]:
import pandas as pd
pd.options.display.max_columns = None

In [None]:
%run ../../utils/__init__.py
config_logging(logging.INFO)

In [None]:
%run ../../datasets/common/constants.py

In [None]:
from medai.datasets import iu_xray, mimic_cxr
IU_DIR = iu_xray.DATASET_DIR
MIMIC_DIR = mimic_cxr.DATASET_DIR

# Find opacities co-occurrences

## Load reports

In [None]:
dataset_dir = IU_DIR
# dataset_dir = MIMIC_DIR

In [None]:
with open(os.path.join(IU_DIR, 'reports', 'reports.min.json')) as f:
    RAW_REPORTS = json.load(f)
len(RAW_REPORTS)

In [None]:
fpath = os.path.join(dataset_dir, 'reports', 'sentences_with_chexpert_labels.csv')
SENTENCES_DF = pd.read_csv(fpath)
SENTENCES_DF.head(3)

In [None]:
cols = ['Lung Opacity', 'Consolidation', 'Pneumonia', 'Atelectasis', 'Lung Lesion']
d = SENTENCES_DF
d = d.loc[((d[cols] == 1).any(axis=1) | (d[cols] == -1).any(axis=1))]
print(len(d))
d.head(3)

In [None]:
dataset_dir = MIMIC_DIR

In [None]:
fpath = os.path.join(dataset_dir, 'reports', 'reports_with_chexpert_labels.csv')
REPORTS_DF = pd.read_csv(fpath)
REPORTS_DF.head(3)

In [None]:
# cols = ['Lung Opacity', 'Consolidation', 'Pneumonia', 'Atelectasis', 'Lung Lesion']
cols = ['Edema']
d = REPORTS_DF
d = d.loc[((d[cols] == 1).any(axis=1) | (d[cols] == -1).any(axis=1))]
# d = d.loc[(d[cols] == -1).sum(axis=1) >= 2]
print(len(d))
d.head(3)

In [None]:
l = list(d['Reports'])
len(l), l[:10]

## Load images

In [None]:
%run ../../utils/images.py

In [None]:
import json

In [None]:
importlib.reload(iu_xray)

In [None]:
dataset = iu_xray.IUXRayDataset('all')
len(dataset)

In [None]:
find_idxs = lambda target_report_id: [
    idx
    for idx, sample in enumerate(dataset.samples)
    if sample['report_filename'] == target_report_id
]

In [None]:
def print_diseases(row, target):
    if not isinstance(target, (list, tuple)):
        target = (target,)
    return ', '.join([
        disease
        for value, disease in zip(row[CHEXPERT_DISEASES], CHEXPERT_DISEASES)
        if value in target
    ])

In [None]:
def print_raw_report(report_id):
    report = RAW_REPORTS.get(report_id)
    if report is None:
        return
    print('---')
    print('FINDINGS: ', report.get('findings'))
    print('IMPRESSION: ', report.get('impression'))
    print('---')

In [None]:
def plot_images_for_report(report_id):
    items = [
        dataset[idx]
        for idx in find_idxs(report_id)
    ]

    if len(items) == 0:
        print('No items found')
        return

    # Print report info
    print(report_id)
    rows = REPORTS_DF.loc[REPORTS_DF['filename'] == report_id]
    if len(rows) != 1:
        print('More than one row!')
    row = rows.iloc[0]
    print('---')
    print(row['Reports'])
    print('---')
    print('Pos: ', print_diseases(row, 1))
    print('Unc: ', print_diseases(row, -1))
    print('Neg: ', print_diseases(row, 0))
    
    # Plot images
    n_rows = 1
    n_cols = len(items)

    plt.figure(figsize=(15, 5))

    for plt_idx, item in enumerate(items):
        plt.subplot(n_rows, n_cols, plt_idx + 1)
        plt.imshow(tensor_to_range01(item.image).permute(1, 2, 0))
        plt.title(item.image_fname)

## Display some images

In [None]:
plot_images_for_report('1000.xml')

In [None]:
plot_images_for_report('1001.xml')

In [None]:
# print_raw_report('1012.xml')
plot_images_for_report('1012.xml')

In [None]:
plot_images_for_report('983.xml')

In [None]:
rid = '984.xml'
print_raw_report(rid)
plot_images_for_report(rid)

In [None]:
rid = '877.xml'
print_raw_report(rid)
plot_images_for_report(rid)

In [None]:
d['filename']

## Show multiple samples

In [None]:
cols = [
    'Lung Opacity', 'Atelectasis', 'Consolidation', 'Pneumonia', 'Lung Lesion',
    # 'Edema',
    # 'Enlarged Cardiomediastinum', 'Cardiomegaly',
]
n_cols = len(cols)

array = np.zeros((n_cols, n_cols), dtype=object)

for i in range(n_cols):
    for j in range(n_cols):
        if i == j:
            continue
        base = cols[i]
        other = cols[j]
        d = REPORTS_DF
        d = d.loc[((d[base] == 1) | (d[base] == -1))]
        total = len(d)
        
        d = d.loc[((d[other] == 1) | (d[other] == -1))]
        shared = len(d)
        
        array[i, j] = f'{shared:,}/{total:,} ({shared/total*100:.0f}%)'
coocurrences = pd.DataFrame(array, columns=cols, index=cols)
coocurrences