# Imports

In [None]:
# %env CUDA_VISIBLE_DEVICES=1

In [None]:
import torch

In [None]:
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['figure.facecolor'] = 'white'
matplotlib.rcParams['figure.figsize'] = (15, 5)

In [None]:
import pandas as pd
pd.options.display.max_columns = None

In [None]:
%run ../utils/__init__.py

In [None]:
config_logging(logging.WARN)

# Utils

In [None]:
from pycocoevalcap.bleu.bleu_scorer import BleuScorer
from pycocoevalcap.rouge.rouge import Rouge
from sklearn.metrics import precision_recall_fscore_support as prf1s

In [None]:
def calculate_rouge(gt, gen):
    assert isinstance(gt, str)
    assert isinstance(gen, str)
    scorer = Rouge()
    return scorer.calc_score([gen], [gt])

In [None]:
def calculate_bleu(gt, gen):
    assert isinstance(gt, str)
    assert isinstance(gen, str)
    scorer = BleuScorer(4)
    scorer += (gen, [gt])
    bleus, _ = scorer.compute_score()
    return bleus

# Examples vs metrics

## Utils

In [None]:
%run ../metrics/report_generation/chexpert.py
%run ../metrics/report_generation/mirqi.py
%run -n ../eval_report_generation_mirqi.py
# %run ../metrics/report_generation/nlp/rouge.py
# %run ../metrics/report_generation/nlp/bleu.py

In [None]:
def calculate_chexpert(gt, gen, verbose=False, diseases=None):
    raw_labels = apply_labeler_to_column([gen, gt])
    # shape: 2, 14

    labels = raw_labels.copy()
    labels[labels == -2] = 0
    labels[labels == -1] = 1
    
    if verbose:
        print('Chexpert labels: \n', labels)
    
    if diseases is not None:
        diseases_idx = [CHEXPERT_DISEASES.index(d) for d in diseases]
        labels = labels[:, diseases_idx]
        # print('\tFiltered: ', labels)
    
    precision, recall, f1, _ = prf1s(
        np.expand_dims(labels[1, :], 0),
        np.expand_dims(labels[0, :], 0),
        zero_division=0,
    )
    return precision, recall, f1, raw_labels

In [None]:
def calculate_mirqi(gt, gen, verbose=False):
    attributes = _call_mirqi_for_reports([gen, gt])
    # shape: 2, 1
    
    attributes = _attributes_to_list(attributes.squeeze())
    if verbose:
        print('MIRQI attributes: \n', attributes)
    
    scores = MIRQI([attributes[1]], [attributes[0]])
    
    return scores, attributes

In [None]:
def calculate_metrics(gt, gen, diseases=None, only_present=True, verbose=False):
    results = {}
    
    bleu = calculate_bleu(gt, gen)
    rouge = calculate_rouge(gt, gen)
    
    results.update({
        'bleu': np.mean(bleu),
        'rouge': rouge,
    })
    
    precision, recall, f1, raw_labels = calculate_chexpert(gt, gen, verbose=verbose,
                                                           diseases=diseases)
    
    if only_present:
        present_labels = raw_labels.sum(axis=0) # shape: 14
        f1 = f1[present_labels != -4]
    f1 = f1.mean()
    precision = precision[present_labels != -4].mean()
    recall = recall[present_labels != -4].mean()
    
    results.update({
        'f1': f1,
        'prec': precision,
        'recall': recall,
    })
    
    mirqi_values, _ = calculate_mirqi(gt, gen, verbose=verbose)
    for key in ('MIRQI-f', 'MIRQI-p', 'MIRQI-r'):
        results[key] = mirqi_values[key][0]
        
    return results

## Compute samples

In [None]:
# diseases = ['Cardiomegaly', 'Pneumothorax']

In [None]:
gt = 'heart size is mildly enlarged . small right pneumothorax is seen .'
gens = [
    'heart size is normal . no pneumothorax is seen .',
    'mild cardiomegaly . pneumothorax on right lung .',
    'mild cardiomegaly . pneumothorax on right lung , bibasilar opacities and edema .',
    'cardiac silhouette is moderately enlarged . left pneumothorax observed .',
    'the cardiac silhouette is enlarged . no pneumothorax .',
    'the cardiac silhouette is enlarged . no pneumothorax is seen.',
#     'the cardiac silhouette is enlarged . pneumothorax observed .',
#     'cardiac silhouette is mildly enlarged . small pneumothorax on right side .',
    # 'cardiomediastinal silhouettes are within normal limits . lungs are clear without focal consolidation , pneumothorax , or pleural effusion . stable calcified granulomas . bony thorax is unremarkable .',
]

In [None]:
df = pd.DataFrame.from_records([
    calculate_metrics(gt, gen)
    for gen in gens
], index=gens)
df

