In [28]:
import torch
import gloria
import pandas as pd
import matplotlib.pyplot as plt
from gloria import builder, utils
from PIL import Image
import numpy as np
from gloria.datasets.mimic_data import ImaGenomeDataModule, MimicCxrFiler, ImaGenomeFiler
from gloria.datasets.mimic_for_gloria import GloriaCollateFn, normalize, original_tensor_to_numpy_image
import os
import cv2
from torch import nn
from jupyter_innotater import *
import pickle as pkl
from tqdm import tqdm
from torchmetrics import AUROC


def bbox_to_mask(bbox, image_shape):
    image1 = torch.zeros(image_shape, dtype=torch.bool)
    image1[bbox[1]:, bbox[0]:] = 1
    image2 = torch.zeros(image_shape, dtype=torch.bool)
    image2[:bbox[3] + 1, :bbox[2] + 1] = 1
    box_mask = image1 & image2
    return box_mask


def mask_to_bbox(box_mask):
    if box_mask.sum() == 0:
        return [-1, -1, -1, -1]
    indices0 = torch.arange(box_mask.shape[0])
    indices1 = torch.arange(box_mask.shape[1])
    indices0 = indices0.unsqueeze(1).expand(*box_mask.shape)[box_mask]
    indices1 = indices1.unsqueeze(0).expand(*box_mask.shape)[box_mask]
    return [indices1.min().item(), indices0.min().item(), indices1.max().item(), indices0.max().item()]


def process_bboxes(model, image_shape, bboxes):
    gloria_collate_fn = GloriaCollateFn(model.cfg, 'test')
    new_bboxes = []
    for bbox in bboxes:
        box_mask = bbox_to_mask(bbox, image_shape)
#         box_mask = np.array((normalize(box_mask) * 2 - 1) * 255, dtype=np.uint8)
        box_mask = original_tensor_to_numpy_image(box_mask)
        new_box_mask = gloria_collate_fn.process_img([box_mask], 'cpu')
        new_box_mask = new_box_mask > 0
        coords = mask_to_bbox(new_box_mask[0, 0])
        new_bboxes.append(coords)
    return new_bboxes


def get_batch(cfg, texts, imgs, device):
    gloria_collate_fn = GloriaCollateFn(cfg, 'test', device=device)
    imgs = [original_tensor_to_numpy_image(img) for img in imgs]
    return gloria_collate_fn.get_batch(imgs, texts)


def plot_attn_maps(attn_maps, imgs, sents, epoch_idx=0, batch_idx=0, nvis=1):

    img_set, _ = utils.build_attention_images(
        imgs,
        attn_maps,
#         max_word_num=self.cfg.data.text.word_num,
        nvis=nvis,
#         rand_vis=self.cfg.train.rand_vis,
        sentences=sents,
    )

    if img_set is not None:
        return Image.fromarray(img_set)


def plot_attention_from_raw(images, reports, model, filename='attention.jpg'):
    reports = [report[report.index('FINDINGS:'):] if 'FINDINGS:' in report else report for report in reports]
    batch_size = len(images)
    batch = get_batch(model.cfg, reports, images, 'cuda')
    img_emb_l, img_emb_g, text_emb_l, text_emb_g, sents = model(batch)
    attn_maps = model.get_attn_maps(img_emb_l, text_emb_l, sents)
    im = plot_attn_maps(attn_maps, batch['imgs'].cpu(), sents, nvis=batch_size)
    im.save(filename)


def draw_bounding_boxes(image, bboxes, color=(255, 0, 0)):
    thickness = image.shape[0] // 100
    for bbox in bboxes:
        image = cv2.rectangle(image, bbox[:2], bbox[2:], color, thickness)
    return image


def get_bounding_boxes_mask(image_shape, bboxes):
    image = np.zeros(image_shape)
    image = draw_bounding_boxes(image, bboxes, color=1)
    return image == 1


