In [1]:
from gloria.datasets.visualization_utils import *
import pandas as pd
import os
import skimage
from torch import nn


def get_attn_overlay(attn, image_shape):
    new_attn = torch.tensor(attn)
#     new_attn = nn.Upsample(size=image_shape)(new_attn.reshape(1, 1, *new_attn.shape))[0, 0]
    new_attn = new_attn.unsqueeze(-1).expand(*new_attn.shape, 3)
    new_attn = skimage.transform.pyramid_expand(
        new_attn, sigma=20, upscale=image_shape[0] // new_attn.shape[0], multichannel=True)
    new_attn = torch.tensor(new_attn[:, :, 0])
    new_attn = nn.Upsample(size=image_shape)(new_attn.reshape(1, 1, *new_attn.shape))[0, 0]
    return new_attn

In [6]:
paths = {
    'original_test': '/scratch/mcinerney.de/gloria_outputs3/output/gloria_pretrain_1.0/2021_12_06_19_58_28/test_outputs_0/',
    'original_test_shufflebboxes': '/scratch/mcinerney.de/gloria_outputs3/output/gloria_pretrain_1.0/2021_12_06_21_00_23/test_outputs_0/',
    'original_test_randsent': '/scratch/mcinerney.de/gloria_outputs3/output/gloria_pretrain_1.0/2021_12_06_20_00_44/test_outputs_0/',
    'original_test_randbboxes': '/scratch/mcinerney.de/gloria_outputs3/output/gloria_pretrain_1.0/2021_12_06_20_21_48/test_outputs_0/',
    'original': '/scratch/mcinerney.de/gloria_outputs3/output/gloria_pretrain_1.0/2021_12_06_11_05_50/val_outputs_0/',
    'original_shufflebboxes': '/scratch/mcinerney.de/gloria_outputs3/output/gloria_pretrain_1.0/2021_12_06_16_00_31/val_outputs_0/',
    'original_randsent': '/scratch/mcinerney.de/gloria_outputs3/output/gloria_pretrain_1.0/2021_12_06_10_51_31/val_outputs_0/',
    'original_randbboxes': '/scratch/mcinerney.de/gloria_outputs3/output/gloria_pretrain_1.0/2021_12_06_11_22_47/val_outputs_0/',
    'retrained': '/scratch/mcinerney.de/gloria_outputs3/output/gloria_pretrain_1.0/2021_12_06_19_50_12/val_outputs_0/',
    'retrained_shufflebboxes': '/scratch/mcinerney.de/gloria_outputs3/output/gloria_pretrain_1.0/2021_12_07_00_24_55/val_outputs_0/',
    'retrained_randsent': '/scratch/mcinerney.de/gloria_outputs3/output/gloria_pretrain_1.0/2021_12_06_22_49_28/val_outputs_0/',
    'retrained_randbboxes': '/scratch/mcinerney.de/gloria_outputs3/output/gloria_pretrain_1.0/2021_12_06_23_02_34/val_outputs_0/',
#     'retrained_noattn_loss': '/scratch/mcinerney.de/gloria_outputs/output3/gloria_pretrain_1.0//val_outputs_0/'
}


In [7]:
dfs = {k: pd.read_csv(os.path.join(path, 'sentences.csv')) for k, path in paths.items()}

In [8]:
one_lung_selector = OrSelector(
             RowBBoxSelector(contains={'left lung'}, does_not_contain={'right lung'}),
             RowBBoxSelector(contains={'right lung'}, does_not_contain={'left lung'})
         )
abnormal_selector = RowLabelAndContextSelector(contains={('abnormal', 'yes')})
selectors = {'all': None, 'one_lung': one_lung_selector, 'abnormal': abnormal_selector}
example_df = dfs['original']
selector_counts = {name: len(example_df[example_df.apply(selector, axis=1)])
                   if selector is not None else len(example_df)
                   for name, selector in selectors.items()}
example_df = dfs['original_test']
selector_counts_test = {name: len(example_df[example_df.apply(selector, axis=1)])
                   if selector is not None else len(example_df)
                   for name, selector in selectors.items()}
score_rows = []
for k, df in dfs.items():
    score_rows.append({'model': k})
    for name, selector in selectors.items():
        rows = df[df.apply(selector, axis=1)] if selector is not None else df
        score_rows[-1]['%s_auroc' % name] = rows.auroc.mean()
        score_rows[-1]['%s_avg_precision' % name] = rows.avg_precision.mean()
    #     info = path_and_rows_to_info(path, rows=rows)
    #     plot_info(get_attn_overlay, info, path=path)
#     plt.close()
scores_df = pd.DataFrame(score_rows)

In [9]:
latex_table = ''
for i, row in scores_df.iterrows():
    formatted_row_scores = [
        row.all_auroc,
        row.all_avg_precision,
        row.one_lung_auroc,
        row.abnormal_auroc,
        row.abnormal_avg_precision
    ]
    formatted_row_scores = ' & '.join(['%f.2' % s for s in formatted_row_scores])
    latex_table += formatted_row_scores + ' \\\\\n'
print(latex_table)
print('val counts', selector_counts)
print('test counts', selector_counts_test)
scores_df

0.690680.2 & 0.516796.2 & 0.654764.2 & 0.695122.2 & 0.482862.2 \\
0.687235.2 & 0.516795.2 & 0.654054.2 & 0.686493.2 & 0.477410.2 \\
0.681315.2 & 0.509981.2 & 0.649710.2 & 0.674153.2 & 0.467818.2 \\
0.631343.2 & 0.452842.2 & 0.599361.2 & 0.619506.2 & 0.411582.2 \\
0.643243.2 & 0.437247.2 & 0.617523.2 & 0.644002.2 & 0.423967.2 \\
0.639547.2 & 0.436247.2 & 0.613919.2 & 0.637499.2 & 0.419597.2 \\
0.633378.2 & 0.430593.2 & 0.605384.2 & 0.627724.2 & 0.410859.2 \\
0.608374.2 & 0.417647.2 & 0.574359.2 & 0.601929.2 & 0.398219.2 \\
0.570353.2 & 0.396059.2 & 0.562269.2 & 0.566088.2 & 0.376516.2 \\
0.562796.2 & 0.388834.2 & 0.548002.2 & 0.556074.2 & 0.364952.2 \\
0.552478.2 & 0.378293.2 & 0.536577.2 & 0.543405.2 & 0.351110.2 \\
0.553149.2 & 0.386717.2 & 0.539586.2 & 0.548057.2 & 0.367481.2 \\

val counts {'all': 6288, 'one_lung': 1112, 'abnormal': 2611}
test counts {'all': 2496, 'one_lung': 285, 'abnormal': 748}


Unnamed: 0,model,all_auroc,all_avg_precision,one_lung_auroc,one_lung_avg_precision,abnormal_auroc,abnormal_avg_precision
0,original_test,0.69068,0.516796,0.654764,0.3867,0.695122,0.482862
1,original_test_shufflebboxes,0.687235,0.516795,0.654054,0.39007,0.686493,0.47741
2,original_test_randsent,0.681315,0.509981,0.64971,0.378776,0.674153,0.467818
3,original_test_randbboxes,0.631343,0.452842,0.599361,0.351233,0.619506,0.411582
4,original,0.643243,0.437247,0.617523,0.319102,0.644002,0.423967
5,original_shufflebboxes,0.639547,0.436247,0.613919,0.315988,0.637499,0.419597
6,original_randsent,0.633378,0.430593,0.605384,0.308945,0.627724,0.410859
7,original_randbboxes,0.608374,0.417647,0.574359,0.269468,0.601929,0.398219
8,retrained,0.570353,0.396059,0.562269,0.292682,0.566088,0.376516
9,retrained_shufflebboxes,0.562796,0.388834,0.548002,0.277389,0.556074,0.364952


In [20]:
matching_scores = []
for base in ['original', 'original_test', 'retrained']:
    matching_scores.append({'model': base})
    combined = pd.merge(dfs[base], dfs['%s_randsent' % base].rename(
        columns=lambda x: 'neg_' + x), left_on='dicom_sent_id', right_on='neg_dicom_sent_id')
    for n, selector in selectors.items():
        rows = combined[combined.apply(selector, axis=1)] if selector is not None else combined
        matching_scores[-1]['%s_local_acc' % n] = (rows.local_sims > rows.neg_local_sims).mean()
        matching_scores[-1]['%s_global_acc' % n] = (rows.global_sims > rows.neg_global_sims).mean()
matching_scores = pd.DataFrame(matching_scores)
matching_scores

Unnamed: 0,model,all_local_acc,all_global_acc,one_lung_local_acc,one_lung_global_acc,abnormal_local_acc,abnormal_global_acc
0,original,0.540712,0.737754,0.381295,0.768885,0.440444,0.789353
1,original_test,0.544071,0.71875,0.414035,0.750877,0.421123,0.779412
2,retrained,0.797233,0.877704,0.835432,0.903777,0.816928,0.893527


In [31]:
correlation_rows = []
for base in ['original', 'original_test', 'retrained']:
    correlation_rows.append({'model': base})
    local_bce = np.log((pd.concat([dfs[base].local_sims / 5, 1 - (dfs['%s_randsent' % base].local_sims / 5)]) + 1) / 2)
    global_bce = np.log((pd.concat([dfs[base].global_sims, 1 - dfs['%s_randsent' % base].global_sims]) + 1) / 2)
#     local_bce = (pd.concat([dfs[base].local_sims / 5, 1 - (dfs['%s_randsent' % base].local_sims / 5)]) + 1) / 2
#     global_bce = (pd.concat([dfs[base].global_sims, 1 - dfs['%s_randsent' % base].global_sims]) + 1) / 2
    avg_precision = pd.concat([dfs[base].avg_precision, dfs['%s_randsent' % base].avg_precision])
    auroc = pd.concat([dfs[base].auroc, dfs['%s_randsent' % base].auroc])
    attn_entropy = pd.concat([dfs[base].attn_entropy, dfs['%s_randsent' % base].attn_entropy])
    for name, bce in [('local', local_bce), ('global', global_bce)]:
        correlation_rows[-1]['%s_matching_bce_and_attn_entropy' % name] = bce.corr(attn_entropy)
        correlation_rows[-1]['%s_matching_bce_and_avg_precision' % name] = bce.corr(avg_precision)
        correlation_rows[-1]['%s_matching_bce_and_auroc' % name] = bce.corr(auroc)
    correlation_rows[-1]['attn_entropy_and_auroc'] = attn_entropy.corr(auroc)
    correlation_rows[-1]['attn_entropy_and_avg_precision'] = attn_entropy.corr(avg_precision)
    correlation_rows[-1]['auroc_and_avg_precision'] = auroc.corr(avg_precision)
correlation_df = pd.DataFrame(correlation_rows)
correlation_df


Unnamed: 0,model,local_matching_bce_and_attn_entropy,local_matching_bce_and_avg_precision,local_matching_bce_and_auroc,global_matching_bce_and_attn_entropy,global_matching_bce_and_avg_precision,global_matching_bce_and_auroc,attn_entropy_and_auroc,attn_entropy_and_avg_precision,auroc_and_avg_precision
0,original,0.093094,0.030391,0.061908,-0.026888,-0.041506,-0.059104,0.093125,0.046038,0.386823
1,original_test,0.080232,0.026347,0.053078,-0.010838,-0.046573,-0.044436,0.091763,0.028466,0.620504
2,retrained,-0.025474,-0.056814,-0.058437,0.005352,-0.04769,-0.082232,0.059014,0.042902,0.004479


In [None]:
import os
from jupyter_innotater import *
import pandas as pd


def visualize(model_paths, selector=None):
    min_list_of_files = None
    for model_path in model_paths.values():
        fs = set(os.listdir(os.path.join(model_path, 'sentence_figures')))
        df = pd.read_csv(os.path.join(model_path, 'sentences.csv'))
        if selector is not None:
            df = df[df.apply(selector, axis=1)]
        fs = fs.intersection(set([dicom_sent_id + '.jpg' for dicom_sent_id in set(df.dicom_sent_id)]))
        if min_list_of_files is None:
            min_list_of_files = fs
        else:
            min_list_of_files = min_list_of_files.intersection(fs)
    min_list_of_files = sorted(list(min_list_of_files))
    innotations = [
        innotation
        for model_name, model_path in model_paths.items()
        for innotation in [
            TextInnotation([model_name] * len(min_list_of_files)),
            TextInnotation(min_list_of_files),
            ImageInnotation(min_list_of_files, path=os.path.join(model_path, 'sentence_figures'))
        ]
    ]
    return Innotater(
        innotations,
        []
    )

In [None]:
visualize(paths, selector=selector)

In [None]:
from omegaconf import OmegaConf
import gloria
from gloria.datasets.mimic_for_gloria import GloriaCollateFn
from gloria.lightning.callbacks import EvaluateLocalization
from gloria.lightning.pretrain_model import PretrainModel

def get_instance(dicom_id, sent_id, dataset, replace_sent_with=None):
    rows = dataset.df[dataset.df.dicom_id == dicom_id]
    instance = dataset.get_item_from_rows(rows)
    instance = dataset.add_objects(instance, sent_id=sent_id)
    if replace_sent_with is not None:
        patient_id = next(iter(instance.keys()))
        study_id = next(iter(instance[patient_id].keys()))
        instance[patient_id][study_id]['sentence'] = replace_sent_with
        instance[patient_id][study_id]['objects'][dicom_id]['sent_to_bboxes'][sent_id]['sentence'] = replace_sent_with
    return instance

def display(instance):
    patient_id = next(iter(instance.keys()))
    study_id = next(iter(instance[patient_id].keys()))
    print(instance[patient_id][study_id]['sentence'])
    img = instance[patient_id][study_id]['images'][dicom_id]
    plt.imshow(to_rgb(img))
    

In [None]:
cfg = OmegaConf.load('configs/chexpert_pretrain_imagenome_val_config.yaml')
dm = gloria.builder.build_data_module(cfg)
collate_fn = GloriaCollateFn(cfg, "test")
el = EvaluateLocalization(collate_fn)
module = PretrainModel.load_from_checkpoint('./pretrained/chexpert_resnet50.ckpt')
module.eval()
valid = dm.dm.get_dataset('valid')

In [None]:
dicom_id = '2533ab2a-565051cb-35201672-0267a457-d931d20f'
sent_id = '55609974|8'
instance = get_instance(dicom_id, sent_id, valid)
display(instance)
el.evaluate_and_save(
    path='experiments2', instances=[instance], pl_module=module,
    save_full_data=True, plot=True, plot_attn_overlay_mode='pyramid')

In [None]:
neg_instance = get_instance(dicom_id, sent_id, valid, replace_sent_with='Donkeys like to fly kites.')
display(neg_instance)
el.evaluate_and_save(
    path='experiments4', instances=[neg_instance], pl_module=module,
    save_full_data=True, plot=True, plot_attn_overlay_mode='pyramid')

In [None]:
visualize({'regular': 'experiments', 'reversed': 'experiments_neg'})