In [1]:
import os, sys
os.environ["CUDA_VISIBLE_DEVICES"]="3"

import torch
from torch_geometric.loader import DataLoader
import torch_geometric.transforms as T
import torch.optim as optim
import torch.nn.functional as F


import random
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt


import PIL
PIL.Image.MAX_IMAGE_PIXELS = 933120000

from gnn import GNN
from visualize import generate_relevance, fetch_slide_image
from util import read_file, find_dataset_using_name


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Args:
    device = 0
    gnn = 'gin'
    num_layer = 3
    emb_dim = 64
    drop_ratio = 0
    jk = 'sum'
    graph_pooling = 'gmt'

    seed = 1078
    batch_size = 1
    num_workers = 0
    phase = 'cams'
    n_classes = 3
    data_config = 'ctranspath_files'
    fdim = 768
    patch_size = 256
    output = 'logs'

    run_name = "Graph-Perciever_September-17"
    fold_idx = 2
    dataset = ['cptac']
    index = None # 0 - Normal | 1 - LUSC | 2 - LUAD

In [3]:
from scipy.stats import zscore
from sklearn.preprocessing import MinMaxScaler, StandardScaler
def plot_heat_maps(graph, scores, prob, slide_root, patch_size=256, overlay=True, clamp=0.05, norm=True, colormap='RdBu_r', save_path=None):
    # Fetch patch coords & slide path for the tissue
    slide_path = graph.slide_path[0]
    coords = graph.node_coords
    # coords = [(int(x), int(y)) for x,y in coords]

    # fetch tissue image at specific downsample
    downsample_factor = 16.0
    image = fetch_slide_image(slide_path, slide_root, patch_size, downsample_factor=downsample_factor, gt = graph.y, save_path=save_path)
    image = np.asarray(image.convert("RGB"))
    image = (image - image.min()) / (image.max() - image.min())

    y_min, y_max, x_min, x_max = 0, image.shape[0], 0, image.shape[1]

    mask = np.zeros((image.shape[0], image.shape[1]), dtype=bool) # this is the cam mask
    heatmap = -np.ones(image.shape[:2], dtype=np.float32)

    offset = patch_size + 2 # 2 is for overlap
    d = downsample_factor
    scores = scores.numpy()
    scores = zscore(scores)

    if clamp:
        q05, q95 = np.quantile(scores, clamp), np.quantile(scores, 1-clamp)
        scores = np.clip(scores, a_min=q05, a_max=q95)

    # check if all values in scores are 0s
    scores = np.nan_to_num(scores, nan=0)
    if not np.all(scores == -1):
        scores = MinMaxScaler(feature_range=(-1, 1)).fit_transform(scores.reshape(-1,1))
    
    for (x, y), s in zip(coords, scores):

        # x, y = int(x)*512, int(y)*512
        x, y = x*patch_size, y*patch_size
        
        mask[round(y.item()/d):round((y.item()+offset)/d), round(x.item()/d):round((x.item()+offset)/d)] = True
        heatmap[round(y.item()/d):round((y.item()+offset)/d), round(x.item()/d):round((x.item()+offset)/d)] = s

    plt.figure(figsize=(30, 30))
    a = 1.
    if overlay:
        plt.imshow(image, alpha=1, cmap='gray')
        a = 0.7

    plt.imshow(heatmap, alpha=0.5*mask, cmap=colormap, interpolation='nearest')
    cbar = plt.colorbar(location='right', orientation='vertical')
    cbar.ax.tick_params(labelsize=40)
    plt.axis('off')

    return plt