In [None]:
scores, attributes = calculate_mirqi(gt, gens[1], verbose=True)
scores

In [None]:
MIRQI([attributes[1]], [attributes[0]])

In [None]:
MIRQI_v2([attributes[1]], [attributes[0]])

In [None]:
attributes[1]

In [None]:
gt = 'heart size is mildly enlarged . small right pneumothorax is seen . bibasilar opacities .'
gens = [
    'heart size is normal . no pneumothorax is seen . no opacities .',
    'mild cardiomegaly . pneumothorax on right side .',
    'cardiac silhouette is moderately enlarged . left pneumothorax observed . patchy opacities .',
#     'the cardiac silhouette is enlarged . pneumothorax observed .',
#     'cardiac silhouette is mildly enlarged . small pneumothorax on right side .',
]

In [None]:
df = pd.DataFrame.from_records([
    calculate_metrics(gt, gen)
    for gen in gens
], index=gens)
df

In [None]:
gt = 'heart size is mildly enlarged . bibasilar interstitial opacities . no pneumothorax is seen .'
gens = [
    # 'heart size is normal . no opacities . no pneumothorax is seen .',
    'the cardiac silhouette is enlarged . pneumothorax is not observed . multiple opacities seen.',
#     'heart size is moderately enlarged . no pneumothorax is seen .',
#     'pneumothorax on right side . cardiac silhouette is mildly enlarged .',
#     'cardiac silhouette is mildly enlarged . pneumothorax on right side .',
]

In [None]:
calculate_metrics(gt, gens[0], verbose=True)

In [None]:
calculate_metrics(gt, gen1) #, diseases)

In [None]:
calculate_metrics(gt, gen2)

In [None]:
calculate_metrics(gt, gen3)

In [None]:
calculate_metrics(gt, gen4)

In [None]:
calculate_metrics(gt, gen5)

# Sample generated reports

In [None]:
%run ../metrics/__init__.py
%run ../utils/files.py

In [None]:
# run_name = '0612_035549'
# run_name = '0602_034645'
# run_name = '0601_031606'
run_name = '0612_233628'
# run_name = '0617_143104'
run_id = RunId(run_name, False, 'rg')
run_id

In [None]:
df = load_rg_outputs(run_id, free=True, labeled=True)
df.head()

## Check commonly generated reports

In [None]:
from collections import Counter

In [None]:
_LUNG_RELATED_DISEASES = (
    'Lung Lesion',
    'Lung Opacity',
    'Edema',
    'Consolidation',
    'Pneumonia',
    'Atelectasis',
    'Pneumothorax',
    'Pleural Effusion',
    'Pleural Other',
)

In [None]:
%run ../utils/nlp.py
%run ../datasets/common/constants.py

In [None]:
ACTUAL_DISEASES = CHEXPERT_DISEASES[1:]
actual_diseases_gen = [f'{d}-gen' for d in ACTUAL_DISEASES]
actual_diseases_gt = [f'{d}-gt' for d in ACTUAL_DISEASES]
lung_diseases_gen = [f'{d}-gen' for d in _LUNG_RELATED_DISEASES]

In [None]:
d = df
d = d.loc[d['dataset_type'] == 'train']
d = d.loc[(d[actual_diseases_gen] == 0).all(axis=1)]
# d = d.loc[(d[actual_diseases_gen] == 1).any(axis=1)]
# d = d.loc[(d[lung_diseases_gen] == 0).all(axis=1)]
d.head()

In [None]:
reports = list(d['generated'])
len(reports), len(set(reports))

In [None]:
reports_appearances = sorted(
    Counter(reports).items(),
    # key=lambda x: (1428 - len(x[0])) * 300000 + x[1],
    key=lambda x: x[1],
    reverse=True,
)
reports_appearances[:5]

In [None]:
reports_appearances[5:30]

In [None]:
s = [r for r in d['ground_truth'] if r.startswith('in comparison with the study of xxxx')]
len(s), len(d)

In [None]:
r = reports_appearances[1][0]
r

In [None]:
d.loc[d['ground_truth'] == 'no pneumonia , vascular congestion , or pleural effusion .'].head(2)# [actual_diseases_gen].head(1)

## Use NLP metrics

In [None]:
from pycocoevalcap.cider.cider_scorer import CiderScorer

In [None]:
d = df.copy()
d = d.loc[d['dataset_type'] == 'test']
d.head(2)

In [None]:
scorer_bleu = BleuScorer(4)
scorer_rouge = Rouge()
scorer_cider = CiderScorer(4)

all_rouge_scores = []

for index, row in d.iterrows():
    gen = str(row['generated'])
    gt = str(row['ground_truth'])
    
    scorer_bleu += (gen, [gt])
    scorer_cider += (gen, [gt])
    all_rouge_scores.append(scorer_rouge.calc_score([gen], [gt]))
    
