In [1]:
import os
import sys


# Add the parent directory to the system path
sys.path.append("..")

from model import Attention, GatedAttention, AdditiveAttention
from WSI_dataloader import collate, BreastWSIDataset, BreastEmbeddingDataset

# ==================================================================================
 
import numpy as np
import math
import matplotlib.pyplot as plt

# ============================== Torch Imports =====================================

import torch
import torch.utils.data as data_utils
import torch.optim as optim
from torch.autograd import Variable
import torch.nn as nn

import pickle

In [2]:
USE_TENSORBOARD = False

DATASET_HDF5 = "/media/mdastorage/breast_5x_dataset.h5"
MODEL_WEIGHTS_FILE = "../model_weights/additiveAttentionMIL_aug2.pt"
SLIDE_DATA_FILE = "../slide_data/breast.csv"


CUDA_DEVICE = "cuda:0"

torch.cuda.init()
torch.cuda.memory_summary(device=None, abbreviated=False)

cuda = torch.cuda.is_available()
device = torch.device(CUDA_DEVICE)

In [3]:
dataset = BreastEmbeddingDataset(DATASET_HDF5)

with open('../test_indices.pkl', 'rb') as f:
    test_indices = pickle.load(f)

test_dataset = torch.utils.data.Subset(dataset, test_indices)

786 786 786
<class 'torch.Tensor'> <class 'torch.Tensor'> <class 'torch.Tensor'>
Dataset Fetched!


In [4]:
model = AdditiveAttention().to(device)
model.load_state_dict(torch.load(MODEL_WEIGHTS_FILE))

<All keys matched successfully>

In [5]:
def count_labels(label, pos, neg):
    if label == 0:
        neg += 1
    else:
        pos += 1
    return pos, neg 


# ========================================== Validation Functions =====================================================

def valid_epoch_embeddings(model, dataloader, epoch, fold="", get_scores=False):
    model.eval()
    valid_loss, valid_error, valid_correct = 0., 0., 0
    probs, true_labels = [], []
    pos, neg = 0, 0
    nan_loss = False
    num_aug = 0
    attention_scores = {}
    Y_hats = []

    for batch_idx, (data, coords, label, path) in enumerate(dataloader):
        bag_label = label[0]
        pos, neg = count_labels(bag_label, pos, neg)
        num_aug = data.shape[0]
        embedding = data[0,:,:]
        if cuda:
            embedding, bag_label = embedding.to(device), label.to(device)

        loss, attention_weights = model.calculate_objective(embedding, bag_label)
        error, Y_hat, Y_prob = model.calculate_classification_error(embedding, bag_label)

        if get_scores:
            attention_scores[path[0]] = (coords[0], attention_weights)

        valid_loss += loss.item()
        valid_error += error
        valid_correct += (Y_hat == bag_label).sum().item()
        Y_hats.append(Y_hat)
        probs.append(Y_prob.item())
        true_labels.append(label[0].item())


    if epoch == "validation" and USE_TENSORBOARD:
        accuracy, auc, f1, recall = calculate_metrics(probs, true_labels)
        writer.add_text('loss', "{:.4f}".format(valid_loss/len(dataloader)))
        writer.add_text('auc', "{:.4f}".format(auc))
        writer.add_text('accuracy',"{:.4f}".format(accuracy))
        writer.add_text('f1',"{:.4f}".format(f1))
        writer.add_text('recall',"{:.4f}".format(recall))


    if get_scores:
        return valid_loss, valid_correct, probs, true_labels, attention_scores, Y_hats
    else:
        return valid_loss, valid_correct, probs, true_labels


In [6]:
import pandas as pd
from math import log2

#print("Attention", attention_scores)

def get_slide_levels(slide_ids, magnification, csv_filename):
    '''
    Given a set of slide ids and the pretended magnification, it returns the correct magnification levels, according to the csv file
    '''
    slide_levels = {}
    slide_info = pd.read_csv(csv_filename)

    for i in range(len(slide_info)):
        if slide_info["id"][i] not in slide_ids:
            #print("1", slide_info["id"][i])
            continue
        if round(slide_info["mpp"][i], 1) not in [0.5,0.2,0.3]:
            print("2", round(slide_info["mpp"][i]))
            continue
        max_magnification = 10/slide_info["mpp"][i]
        levels_to_decrease = int(log2(max_magnification / magnification))
        slide_levels[slide_info["id"][i]] = slide_info["max_level"][i] - levels_to_decrease

    return slide_levels