def show_attention_from_raw(batch, model):
    batch_size = len(batch['imgs'])
    img_emb_l, img_emb_g, text_emb_l, text_emb_g, sents = model(batch)
    attn_maps = model.get_attn_maps(img_emb_l, text_emb_l, sents)
    im = attn_maps[0][0].sum(0).cpu().detach().numpy()
    return im


def to_rgb(image):
    return np.array((normalize(image) * 255).int().unsqueeze(-1).expand(*image.shape, 3).cpu(), dtype=np.uint8)


def process_instance(instance, model, plot=True):
    patient_id = next(iter(instance.keys()))
    study_id = next(iter(instance[patient_id].keys()))
    instance = instance[patient_id][study_id]
    dicom_id = next(iter(instance['images'].keys()))
    image = instance['images'][dicom_id]
    sent_ids = sorted(list(instance['objects'][dicom_id]['sent_to_bboxes'].keys()))
    sents, bbox_names, new_bboxes, attentions, images, labels, contexts = [], [], [], [], [], [], []
    for sent_id in sent_ids:
        sent_info = instance['objects'][dicom_id]['sent_to_bboxes'][sent_id]
        sents.append(sent_info['sentence'])
        bbox_names.append(sent_info['bboxes'])
        sent_bboxes = sent_info['coords_original']
        labels.append(sent_info['labels'])
        contexts.append(sent_info['contexts'])
        sent_images = []
        if plot:
            print('sentence:', sents[-1])
            print('bbox names:', bbox_names[-1])
            print('labels:', labels[-1])
            print('context:', contexts[-1])
            fig, axes = plt.subplots(1, 3)
        image1 = draw_bounding_boxes(to_rgb(image), sent_bboxes)
        sent_images.append(image1)
        if plot:
            axes[0].imshow(image1)
        batch = get_batch(model.cfg, [sents[-1]], [image], 'cuda')
        new_sent_bboxes = process_bboxes(model, image.shape, sent_bboxes)
        new_bboxes.append(new_sent_bboxes)
        image2 = batch['imgs'][0, 0]
        image2 = draw_bounding_boxes(to_rgb(image2), new_sent_bboxes)
        sent_images.append(image2)
        if plot:
            axes[1].imshow(image2)
        attn = torch.tensor(show_attention_from_raw(batch, model))
        attn = attn.reshape(1, 1, *attn.shape)
        new_attn = nn.Upsample(size=image2.shape[:2], mode="bilinear")(attn)
        attentions.append(new_attn[0, 0])
        new_attn = draw_bounding_boxes(to_rgb(new_attn[0, 0]), new_sent_bboxes)
        sent_images.append(new_attn)
        if plot:
            axes[2].imshow(new_attn)
            plt.show()
        images.append(sent_images)
    return dict(
        bbox_names=bbox_names,
        new_bboxes=new_bboxes,
        attentions=attentions,
        images=images,
        sents=sents,
        sent_ids=sent_ids,
        labels=labels,
        contexts=contexts,
    )


