In [1]:
import torch
import gloria
import pandas as pd
import matplotlib.pyplot as plt
from gloria import builder, utils
from PIL import Image
import numpy as np
from mimic_data import ImaGenomeDataModule, MimicCxrFiler, ImaGenomeFiler, normalize
import os
import cv2
from torch import nn
from jupyter_innotater import *
import pickle as pkl
from tqdm import tqdm
from torchmetrics import AUROC

# def normalize(x):
#     return x

def process_img(model, imgs, device):

    transform = builder.build_transformation(model.cfg, split="test")

    all_imgs = []
    for x in imgs:

        # tranform images
        x = np.array((normalize(x) * 2 - 1) * 255, dtype=np.uint8)
        x = model._resize_img(x, model.cfg.data.image.imsize)
        img = Image.fromarray(x).convert("RGB")
        img = transform(img)
        all_imgs.append(torch.tensor(img))

    all_imgs = torch.stack(all_imgs).to(device)

    return all_imgs


def bbox_to_mask(bbox, image_shape):
    image1 = torch.zeros(image_shape, dtype=torch.bool)
    image1[bbox[1]:, bbox[0]:] = 1
    image2 = torch.zeros(image_shape, dtype=torch.bool)
    image2[:bbox[3] + 1, :bbox[2] + 1] = 1
    box_mask = image1 & image2
    return box_mask


def mask_to_bbox(box_mask):
    if box_mask.sum() == 0:
        return [-1, -1, -1, -1]
    indices0 = torch.arange(box_mask.shape[0])
    indices1 = torch.arange(box_mask.shape[1])
    indices0 = indices0.unsqueeze(1).expand(*box_mask.shape)[box_mask]
    indices1 = indices1.unsqueeze(0).expand(*box_mask.shape)[box_mask]
    return [indices1.min().item(), indices0.min().item(), indices1.max().item(), indices0.max().item()]


def process_bboxes(model, image_shape, bboxes):
    new_bboxes = []
    for bbox in bboxes:
        box_mask = bbox_to_mask(bbox, image_shape)
        new_box_mask = process_img(model, [box_mask], 'cpu')
        new_box_mask = new_box_mask > 0
        coords = mask_to_bbox(new_box_mask[0, 0])
        new_bboxes.append(coords)
    return new_bboxes


def get_batch(model, texts, imgs, device):
    batch = model.process_text(texts, device)
    batch['imgs'] = process_img(model, imgs, device)
    return batch


def plot_attn_maps(attn_maps, imgs, sents, epoch_idx=0, batch_idx=0, nvis=1):

    img_set, _ = utils.build_attention_images(
        imgs,
        attn_maps,
#         max_word_num=self.cfg.data.text.word_num,
        nvis=nvis,
#         rand_vis=self.cfg.train.rand_vis,
        sentences=sents,
    )

    if img_set is not None:
        return Image.fromarray(img_set)


def plot_attention_from_raw(images, reports, model, filename='attention.jpg'):
    reports = [report[report.index('FINDINGS:'):] if 'FINDINGS:' in report else report for report in reports]
    batch_size = len(images)
    batch = get_batch(model, reports, images, 'cuda')
    img_emb_l, img_emb_g, text_emb_l, text_emb_g, sents = model(batch)
    attn_maps = model.get_attn_maps(img_emb_l, text_emb_l, sents)
    im = plot_attn_maps(attn_maps, batch['imgs'].cpu(), sents, nvis=batch_size)
    im.save(filename)


def draw_bounding_boxes(image, bboxes, color=(255, 0, 0)):
    thickness = image.shape[0] // 100
    for bbox in bboxes:
        image = cv2.rectangle(image, bbox[:2], bbox[2:], color, thickness)
    return image


def get_bounding_boxes_mask(image_shape, bboxes):
    image = np.zeros(image_shape)
    image = draw_bounding_boxes(image, bboxes, color=1)
    return image == 1


def show_attention_from_raw(batch, model):
    batch_size = len(batch['imgs'])
    img_emb_l, img_emb_g, text_emb_l, text_emb_g, sents = model(batch)
    attn_maps = model.get_attn_maps(img_emb_l, text_emb_l, sents)
    im = attn_maps[0][0].sum(0).cpu().detach().numpy()
    return im