bleus, all_bleu_scores = scorer_bleu.compute_score()
cider, all_cider_scores = scorer_cider.compute_score()
len(all_bleu_scores), len(all_cider_scores), len(all_rouge_scores)

In [None]:
all_bleu_scores = np.array(all_bleu_scores)
all_bleu_scores.shape

In [None]:
all_bleu_scores = all_bleu_scores.mean(axis=0)
len(all_bleu_scores)

In [None]:
d['bleu'] = all_bleu_scores
d['rouge'] = all_rouge_scores
d['cider'] = all_cider_scores
d.head(2)

In [None]:
cols = ['ground_truth', 'generated', 'bleu', 'rouge', 'cider']
d2 = d.sort_values(['bleu', 'rouge', 'cider'])[cols]
d2

In [None]:
list(d2['ground_truth'])

In [None]:
d2.head(40)

In [None]:
d2.loc[d2['ground_truth'].str.contains('no acute intrathoracic process')]

In [None]:
list(d2.loc[240745])

In [None]:
list(d2['ground_truth'])

In [None]:
d2.head(60)

In [None]:
list(d2.loc[242324])

## Use chexpert

In [None]:
%run ../metrics/report_generation/chexpert.py

In [None]:
# rr = ['no pneumonia , vascular congestion , or pleural effusion']
# rr = ['no acute cardiopulmonary process']
rr = ["""in comparison with the study of xxxx ,
        the monitoring and support devices are unchanged .
        continued enlargement of the cardiac silhouette with
        pulmonary vascular congestion and bilateral pleural effusions
        with compressive atelectasis at the bases"""]
labels = apply_labeler_to_column(rr)
labels

In [None]:
list(zip(CHEXPERT_DISEASES, labels[0]))

# Sample real reports

## Raw reports

Findings + impression

In [None]:
import json

In [None]:
%run ../datasets/preprocess/iu_xray.py

In [None]:
# reports = load_raw_reports()
with open(os.path.join(REPORTS_DIR, 'reports.clean.v4.json')) as f:
    reports = json.load(f)
len(reports)

In [None]:
reports['1.xml']

In [None]:
def has_what_is_needed(report):
    has_text = report['findings'] is not None and report['impression'] is not None
    if not has_text:
        return False
    if 'xxxx' in report['findings'].lower() or 'xxxx' in report['impression'].lower():
        return False
    images = report['images']
    frontal_image = any('frontal' in i['side'] and not i['broken'] for i in images)
    lateral_image = any('frontal' not in i['side'] and not i['broken'] for i in images)
    has_images = frontal_image and lateral_image
    return has_text and has_images

reports = {
    k: report
    for k, report in reports.items()
    if has_what_is_needed(report)
}
len(reports)

In [None]:
studies = list(reports.keys())
len(studies)

In [None]:
def plot_sample(study):
    report_meta = reports[study]
    
    images = report_meta['images']
    n_cols = len(images)
    n_rows = 1
    print(study)
    
    plt.figure(figsize=(n_cols * 5, n_rows*5))
    for idx, image_meta in enumerate(images):
        image_id = image_meta['id']
        if image_meta['broken']:
            print(f'WARNING: {image_id} is broken')
        image_pos = image_meta['side']
        title = f'{image_id} ({image_pos})'
        
        print(title)
        
        image_path = os.path.join(DATASET_DIR, 'images', f'{image_id}.png')
        image = load_image(image_path, 'L')
        plt.subplot(n_rows, n_cols, idx + 1)
        plt.title(title)
        plt.imshow(image, cmap='gray')
    
    for key in ('indication', 'findings', 'impression'):
        value = report_meta.get(key, None)
        print(f'{key}: {value}')

In [None]:
plot_sample('10.xml')

In [None]:
plot_sample('922.xml')

In [None]:
d = d.loc[d['filename'].isin(list(reports))]
print(len(d))
d.head(2)

In [None]:
l = list(d['filename'])

In [None]:
CANDIDATES = ['3959.xml', '2532.xml', '1057.xml']

In [None]:
ll = list([l[2], l[5], l[12], l[14], l[40]])
ll

In [None]:
plot_sample('3095.xml')

## Clean reports

In [None]:
%run ../datasets/iu_xray.py

In [None]:
fpath = os.path.join(DATASET_DIR, 'reports', 'reports_with_chexpert_labels.csv')
chexpert_df = pd.read_csv(fpath)
chexpert_df.head(2)

In [None]:
d = chexpert_df
# d = d.loc[((d['Pneumothorax'] == 1) & (d['Cardiomegaly'] == 1))]
d = d.loc[d['Consolidation'] == 1]
d = d.sort_values('Reports', key=lambda x: x.str.len(), ascending=True)
print(len(d))
d.head(2)