def get_and_save_instance_results(path, dataset, model, num_examples=None):
    if num_examples is None:
        num_examples = len(model)
    if not os.path.exists(os.path.join(path, 'sentences.csv')):
        if not os.path.exists(path):
            os.mkdir(path)
        if not os.path.exists(os.path.join(path, 'bbox_images0')):
            os.mkdir(os.path.join(path, 'bbox_images0'))
        if not os.path.exists(os.path.join(path, 'bbox_images1')):
            os.mkdir(os.path.join(path, 'bbox_images1'))
        if not os.path.exists(os.path.join(path, 'bbox_images2')):
            os.mkdir(os.path.join(path, 'bbox_images2'))
        if not os.path.exists(os.path.join(path, 'attentions')):
            os.mkdir(os.path.join(path, 'attentions'))
        info = []
        for i in tqdm(range(num_examples), total=num_examples):
            instance = dataset[i]
            patient_id = next(iter(instance.keys()))
            study_id = next(iter(instance[patient_id].keys()))
            dicom_id = next(iter(instance[patient_id][study_id]['images'].keys()))
            outs = process_instance(instance, model, plot=False)
            for sent_id, sent, bbox_names, bboxes, sent_labels, sent_contexts, sent_images, attention in zip(
                    outs['sent_ids'], outs['sents'], outs['bbox_names'], outs['new_bboxes'],
                    outs['labels'], outs['contexts'], outs['images'], outs['attentions']):
                dicom_sent_id = 'dicom_%s_sent_%s' % (dicom_id, sent_id)
                info.append([
                    patient_id,
                    study_id,
                    dicom_id,
                    sent_id,
                    dicom_sent_id,
                    sent,
                    str(bbox_names),
                    str(bboxes),
                    str(sent_labels),
                    str(sent_contexts),
                ])
                Image.fromarray(sent_images[0]).save(
                    os.path.join(path, 'bbox_images0', dicom_sent_id + '.jpg'))
                Image.fromarray(sent_images[1]).save(
                    os.path.join(path, 'bbox_images1', dicom_sent_id + '.jpg'))
                Image.fromarray(sent_images[2]).save(
                    os.path.join(path, 'bbox_images2', dicom_sent_id + '.jpg'))
                np.save(os.path.join(path, 'attentions', dicom_sent_id), attention)
        df = pd.DataFrame(info, columns=[
            'patient_id',
            'study_id',
            'dicom_id',
            'sent_id',
            'dicom_sent_id',
            'sentence',
            'bbox_names',
            'bboxes',
            'sent_labels',
            'sent_contexts',
        ])
        df.to_csv(os.path.join(path, 'sentences.csv'))
    else:
        df = pd.read_csv(os.path.join(path, 'sentences.csv'))
    return df


def get_ent_to_bbox(row):
    ent_to_bbox = {}
    for label, context, bbox in zip(eval(row.sent_labels), eval(row.sent_contexts), eval(row.bbox_names)):
        if (label, context) not in ent_to_bbox.keys():
            ent_to_bbox[(label, context)] = set()
        ent_to_bbox[(label, context)].add(bbox)
    return ent_to_bbox


def annotate(path, dataset, model, num_examples=None, labels=None, selector=None):
    df = get_and_save_instance_results(path, dataset, model, num_examples=num_examples)
    if labels is None:
        labels = [0] * len(df)
    sentences = df.sentence.tolist()
    entities = []
    for i, row in df.iterrows():
        entities.append('')
        ent_to_bbox = get_ent_to_bbox(row)
        for k, v in ent_to_bbox.items():
            entities[-1] += str(k) + ': ' + str(v) + '\n'
    files = [name + '.jpg' for name in df.dicom_sent_id.tolist()]
    classes = ['0 - Unselected', '1 - Positive', '2 - Ambiguous', '3 - Negative']
    indexes = None if selector is None else df[df.apply(selector, axis=1)].index.tolist()
    return Innotater(
        [
            TextInnotation(sentences),
            TextInnotation(entities),
    #         ImageInnotation(files, path=os.path.join(path, 'bbox_images0'), width=10, height=10),
            ImageInnotation(files, path=os.path.join(path, 'bbox_images1'), width=200, height=200),
            ImageInnotation(files, path=os.path.join(path, 'bbox_images2'), width=200, height=200),
        ],
        MultiClassInnotation(labels, classes=classes),
        indexes=indexes
    ), labels


def compute_auroc(sent_attention, sent_bboxes):
    label_segmentation = torch.zeros_like(sent_attention, dtype=torch.bool)
    for bbox in sent_bboxes:
        label_segmentation = label_segmentation | bbox_to_mask(bbox, sent_attention.shape)
    if label_segmentation.sum() > 0:
        return AUROC()(sent_attention.reshape(-1), label_segmentation.reshape(-1).long())