def to_rgb(image):
    return np.array((normalize(image) * 255).int().unsqueeze(-1).expand(*image.shape, 3).cpu(), dtype=np.uint8)


def process_instance(instance, model, plot=True):
    patient_id = next(iter(instance.keys()))
    study_id = next(iter(instance[patient_id].keys()))
    instance = instance[patient_id][study_id]
    dicom_id = next(iter(instance['images'].keys()))
    image = instance['images'][dicom_id]
    sent_ids = sorted(list(instance['objects'][dicom_id]['sent_to_bboxes'].keys()))
    sents, bbox_names, new_bboxes, attentions, images, labels, contexts = [], [], [], [], [], [], []
    for sent_id in sent_ids:
        sent_info = instance['objects'][dicom_id]['sent_to_bboxes'][sent_id]
        sents.append(sent_info['sentence'])
        bbox_names.append(sent_info['bboxes'])
        sent_bboxes = sent_info['coords_original']
        labels.append(sent_info['labels'])
        contexts.append(sent_info['contexts'])
        sent_images = []
        if plot:
            print('sentence:', sents[-1])
            print('bbox names:', bbox_names[-1])
            print('labels:', labels[-1])
            print('context:', contexts[-1])
            fig, axes = plt.subplots(1, 3)
        image1 = draw_bounding_boxes(to_rgb(image), sent_bboxes)
        sent_images.append(image1)
        if plot:
            axes[0].imshow(image1)
        batch = get_batch(model, [sents[-1]], [image], 'cuda')
        new_sent_bboxes = process_bboxes(model, image.shape, sent_bboxes)
        new_bboxes.append(new_sent_bboxes)
        image2 = batch['imgs'][0, 0]
        image2 = draw_bounding_boxes(to_rgb(image2), new_sent_bboxes)
        sent_images.append(image2)
        if plot:
            axes[1].imshow(image2)
        attn = torch.tensor(show_attention_from_raw(batch, model))
        attn = attn.reshape(1, 1, *attn.shape)
        new_attn = nn.Upsample(size=image2.shape[:2], mode="bilinear")(attn)
        attentions.append(new_attn[0, 0])
        new_attn = draw_bounding_boxes(to_rgb(new_attn[0, 0]), new_sent_bboxes)
        sent_images.append(new_attn)
        if plot:
            axes[2].imshow(new_attn)
            plt.show()
        images.append(sent_images)
    return dict(
        bbox_names=bbox_names,
        new_bboxes=new_bboxes,
        attentions=attentions,
        images=images,
        sents=sents,
        sent_ids=sent_ids,
        labels=labels,
        contexts=contexts,
    )