In [7]:
from utils.gdc_api_utils import getTile
import matplotlib.cm as cm
import matplotlib as mpl

color_map = cm.jet



def get_ROI_color(attention_scores, color, i):
    #sorted_attention_scores, index = attention_scores.sort()
    scores = attention_scores.cpu().detach().numpy()
    normalized_scores = (scores - np.min(scores)) / (np.max(scores) - np.min(scores))
    normalized_scores = scores - np.min(scores)
    mean = np.mean(normalized_scores)

    changed_scores = normalized_scores - mean
    return changed_scores[i]


def get_deviation_scores(attention_scores, i):
    scores = attention_scores.cpu().detach().numpy()
    mean = np.mean(scores)
    deviation_scores = scores - mean
    deviation_scores += np.min(deviation_scores)
    normalized_dev_scores = (deviation_scores - np.min(deviation_scores)) / (np.max(deviation_scores) - np.min(deviation_scores))
    return normalized_dev_scores.T[0][i]

def getPixelsInThumbnail(slide_id, magnification_level, coords_scores, show_original=False):
    img = getTile(slide_id, 9, 0, 0)
    img = np.array(img).transpose((1,0,2))
    
    if show_original:
        original_img = img.transpose((1,0,2)).copy()
    
    img_width = img.shape[0]
    img_height = img.shape[1]
    levels = magnification_level - 9 
    num_tiles = 2**levels

    tiles = (int(num_tiles*img_width/512)+1, int(num_tiles*img_height/512)+1)

    pixels_per_patch = int(512/num_tiles)

    coords = coords_scores[0]
    scores = coords_scores[1]

    for index in range(len(coords)):
        pixel_x = int(coords[index][0] * pixels_per_patch)
        pixel_y = int(coords[index][1] * pixels_per_patch)
        pixels_to_color = img[pixel_x:pixel_x+pixels_per_patch, pixel_y:pixel_y+pixels_per_patch, :]
        attention_score = get_deviation_scores(scores, index)
        

        color = color_map(attention_score)
        np.multiply(pixels_to_color,  [color[0], color[1], color[2]], out=pixels_to_color, casting="unsafe")

    img = img.transpose((1,0,2))

    fig = plt.figure()
    ax1 = fig.add_subplot(1,2,1)
    ax1.imshow(img)
    if show_original:
        ax2 = fig.add_subplot(1,2,2)
        ax2.imshow(original_img)
    plt.show()

In [10]:
def visualize_ROIs(attention_scores, Y_hats, img_path, show_original=True):
    index = 0
    for path, scores in attention_scores.items():
        pred = Y_hats[index]
        index += 1
        label = path[4]
        slide_id = path[6:]
        print(label)
        print("Label:", label, "Predicted:", pred)
        print(slide_id)
        slide_level = get_slide_levels([slide_id], 5, SLIDE_DATA_FILE)
        getPixelsInThumbnail(slide_id, slide_level[slide_id], scores, show_original)
    

def visualize_attention_scores_dist(attention_scores, Y_hats, img_path):
    index = 0
    for path, scores in attention_scores.items():
        pred = Y_hats[index]
        index += 1
        label = path[4]
        slide_id = path[6:]
    
        plt.hist(scores[1][0].cpu().detach().numpy(), range=[0.0,1.0])
        plt.show()    

In [11]:
dataloader = data_utils.DataLoader(test_dataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=True, collate_fn=collate)

_, _, _, _, attention_scores, Y_hats = valid_epoch_embeddings(model, dataloader, "validation", get_scores=True)
visualize_ROIs(attention_scores, Y_hats, "")


0
Label: 0 Predicted: tensor(0., device='cuda:0')
945983a7-108a-4e77-941a-b65bdb6575eb


UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x7fe4b7d41170>