def compute_metrics(path, dataset, model, num_examples=None, selector=None):
    df = get_and_save_instance_results(path, dataset, model, num_examples=num_examples)
    aurocs = []
    if selector is not None:
        df = df[df.apply(selector, axis=1)]
    for i, row in df.iterrows():
        sent_attention = torch.tensor(np.load(os.path.join(path, 'attentions', row.dicom_sent_id + '.npy')))
        auroc = compute_auroc(sent_attention, eval(row.bboxes))
        if auroc is not None:
            aurocs.append(auroc)
    print(sum(aurocs) / len(aurocs))
    with open('gold_scores.pkl', 'wb') as f:
        pkl.dump({'aurocs': aurocs}, f)


In [2]:
mimic_cxr_filer = MimicCxrFiler(
    download_directory='/scratch/mcinerney.de/mimic-cxr', physio_username='dmcinerney')
imagenome_filer = ImaGenomeFiler(
    download_directory='/scratch/mcinerney.de/imagenome', physio_username='dmcinerney',
    physio_password=mimic_cxr_filer.password)

dm = ImaGenomeDataModule(
    mimic_cxr_filer, imagenome_filer, batch_size=8, num_workers=5, collate_fn=None,
    get_images=True, get_reports=True, force=False, parallel=False,
    num_preprocessing_workers=os.cpu_count(), chunksize=1, split_slices='gold', gold_test=False)

dm.prepare_data()

  dicom_ids = set(gold_object_attribute_with_coordinates_df.image_id.str.replace('.dcm', ''))


downloaded


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3821/3821 [00:00<00:00, 86748.02it/s]


Downloading gold None (500):
not parallelizing


  0%|                                                                                                                                                                               | 0/500 [00:00<?, ?it/s]

Setting one record's processing to verbose to serve as an example.

Filter dicoms so view position is '['PA', 'AP']':



100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1620.05it/s][A



Save dicoms to pytorch files:



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 717.34it/s][A



Save reports:



100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1582.76it/s][A





100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:01<00:00, 446.30it/s]


In [3]:
gold = dm.get_dataset('gold')
gold.df

Unnamed: 0.1,Unnamed: 0,subject_id,study_id,dicom_id,path,ViewPosition
0,3352,18771968,53209617,67bd451d-6f695b32-b8ce2be9-23c30cae-f6f94270,files/p18/p18771968/s53209617/67bd451d-6f695b3...,AP
1,187,10319873,59547133,13e74403-fb5ec524-1c5d0384-744cdb87-66121ad5,files/p10/p10319873/s59547133/13e74403-fb5ec52...,PA
2,1210,12988419,55086498,1d29f52b-c02f74b1-4e71b918-8878515b-f980dae0,files/p12/p12988419/s55086498/1d29f52b-c02f74b...,PA
3,327,10680329,54197100,7b06a9f9-7fa05d00-3b464e66-8c729c40-0ec1ff52,files/p10/p10680329/s54197100/7b06a9f9-7fa05d0...,PA
4,1170,12957707,59788377,6bcf7ea1-7d6d22d2-acc8b8c1-846af6af-78841c71,files/p12/p12957707/s59788377/6bcf7ea1-7d6d22d...,PA
...,...,...,...,...,...,...
495,2802,17033197,55486789,810da8a0-8a89650f-74545e4d-9cdbf128-5546f6d7,files/p17/p17033197/s55486789/810da8a0-8a89650...,PA
496,32,10063856,54814005,4bb710ab-ab7d4781-568bcd6e-5079d3e6-7fdb61b6,files/p10/p10063856/s54814005/4bb710ab-ab7d478...,AP
497,3043,17938416,57945085,370f4921-7a6644ba-5761cb31-6bee2248-c19e4971,files/p17/p17938416/s57945085/370f4921-7a6644b...,PA
498,1580,14256117,56473110,3de8b3bb-75a4d283-f0fa8221-26a993ae-5ab5fbe3,files/p14/p14256117/s56473110/3de8b3bb-75a4d28...,AP