def get_and_save_instance_results(path, dataset, model, num_examples=None):
    if num_examples is None:
        num_examples = len(model)
    if not os.path.exists(os.path.join(path, 'sentences.csv')):
        if not os.path.exists(path):
            os.mkdir(path)
        if not os.path.exists(os.path.join(path, 'bbox_images0')):
            os.mkdir(os.path.join(path, 'bbox_images0'))
        if not os.path.exists(os.path.join(path, 'bbox_images1')):
            os.mkdir(os.path.join(path, 'bbox_images1'))
        if not os.path.exists(os.path.join(path, 'bbox_images2')):
            os.mkdir(os.path.join(path, 'bbox_images2'))
        if not os.path.exists(os.path.join(path, 'attentions')):
            os.mkdir(os.path.join(path, 'attentions'))
        info = []
        for i in tqdm(range(num_examples), total=num_examples):
            instance = dataset[i]
            patient_id = next(iter(instance.keys()))
            study_id = next(iter(instance[patient_id].keys()))
            dicom_id = next(iter(instance[patient_id][study_id]['images'].keys()))
            outs = process_instance(instance, model, plot=False)
            for sent_id, sent, bbox_names, bboxes, sent_labels, sent_contexts, sent_images, attention in zip(
                    outs['sent_ids'], outs['sents'], outs['bbox_names'], outs['new_bboxes'],
                    outs['labels'], outs['contexts'], outs['images'], outs['attentions']):
                dicom_sent_id = 'dicom_%s_sent_%s' % (dicom_id, sent_id)
                info.append([
                    patient_id,
                    study_id,
                    dicom_id,
                    sent_id,
                    dicom_sent_id,
                    sent,
                    str(bbox_names),
                    str(bboxes),
                    str(sent_labels),
                    str(sent_contexts),
                ])
                Image.fromarray(sent_images[0]).save(
                    os.path.join(path, 'bbox_images0', dicom_sent_id + '.jpg'))
                Image.fromarray(sent_images[1]).save(
                    os.path.join(path, 'bbox_images1', dicom_sent_id + '.jpg'))
                Image.fromarray(sent_images[2]).save(
                    os.path.join(path, 'bbox_images2', dicom_sent_id + '.jpg'))
                np.save(os.path.join(path, 'attentions', dicom_sent_id + '.np'), attention)
        df = pd.DataFrame(info, columns=[
            'patient_id',
            'study_id',
            'dicom_id',
            'sent_id',
            'dicom_sent_id',
            'sentence',
            'bbox_names',
            'bboxes',
            'sent_labels',
            'sent_contexts',
        ])
        df.to_csv(os.path.join(path, 'sentences.csv'))
    else:
        df = pd.read_csv(os.path.join(path, 'sentences.csv'))
    return df


def annotate(path, dataset, model, num_examples=None, labels=None):
    df = get_and_save_instance_results(path, dataset, model, num_examples=num_examples)
    if labels is None:
        labels = [0] * len(df)
    sentences = df.sentence.tolist()
    entities = []
    for i, row in df.iterrows():
        entities.append('')
        ent_to_bbox = {}
        for label, context, bbox in zip(eval(row.sent_labels), eval(row.sent_contexts), eval(row.bbox_names)):
            if (label, context) not in ent_to_bbox.keys():
                ent_to_bbox[(label, context)] = set()
            ent_to_bbox[(label, context)].add(bbox)
        for k, v in ent_to_bbox.items():
            entities[-1] += str(k) + ': ' + str(v) + '\n'
    files = [name + '.jpg' for name in df.dicom_sent_id.tolist()]
    classes = ['0 - Unselected', '1 - Positive', '2 - Ambiguous', '3 - Negative']
    return Innotater(
        [
            TextInnotation(sentences),
            TextInnotation(entities),
    #         ImageInnotation(files, path=os.path.join(path, 'bbox_images0'), width=10, height=10),
            ImageInnotation(files, path=os.path.join(path, 'bbox_images1'), width=200, height=200),
            ImageInnotation(files, path=os.path.join(path, 'bbox_images2'), width=200, height=200),
        ],
        MultiClassInnotation(labels, classes=classes),
    ), labels


def compute_auroc(sent_attention, sent_bboxes):
    label_segmentation = torch.zeros_like(sent_attention, dtype=torch.bool)
    for bbox in sent_bboxes:
        label_segmentation = label_segmentation | bbox_to_mask(bbox, sent_attention.shape)
    if label_segmentation.sum() > 0:
        return AUROC()(sent_attention.reshape(-1), label_segmentation.reshape(-1).long())


def compute_metrics(path, dataset, model, num_examples=None):
    df = get_and_save_instance_results(path, dataset, model, num_examples=num_examples)
    aurocs = []
    for i, row in df.iterrows():
        sent_attention = np.load(row.dicom_sent_id + '.jpg')
        auroc = compute_auroc(sent_attention, row.bboxes)
        if auroc is not None:
            aurocs.append(auroc)
    print(sum(aurocs) / len(aurocs))
    with open('gold_scores.pkl', 'wb') as f:
        pkl.dump({'aurocs': aurocs}, f)


In [2]:
mimic_cxr_filer = MimicCxrFiler(
    download_directory='/scratch/mcinerney.de/mimic-cxr', physio_username='dmcinerney')
imagenome_filer = ImaGenomeFiler(
    download_directory='/scratch/mcinerney.de/imagenome', physio_username='dmcinerney',
    physio_password=mimic_cxr_filer.password)