In [4]:
def generate_cams(args, model, device, multiple_loaders, index=None):

    # print('Evaluating...', args.dataset.upper())
    # print("Dataset length", len(loader))
    model.eval()

    y_true = []
    y_pred = []

    os.makedirs(args.output_folder, exist_ok=True)

    for loader in multiple_loaders:

        true_labels = list(loader.dataset.classdict.keys())
        to_be_predicted_classes = list(loader.dataset.to_be_predicted_classes.keys())

        for step, graph in enumerate(tqdm(loader, desc="Iteration")):

            graph = graph.to(device)
            slide_name = graph.slide_path[0]
            print(slide_name)
            # print(graph.node_coords)


            if graph.x.shape[0] == 1:
                pass
            else:

                # GENERATE VISUALIZATION :
                transformer_attribution, output, y_pred = generate_relevance(model, graph, index=index)
                if index is not None:
                    y_pred = index

                print("logits: ", output)

                prob = F.softmax(output, dim=1)
                prob = prob.squeeze()

                print("Slide: {}, True Class: {}, Predicted Class: {}(p={:.3f})".format(slide_name, true_labels[graph.y], to_be_predicted_classes[y_pred], prob[y_pred].item()))
                
                del output

                ########################################
                slide_root = os.path.join('/SeaExp/Rushin/datasets/', args.dataset_name.upper(), 'WSIs')
                plt = plot_heat_maps(graph, scores=transformer_attribution, prob=prob[index], slide_root=slide_root, clamp=0.05, save_path=args.output_folder, overlay=True) # attention_blend = 
                # Use numpy to save attention_blend image to a file
                # attention_blend = Image.fromarray(attention_blend)
                ########################################
                
                # plt.axis('off')
                plt.savefig(os.path.join(args.output_folder, "{}_{}(fold{})_cam.png".format(slide_name, to_be_predicted_classes[y_pred], args.fold_idx)))
                plt.close()
                # attention_blend.save(os.path.join(args.output_folder, "{}_{}(fold{})_cam.png".format(slide_name, to_be_predicted_classes[y_pred], args.fold_idx)))
                # break

In [5]:
def implement(index=None):

    args = Args()
    if index is not None:
        args.index = index

    ### set up seeds and gpu device
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")

    args.slide_feats_folder = {}

    fold_idx = args.fold_idx
    test_loaders = []

    for item in args.dataset:

        args.dataset_name = item
        # args.slide_feats_folder[item] = os.path.join('/SeaExp/Rushin/datasets', item.upper(), 'slide_features')
        # os.makedirs(args.slide_feats_folder[item], exist_ok=True)

        dataset_class = find_dataset_using_name(item)
        print(dataset_class)

        ### automatic dataloading and splitting
        root = os.path.join('/SeaExp/Rushin/datasets', item.upper(), args.data_config)
        wsi_file = os.path.join('/SeaExp/Rushin/datasets', item.upper(), '%s_%s.txt' % (item.upper(), args.phase))
        wsi_ids = read_file(wsi_file)

        dataset = dataset_class(root, wsi_ids, args.fdim, n_classes=args.n_classes, isTrain=False, transform=T.ToSparseTensor(remove_edge_index=False))
        test_loaders.append(DataLoader(dataset, batch_size=1, shuffle=False, num_workers=args.num_workers))

    log_path = os.path.join('logs', "{}_fold_{}".format(args.run_name, fold_idx))

    model = GNN(gnn_type = args.gnn, num_class = dataset.num_classes, num_layer = args.num_layer, input_dim = args.fdim, emb_dim = args.emb_dim, drop_ratio = args.drop_ratio, JK = args.jk, graph_pooling = args.graph_pooling).to(device)

    # model.load_state_dict(torch.load(os.path.join(log_path, "epoch_30_model_{}_fold_{}.pth".format(args.run_name, fold_idx))))
    model.load_state_dict(torch.load(os.path.join(log_path, "final_model_{}_fold_{}.pth".format(args.run_name, fold_idx))))
    model = model.to(device)
    print("model weights loaded successfully")

    total_params = sum(p.numel() for p in model.parameters())
    print('Total params:', total_params)


    # args.output_folder = os.path.join(args.output, args.run_name+"_{}_{}".format(args.patch_size, args.phase), args.dataset_name+f'_final_epoch')

    # generate_cams(args, model, device, test_loaders, args.index)

    # del model


In [6]:
# implement(index=0)
# implement(index=1)
implement(index=1)

<class 'dataloaders.cptac.CptacDataset'>
model weights loaded successfully
Total params: 202292