In [15]:
device = "cuda" if torch.cuda.is_available() else "cpu"
gloria_model = gloria.load_gloria(name='./data/ckpt/gloria_pretrain_1.0/2021_11_09_02_20_45/epoch=2-step=3125.ckpt', device=device)
gloria_model.cfg.data.text.full_report = True
gloria_model_random = gloria.load_gloria(name='./data/ckpt/gloria_pretrain_1.0/2021_11_09_02_25_25/epoch=0-step=1041.ckpt', device=device)
gloria_model_random.cfg.data.text.full_report = True


In [16]:
path = 'annotations_random'
compute_metrics(path, gold, gloria_model_random, num_examples=10)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [01:14<00:00,  7.44s/it]


tensor(0.4512)


In [22]:
path = 'annotations_random'
innotator, labels = annotate(path, gold, gloria_model_random, num_examples=10)
innotator

Innotater(children=(HBox(children=(VBox(children=(Textarea(value='No acute cardiopulmonary abnormality.', disa…

In [18]:
path = 'annotations'
compute_metrics(path, gold, gloria_model, num_examples=10)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [01:07<00:00,  6.73s/it]


tensor(0.6072)


In [19]:
path = 'annotations'
innotator, labels = annotate(path, gold, gloria_model, num_examples=10)
innotator

Innotater(children=(HBox(children=(VBox(children=(Textarea(value='No acute cardiopulmonary abnormality.', disa…

In [30]:
path = 'annotations_random'
innotator, labels = annotate(path, gold, gloria_model_random, num_examples=10,
                             selector=RowContainsLabelAndContextSelector('abnormal', 'yes'))
innotator

Innotater(children=(HBox(children=(VBox(children=(Textarea(value=' lower lungs are grossly clear, though there…

In [29]:
path = 'annotations'
innotator, labels = annotate(path, gold, gloria_model, num_examples=10,
                             selector=RowContainsLabelAndContextSelector('abnormal', 'yes'))
innotator

Innotater(children=(HBox(children=(VBox(children=(Textarea(value=' lower lungs are grossly clear, though there…

In [None]:
# commented out for safety
# with open(os.path.join(path, 'labels.pkl'), 'wb') as f:
#     pkl.dump(labels, f)

In [None]:
path = 'annotations'
with open(os.path.join(path, 'labels.pkl'), 'rb') as f:
    labels = pkl.load(f)
innotator, labels = annotate(path, gold, gloria_model, num_examples=100, labels=labels)
innotator

In [None]:
(np.array(labels) != 0).sum()

In [None]:
(np.array(labels) == 1).sum()

In [None]:
(np.array(labels) == 2).sum()

In [None]:
(np.array(labels) == 3).sum()

In [None]:
59 / 101

In [None]:
35 / 101

In [None]:
7 / 101

In [20]:
df = pd.read_csv('annotations/sentences.csv')
df

Unnamed: 0.1,Unnamed: 0,patient_id,study_id,dicom_id,sent_id,dicom_sent_id,sentence,bbox_names,bboxes,sent_labels,sent_contexts
0,0,15184836,59382057,00046130-fd952ef0-57f2948d-491a16b4-5db3a18c,59382057|11,dicom_00046130-fd952ef0-57f2948d-491a16b4-5db3...,No acute cardiopulmonary abnormality.,"['cardiac silhouette', 'cardiac silhouette', '...","[[97, 109, 198, 189], [97, 109, 198, 189], [12...","['abnormal', 'normal', 'abnormal', 'normal', '...","['no', 'yes', 'no', 'yes', 'no', 'yes', 'no', ..."
1,1,15184836,59382057,00046130-fd952ef0-57f2948d-491a16b4-5db3a18c,59382057|6,dicom_00046130-fd952ef0-57f2948d-491a16b4-5db3...,The cardiomediastinal and hilar contours are w...,"['cardiac silhouette', 'cardiac silhouette', '...","[[97, 109, 198, 189], [97, 109, 198, 189], [12...","['enlarged cardiac silhouette', 'mediastinal d...","['no', 'no', 'no', 'yes', 'no', 'no', 'no', 'n..."
2,2,15184836,59382057,00046130-fd952ef0-57f2948d-491a16b4-5db3a18c,59382057|7,dicom_00046130-fd952ef0-57f2948d-491a16b4-5db3...,The lung fields are clear.,"['left lung', 'right lung']","[[128, 34, 216, 196], [24, 39, 114, 202]]","['lung opacity', 'lung opacity']","['no', 'no']"
3,3,15184836,59382057,00046130-fd952ef0-57f2948d-491a16b4-5db3a18c,59382057|8,dicom_00046130-fd952ef0-57f2948d-491a16b4-5db3...,"There is no pneumothorax, fracture or dislocat...","['left apical zone', 'left lung', 'right apica...","[[129, 34, 189, 69], [128, 34, 216, 196], [56,...","['pneumothorax', 'pneumothorax', 'pneumothorax...","['no', 'no', 'no', 'no']"
4,4,12930467,57384894,005043e2-a4e25d1d-aae26631-732a2db0-38412248,57384894|4,dicom_005043e2-a4e25d1d-aae26631-732a2db0-3841...,FINDINGS: The cardiomediastinal and hilar con...,"['cardiac silhouette', 'cardiac silhouette', '...","[[85, 82, 155, 152], [85, 82, 155, 152], [115,...","['enlarged cardiac silhouette', 'mediastinal d...","['no', 'no', 'no', 'yes', 'no', 'no', 'no', 'n..."
5,5,12930467,57384894,005043e2-a4e25d1d-aae26631-732a2db0-38412248,57384894|5,dicom_005043e2-a4e25d1d-aae26631-732a2db0-3841...,The lungs are clear.,"['left lung', 'right lung']","[[114, 0, 190, 169], [31, 0, 100, 164]]","['lung opacity', 'lung opacity']","['no', 'no']"
6,6,12930467,57384894,005043e2-a4e25d1d-aae26631-732a2db0-38412248,57384894|6,dicom_005043e2-a4e25d1d-aae26631-732a2db0-3841...,There is no pleural effusion or pneumothorax.,"['left costophrenic angle', 'left lung', 'left...","[[180, 153, 202, 175], [114, 0, 190, 169], [11...","['pleural effusion', 'pleural effusion', 'pneu...","['no', 'no', 'no', 'no', 'no', 'no']"
7,7,12930467,57384894,005043e2-a4e25d1d-aae26631-732a2db0-38412248,57384894|7,dicom_005043e2-a4e25d1d-aae26631-732a2db0-3841...,IMPRESSION: No acute cardiopulmonary process.,"['cardiac silhouette', 'cardiac silhouette', '...","[[85, 82, 155, 152], [85, 82, 155, 152], [114,...","['abnormal', 'normal', 'abnormal', 'normal', '...","['no', 'yes', 'no', 'yes', 'no', 'yes', 'no', ..."
8,8,10646008,53041609,009a6abf-9d2e4695-673ee589-a7c60144-2dd81769,53041609|10,dicom_009a6abf-9d2e4695-673ee589-a7c60144-2dd8...,Heart size is normal.,"['cardiac silhouette', 'cardiac silhouette']","[[73, 90, 155, 149], [73, 90, 155, 149]]","['enlarged cardiac silhouette', 'normal']","['no', 'yes']"
9,9,10646008,53041609,009a6abf-9d2e4695-673ee589-a7c60144-2dd81769,53041609|5,dicom_009a6abf-9d2e4695-673ee589-a7c60144-2dd8...,Bulging mediastinum projecting over the left m...,"['left lung', 'right lung']","[[113, 2, 188, 164], [-1, -1, -1, -1]]","['low lung volumes', 'low lung volumes']","['yes', 'yes']"


In [23]:
class RowContainsLabelAndContextSelector:
    def __init__(self, label, context):
        self.label = label
        self.context = context

    def __call__(self, row):
        ent_to_bbox = get_ent_to_bbox(row)
        return (self.label, self.context) in ent_to_bbox.keys()


[10, 11, 12, 14, 21, 32, 33, 34, 47]