## Building attention heatmap from a pretrained model

This is a very verbose implementation. In many cases, using the `Processor` will be more appropriate when processing large batches of slides.  

#### Run inference on a slide from scratch. 

- Download a WSI
- Create OpenSlideWSI
- Run segmentation
- Run patch coordinate extraction
- Run patch features extraction
- Run slide feature extraction and attention score

Note: This tutorial uses Threads, which is not public yet. Stay tuned!


In [None]:
import os
import torch 
import h5py 
from PIL import Image
import geopandas as gpd
from IPython.display import display
from huggingface_hub import snapshot_download

try:
    from scipy.stats import rankdata
except:
    print('Please install scipy: `pip install scipy`')

from trident.wsi_objects.WSI import OpenSlideWSI
from trident.segmentation_models.load import segmentation_model_factory
from trident.patch_encoder_models.load import encoder_factory as patch_encoder_factory
from trident.slide_encoder_models.load import encoder_factory as slide_encoder_factory

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2

# a. Download a WSI
OUTPUT_DIR = "tutorial-3/"
DEVICE = f"cuda:0"
WSI_FNAME = '394140.svs'
os.makedirs(OUTPUT_DIR, exist_ok=True)
local_wsi_dir = snapshot_download(
    repo_id="MahmoodLab/unit-testing",
    repo_type='dataset',
    local_dir=os.path.join(OUTPUT_DIR, 'wsis'),
    allow_patterns=[WSI_FNAME]
)

# b. Create OpenSlideWSI
wsi_path = os.path.join(local_wsi_dir, WSI_FNAME)
slide = OpenSlideWSI(slide_path=wsi_path, lazy_init=False)

# c. Run segmentation 
segmentation_model = segmentation_model_factory("hest", device=DEVICE)
geojson_contours = slide.segment_tissue(segmentation_model=segmentation_model, target_mag=10, job_dir=OUTPUT_DIR)

# d. Run patch coordinate extraction
coords_path = slide.extract_tissue_coords(
    target_mag=20,
    patch_size=512,
    save_coords=OUTPUT_DIR,
    overlap=448, 
)

# e. Run patch feature extraction
patch_encoder = patch_encoder_factory("conch_v15")
patch_encoder.eval()
patch_encoder.to(DEVICE)
patch_features_path = slide.extract_patch_features(
    patch_encoder=patch_encoder,
    coords_path=coords_path,
    save_features=os.path.join(OUTPUT_DIR, f"features_conch_v15"),
    device=DEVICE
)

# f. Run slide feature extraction
slide_encoder = slide_encoder_factory("threads").to(DEVICE).eval()
with h5py.File(patch_features_path, 'r') as f:
    coords = f['coords'][:]
    patch_features = f['features'][:]
    coords_attrs = dict(f['coords'].attrs)

batch = {
    'features': torch.from_numpy(patch_features).float().to(DEVICE).unsqueeze(0) ,
    'coords': torch.from_numpy(coords).to(DEVICE).unsqueeze(0),
    'attributes': coords_attrs
}

# Generate slide-level features
with torch.autocast(device_type='cuda', enabled=(slide_encoder.precision != torch.float32)):
    features, attention = slide_encoder(batch, return_raw_attention=True)
    features = features.cpu().numpy().squeeze()
    attention = attention.cpu().numpy().squeeze()


#### Visualize the heatmap using attention scores extracted

In [None]:
from trident.Visualization import visualize_heatmap

heatmap = visualize_heatmap(
    wsi=slide,
    scores=attention[:, 0],  # get attention for the first attention head.
    coords=coords,
    vis_level=2,
    coords_attrs=coords_attrs,
    normalize=True,
)

display(heatmap)
heatmap.save(os.path.join(OUTPUT_DIR, f'{slide.name}_heatmap.jpg'))