dm = ImaGenomeDataModule(
    mimic_cxr_filer, imagenome_filer, batch_size=8, num_workers=5, collate_fn=None,
    get_images=True, get_reports=True, force=False, parallel=False,
    num_preprocessing_workers=os.cpu_count(), chunksize=1, split_slices='gold', gold_test=False)

dm.prepare_data()

  dicom_ids = set(gold_object_attribute_with_coordinates_df.image_id.str.replace('.dcm', ''))


downloaded


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3821/3821 [00:00<00:00, 88915.22it/s]


Downloading gold (500):

Filter dicoms so view position is '['PA', 'AP']':


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:00<00:00, 98517.97it/s]



Save dicoms to pytorch files:


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:00<00:00, 37483.28it/s]



Save reports:


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:00<00:00, 41610.16it/s]







In [3]:
gold = dm.get_dataset('gold')
gold.df

Unnamed: 0.1,Unnamed: 0,subject_id,study_id,dicom_id,path,ViewPosition
0,8,10020740,55522869,27776756-1d9ef4fc-cd8dd0ca-1453072f-12c0f484,files/p10/p10020740/s55522869/27776756-1d9ef4f...,AP
1,31,10037020,58400371,76289ac1-3ef7c087-3e77810d-63462e2c-20c0364c,files/p10/p10037020/s58400371/76289ac1-3ef7c08...,AP
2,32,10063856,54814005,4bb710ab-ab7d4781-568bcd6e-5079d3e6-7fdb61b6,files/p10/p10063856/s54814005/4bb710ab-ab7d478...,AP
3,39,10098993,52050071,01e55956-89f296bb-002ac02d-e08ee2a9-832f1cff,files/p10/p10098993/s52050071/01e55956-89f296b...,PA
4,48,10104308,52433992,749d7548-73506c3c-c2d571b0-609dd2f9-746e60a7,files/p10/p10104308/s52433992/749d7548-73506c3...,PA
...,...,...,...,...,...,...
495,3800,19966115,59650514,198be438-16dc1b2c-e4d95d59-25e6b0a8-9e815c12,files/p19/p19966115/s59650514/198be438-16dc1b2...,AP
496,3805,19969031,54877992,4e78a467-5eede0ee-476cb29e-af0db15d-69c4465c,files/p19/p19969031/s54877992/4e78a467-5eede0e...,AP
497,3807,19986230,50379095,27d2ca4c-93c1fe0c-3104a1a6-1f1118dd-14ae6eac,files/p19/p19986230/s50379095/27d2ca4c-93c1fe0...,PA
498,3814,19989918,53487857,5da30295-f1eadf3e-001ddb71-3b07c00c-2883f874,files/p19/p19989918/s53487857/5da30295-f1eadf3...,AP


In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
gloria_model = gloria.load_gloria(name='gloria_resnet18', device=device)
gloria_model

GLoRIA(
  (text_encoder): BertEncoder(
    (model): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(28996, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
              

In [6]:
path = 'annotations11'
innotator, labels = annotate(path, gold, gloria_model, num_examples=2)
innotator

  object_rows = self.gold_objects_df[self.gold_objects_df.image_id.str.replace('.dcm', '') == dicom_id]
  all_imgs.append(torch.tensor(img))
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:04<00:00,  2.43s/it]


Innotater(children=(HBox(children=(VBox(children=(Textarea(value='No acute cardiopulmonary abnormality.', disa…

In [None]:
# commented out for safety
# with open(os.path.join(path, 'labels.pkl'), 'wb') as f:
#     pkl.dump(labels, f)

In [None]:
path = 'annotations'
with open(os.path.join(path, 'labels.pkl'), 'rb') as f:
    labels = pkl.load(f)
innotator, labels = annotate(path, gold, gloria_model, num_examples=100, labels=labels)
innotator

In [None]:
(np.array(labels) != 0).sum()

In [None]:
(np.array(labels) == 1).sum()

In [None]:
(np.array(labels) == 2).sum()

In [None]:
(np.array(labels) == 3).sum()

In [None]:
59 / 101

In [None]:
35 / 101

In [None]:
7 / 101