In [None]:
list(d['Reports'])

# Load rg-templates model

## Load CNN and rg-templates

In [None]:
%run ../models/checkpoint/__init__.py
%run ../utils/files.py

In [None]:
rg_run_id = RunId('0612_012741', debug=False, task='rg')
cnn_name = re.match(r'.*cnn-(\d{4}-\d{6})', rg_run_id.name).group(1).replace('-', '_')
run_id = RunId(cnn_name, debug=False, task='cls')
run_id

In [None]:
compiled_model = load_compiled_model(run_id)
compiled_model.model.eval()
compiled_model.metadata['model_kwargs']

## Load outputs

In [None]:
%run ../metrics/__init__.py
%run ../metrics/report_generation/chexpert.py

In [None]:
results_folder = get_results_folder(rg_run_id)
outputs_path = os.path.join(results_folder, f'outputs-labeled-free.csv')
df = pd.read_csv(outputs_path)
df.head(2)

In [None]:
target = 'Cardiomegaly'

In [None]:
others = labels_with_suffix('gen') + labels_with_suffix('gt')
others.remove(f'{target}-gt')
others.remove(f'{target}-gen')

In [None]:
d = df
d = d.loc[((d[f'{target}-gen'] == 1) & (d[f'{target}-gt'] == 1) & ((d[others] == 0).all(axis=1)))]
print(len(d))
d.head(2)

In [None]:
d

## Load image and Grad-CAM

In [None]:
%run ../datasets/iu_xray.py
%run -n ../eval_rg_template.py
%run ../utils/__init__.py

In [None]:
transform = get_default_image_transform(
    (256, 256),
    norm_by_sample=False,
    mean=_DATASET_MEAN,
    std=_DATASET_STD,
)

In [None]:
image_name = 'CXR3993_IM-2044-1001' # Great example

In [None]:
image_fpath = os.path.join(DATASET_DIR, 'images', f'{image_name}.png')

In [None]:
image = load_image(image_fpath, 'RGB')
image = transform(image)
image.size()

In [None]:
plotable_image = tensor_to_range01(image).permute(1, 2, 0).detach().cpu().numpy()
plotable_image.shape

In [None]:
images = image.unsqueeze(0).cuda()
images.size()

In [None]:
out, embedding = compiled_model.model(images)
out = torch.sigmoid(out)
out.size(), embedding.size()

In [None]:
out

In [None]:
thresh = _get_threshold(run_id, 'pr', compiled_model.model.labels)
thresh

In [None]:
(out >= thresh).type(torch.uint8)

### Grad-CAM

In [None]:
%run ../training/classification/grad_cam.py

In [None]:
grad_cam = create_grad_cam(compiled_model.model)

In [None]:
attributions = calculate_attributions(grad_cam, images, 1, resize=False)
attributions.size()

In [None]:
image_size = images.size()[-2:]
attributions = interpolate(attributions, image_size, mode='bilinear', align_corners=False)
attributions.size()

In [None]:
heatmap = attributions.squeeze().detach().cpu().numpy()
heatmap.shape

In [None]:
n_rows = 1
n_cols = 2
plt.subplot(n_rows, n_cols, 1)
plt.imshow(plotable_image)

plt.subplot(n_rows, n_cols, 2)
plt.imshow(heatmap)
plt.colorbar()

In [None]:
from captum.attr import visualization

In [None]:
figure, axis = visualization.visualize_image_attr(
    np.expand_dims(heatmap, 2),
    plotable_image,
    method='blended_heat_map',
    # method='original_image',
    cmap='jet',
)

In [None]:
figure.savefig('/home/pdpino/downloads/iu-out-example-grad-cam.png', bbox_inches='tight')

In [None]:
type(figure)

In [None]:
row = d.loc[d['image_fname'] == image_name]
gt = str(row['ground_truth'].item())
gen = str(row['generated'].item())
gt, gen

In [None]:
compiled_model.model.labels

# Check data amounts

In [None]:
%run ../datasets/__init__.py

In [None]:
dataset_kwargs = {
    # 'dataset_name': 'mimic-cxr',
    'dataset_name': 'iu-x-ray',
    'dataset_type': 'test',
    'max_samples': None,
    'frontal_only': True,
    'reports_version': 'v4-1',
    'image_size': (256, 256),
}
dataloader = prepare_data_report_generation(**dataset_kwargs)
dataset = dataloader.dataset
len(dataset)

In [None]:
mimic_train = 237964
mimic_val = 1959
mimic_test = 3403
mimic_train + mimic_val + mimic_test

In [None]:
iu_train = 2638
iu_val = 336
iu_test = 337
iu_total = iu_train + iu_val + iu_test
iu_total

In [None]:
iu_train / iu_total, iu_val / iu_total, iu_test / iu_total