In [None]:
import sys
import os
current_dir = os.path.abspath(os.getcwd())
sgmf_dir = os.path.abspath(os.path.join(current_dir, '..', 'SGMF-main'))

if sgmf_dir not in sys.path:
    sys.path.insert(0, sgmf_dir)

from wsi_core.WholeSlideImage import WholeSlideImage

from scipy.stats import percentileofscore
import math
from utils.file_utils import save_hdf5
from scipy.stats import percentileofscore
from utils.utils import *
import h5py
import torch
import numpy as np
import pandas as pd

from torch_geometric.nn import GCNConv
from torch_geometric.explain import Explainer, GNNExplainer
from torch_geometric.explain.config import ModelConfig, ExplanationType
from torch_geometric.typing import NodeType
from torch_geometric.data import Data
from globals import *
from utils_local import ordered_yaml
import yaml
from parser import parse_gnn_model
import openslide
import random
import matplotlib.pyplot as plt

In [None]:

seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
def drawHeatmap(scores, coords, slide_path=None, wsi_object=None, vis_level=-1, top_n=15, **kwargs):
    if wsi_object is None:
        wsi_object = WholeSlideImage(slide_path)
        print(wsi_object.name)
    wsi = wsi_object.getOpenSlide()
    if vis_level < 0:
        vis_level = wsi.get_best_level_for_downsample(32)
    if isinstance(scores, torch.Tensor):
        scores = scores.cpu().numpy()
    else:
        scores = np.array(scores)
    heatmap = wsi_object.visHeatmap(scores, coords, vis_level=vis_level, **kwargs)
    return heatmap


In [None]:
def weight_coord(graph_path, model):
    data = torch.load(graph_path).to(device)
    x = data['x']
    edge_index = data['edge_latent']
    coord = data['centroid']
    explainer = GNNExplainer(epochs=1000, lr=0.001, )
    explainer_config = {
        'explanation_type': ExplanationType.model,
        'node_mask_type': 'object',
        'edge_mask_type': None
    }
    model_config = ModelConfig(mode='regression', task_level='graph', return_type='raw')
    explainer.connect(explainer_config, model_config)
    prediction = model(data)
    explanation = explainer.forward(model=model, x=x, edge_index=edge_index, target=prediction)
    node_weight = explanation.node_mask.view(-1)
    node_weight = node_weight.cpu().numpy()
    print(node_weight.max(), node_weight.min())
    node_weight = ((node_weight - node_weight.min()) / (node_weight.max() - node_weight.min()) - 0.5) * 2
    node_weight = node_weight**(1/3)
    print(node_weight.max(), node_weight.min())
    return node_weight, coord.cpu().numpy()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
default_config_path = "training_config/survival_analysis.yml"
opt_path = CONFIG_DIR / default_config_path
with open(opt_path, mode='r') as f:
    loader, _ = ordered_yaml()
    config = yaml.load(f, loader)
    print(f"Loaded configs from {opt_path}")
model = parse_gnn_model(config["GNN"]).to(device)

slide_id = "B201714220-22"

fold_num = 4
fold_id = 0
seed = 42
for i in range(1, fold_num+1):
    test_fold = pd.read_csv(f'data/test_results/level1/slide_num148/fold_num_4/test_fold/test_results_seed{seed}_fold_{i}.csv')
    if slide_id in test_fold['slide_id'].values:
        fold_id = i
        print(f"{slide_id}: {fold_id}")
        break

state_dict = torch.load(f'data/model_save/level1/best_model_fold_{fold_id}_seed83.pt')
model.load_state_dict(state_dict)
model.eval()

level = "_level1"
# level = ""
graph_path = f"data/create_save{level}/graph_files/{slide_id}.pt"
slide_path = f"data/WSI_svs/SUR/{slide_id}.svs"
node_weight, coord = weight_coord(graph_path, model)


In [None]:
model

In [None]:
patch_size = 3000
heatmap_args = {
    'vis_level' : 3,
    'cmap': 'coolwarm',
    'blank_canvas' : False ,
    'blur' : False ,
    'binarize' : False ,
    'custom_downsample' : 1 ,
    'alpha': 0.9,
    'patch_size': (patch_size, patch_size), 
    # 'convert_to_percentiles': True, 
    # 'binarize': False
}
heatmap = drawHeatmap(node_weight, coord, slide_path, **heatmap_args)

In [None]:
heatmap

In [None]:
slide = openslide.open_slide(slide_path)

def vis_region(rank_num, slide=slide, patch_size=1024):
    node_weight[np.isnan(node_weight)] = 0
    # coord[np.isnan(node_weight)] = 0
    sorted_indices = np.argsort(node_weight)[::-1]
    if rank_num > 0:
        idx = sorted_indices[rank_num - 1]
    elif rank_num < 0:
        idx = sorted_indices[len(node_weight) + rank_num]
    coord_selected = tuple(coord[idx].astype(int))
    
    patch = slide.read_region(coord_selected, 0, (patch_size, patch_size))
    patch = patch.convert("RGB")
    plt.imshow(patch)
    plt.axis('off')
    plt.show()
    
def vis_region_with_box(rank_num, slide=slide, patch_size=1024, display_level=2):
    node_weight[np.isnan(node_weight)] = 0
    # coord[np.isnan(node_weight)] = 0
    sorted_indices = np.argsort(node_weight)[::-1]
    if rank_num > 0:
        idx = sorted_indices[rank_num - 1]
    elif rank_num < 0:
        idx = sorted_indices[len(node_weight) + rank_num]
    coord_selected = tuple(coord[idx].astype(int))
    level = min(display_level, slide.level_count - 1)
    downsample = slide.level_downsamples[level]
    scaled_x = int(coord_selected[0] / downsample)
    scaled_y = int(coord_selected[1] / downsample)
    scaled_patch_size = int(patch_size / downsample)
    wsi_image = slide.read_region((0, 0), level, slide.level_dimensions[level])
    wsi_image = wsi_image.convert("RGB")
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.imshow(wsi_image)
    rect = plt.Rectangle(
        (scaled_x, scaled_y), 
        scaled_patch_size, 
        scaled_patch_size,
        linewidth=20 * (1 / downsample),
        edgecolor='red', 
        facecolor='none', 
        linestyle='--'
    )
    ax.add_patch(rect)
    ax.set_title(f"Region of Rank {rank_num} (Level {level}, Downsample={downsample:.1f}x)")
    plt.axis('off')
    plt.show()

In [None]:
for i in range(3):
    vis_region(i+1)
for i in range(3):
    vis_region(-i-1)

In [None]:
# vis_region(1)
# vis_region_with_box(1)

In [None]:
# vis_region(2)
# vis_region_with_box(2)

In [None]:
def save_region(rank_num, pacth_name, slide=slide, patch_size=1024, save_path=f"/homeuser/home/xiazhixiang/CONCH/docs/{slide_id}"):
    node_weight[np.isnan(node_weight)] = 0
    # coord[np.isnan(node_weight)] = 0
    sorted_indices = np.argsort(node_weight)[::-1]
    if rank_num > 0:
        idx = sorted_indices[rank_num - 1]
    elif rank_num < 0:
        idx = sorted_indices[len(node_weight) + rank_num]
    coord_selected = tuple(coord[idx].astype(int))
    
    patch = slide.read_region(coord_selected, 0, (patch_size, patch_size))
    patch = patch.convert("RGB")
    os.makedirs(save_path, exist_ok=True)
    patch.save(save_path+"/"+pacth_name, "JPEG")

In [None]:
for rank_num in range(1, 10):
    save_region(rank_num, f"rank_{rank_num}.jpg")