In [29]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
import pytorch_lightning as pl
import yaml
import argparse
import torch
import h5py
import matplotlib.pyplot as plt
import numpy as np

from classifier import ClassifierLightning
from options import Options

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Load features and model

In [31]:
patient_id = Path('18-LSS0736')
feature_path = Path('/Users/sophia.wagner/Downloads/439097.h5')
model_path = Path('/Users/sophia.wagner/Documents/PhD/projects/2022_MSI_transformer/attention-user-study/multi-all-same_transformer_DACHS-QUASAR-RAINBOW-TCGA_histaugan_isMSIH/models/best_model_multi-all-same_transformer_DACHS-QUASAR-RAINBOW-TCGA_histaugan_isMSIH_fold3.ckpt')

In [32]:
h5_file = h5py.File(feature_path)
features = torch.Tensor(np.array(h5_file['feats'])).unsqueeze(0)
coords = torch.Tensor(np.array(h5_file['coords']))
coords = [(coords[i, 0].int().item(), coords[i, 1].int().item()) for i in range(coords.shape[0])]

In [33]:
parser = Options()
args = parser.parser.parse_args('')  

# Load the configuration from the YAML file
with open(args.config_file, 'r') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

# Update the configuration with the values from the argument parser
for arg_name, arg_value in vars(args).items():
    if arg_value is not None and arg_name != 'config_file':
        config[arg_name]['value'] = getattr(args, arg_name)

# Create a flat config file without descriptions
config = {k: v['value'] for k, v in config.items()}

print('\n--- load options ---')
for name, value in sorted(config.items()):
    print(f'{name}: {str(value)}')

cfg = argparse.Namespace(**config)


--- load options ---
bs: 1
clini_info: {}
cohorts: ['TCGA']
criterion: BCEWithLogitsLoss
data_config: /home/ubuntu/projects/idkidc/data_config.yaml
debug: False
ext_cohorts: ['CPTAC']
feats: ctranspath
folds: 5
input_dim: 768
lr: 2e-05
model: transformer
name: test
norm: macenko
num_classes: 1
num_epochs: 1
num_tiles: -1
optimizer: AdamW
pad_tiles: False
project: hackathon
resume: None
save_dir: /home/ubuntu/logs
scheduler: None
seed: None
stop_criterion: loss
target: isMSIH
task: binary
wd: 2e-05


In [34]:
cfg.pos_weight = torch.tensor([1.0])
classifier = ClassifierLightning(cfg)
checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
checkpoint['state_dict'].keys()
classifier.load_state_dict(checkpoint['state_dict'])
classifier.eval();

### Load plotting utils

In [39]:
# function that plots scores nicely. scores should have the same length as the number of tiles.

def plot_scores(coords, scores, image, overlay=True, clamp=0.05, norm=True, colormap='RdBu', crop=False):
    if clamp:
        q05, q95 = torch.quantile(scores, clamp), torch.quantile(scores, 1-clamp)
        scores.clamp_(q05,q95)
    
    if norm:
        scores = NormalizeData(scores)
        
    if crop:
        coords_min, coords_max = np.array(coords).min(axis=0), np.array(coords).max(axis=0)
        y_min, y_max, x_min, x_max = round(coords_min[1]/d), round(coords_max[1]/d), round(coords_min[0]/d), round(coords_max[0]/d)
        if pat_name == '439042':
            x_max = round((69 * 1013)/d)
        print(y_min, y_max, x_min, x_max)
    else:
        y_min, y_max, x_min, x_max = 0, image.shape[0], 0, image.shape[1]

        
    attention_map = np.zeros((image.shape[0], image.shape[1]), dtype=np.float32)    
    tissue_map = -np.ones((image.shape[0], image.shape[1]), dtype=np.float32)
    
    offset = 1013
    for (x,y), s in zip(coords, scores):
        
        if colormap == 'RdBu': 
            attention_map[round(y/d):round((y+offset)/d), round(x/d):round((x+offset)/d)] = 1 - s.item()
        else: 
            attention_map[round(y/d):round((y+offset)/d), round(x/d):round((x+offset)/d)] = s.item()
        tissue_map[round(y/d):round((y+offset)/d), round(x/d):round((x+offset)/d)] = s.item()
            
    attention_map = np.array(attention_map * 255., dtype=np.uint8)
    tissue_map[tissue_map>=0] = 1
    tissue_map[tissue_map<0] = 0
    
#     plt.figure(figsize=(30, 30))
    a = 1.
    if overlay:
        plt.imshow(image[y_min:y_max, x_min:x_max])
        a = 0.5

    if crop:
        plt.imshow(attention_map[y_min:y_max, x_min:x_max], alpha=a*(tissue_map[y_min:y_max, x_min:x_max]), cmap=colormap, interpolation='nearest')
#         plt.imshow(attention_map[round(coords_min[1]/d):, round(coords_min[0]/d):], alpha=a*(tissue_map[round(coords_min[1]/d):, round(coords_min[0]/d):]), cmap=colormap, interpolation='nearest')
    else:
        plt.imshow(attention_map, alpha=a*(tissue_map), cmap=colormap, interpolation='nearest')
    plt.axis('off')

### Load attention utils

In [35]:
def compute_rollout_attention(all_layer_matrices, start_layer=0):
    # adding residual consideration- code adapted from https://github.com/samiraabnar/attention_flow
    num_tokens = all_layer_matrices[0].shape[1]
    batch_size = all_layer_matrices[0].shape[0]
    eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device)
    all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))]
    matrices_aug = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
                          for i in range(len(all_layer_matrices))]
    joint_attention = matrices_aug[start_layer]
    for i in range(start_layer+1, len(matrices_aug)):
        joint_attention = matrices_aug[i].bmm(joint_attention)
    return joint_attention

In [36]:
def generate_rollout(model, input, start_layer=0):
    model(input)
    blocks = model.transformer.layers
    all_layer_attentions = []
    for blk in blocks:
        attn_heads = blk[0].fn.get_attention_map()
        avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
        all_layer_attentions.append(avg_heads)
    rollout = compute_rollout_attention(all_layer_attentions, start_layer=start_layer)
    return rollout[:,0, 1:]

### Compute attention scores

In [40]:
rollout = generate_rollout(classifier.model, features, start_layer=0).squeeze(0)

In [22]:
plot_scores(coords, rollout, clamp=0.05)
plt.show()

<All keys matched successfully>

In [50]:
values, indices = rollout.topk(5)

In [53]:
for i in indices:
    print(coords[i])

(17221, 62806)
(10130, 50650)
(18234, 61793)
(18234, 35455)
(16208, 61793)


In [45]:
rollout

tensor([2.0769e-04, 3.8338e-04, 9.3652e-05,  ..., 9.9520e-05, 9.9313e-05,
        1.0548e-04])