In [1]:
import numpy as np
import pandas as pd
import os
import glob
import h5py
import wandb

import string
import random

import torch 
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F

import lightning as L
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint


from torch import optim, utils, Tensor

import matplotlib.pyplot as plt
from sklearn.metrics import RocCurveDisplay, PrecisionRecallDisplay
import sklearn.linear_model


from itables import show
from src.model.SimpleMILModels import Attention, MaxMIL, AttentionResNet
from src.dataloaders.DataLoaders import RetCCLFeatureLoader, RetCCLFeatureLoaderMem


import zarr
import seaborn as sns

recompute_attn = False

In [2]:
def sigmoid_array(x):
    return 1 / (1 + np.exp(-x))

In [3]:


import os
os.environ['HTTP_PROXY']="http://www-int.dkfz-heidelberg.de:80"
os.environ['HTTPS_PROXY']="http://www-int.dkfz-heidelberg.de:80"


In [4]:
group = "4ZFF4Q"

In [5]:

api = wandb.Api()


In [6]:
runs = api.runs(path="psmirnov/TCGA_RetCLL_299_CT", filters={"group": group})

In [7]:
runs

<Runs psmirnov/TCGA_RetCLL_299_CT>

Load in the features (not too much memory needed)

In [8]:
path_to_extracted_features = '/home/p163v/histopathology/MASTER/UKHD_1/299/'

In [9]:

slides_to_infer = ['30435-19_L2-1','30435-19_L2-2']

In [11]:
test_features = [h5py.File(path_to_extracted_features + "/" + file + ".h5", 'r')['feats'][:] for file in slides_to_infer]

# Loss

We use the loss as the early stopping criteria


In [12]:
model_list = list()
attention_list = list()
cv =  lambda x: np.std(x) / np.mean(x)

In [13]:
test_data = RetCCLFeatureLoaderMem(test_features, np.array([0,0]), patches_per_iter='all')

RetCCLTest = DataLoader(test_data, batch_size=1, num_workers=1)#, sampler=valid_Sampler)


In [14]:
len(runs)

5

In [15]:
for ii in range(len(runs)):
    
    arts = runs[ii].logged_artifacts()
    arts_dict = {a.name.removesuffix(':'+a.version).split('-')[0]: a for a in arts}
    checkpoint_folder_name = arts_dict['model'].name.split('-')[1].removesuffix(':'+arts_dict['model'].version)
    args = runs[0].config

    model = Attention(2048, lr=args['lr'], weight_decay=args['weight_decay'], hidden_dim=args['hidden_dim'], attention_dim=args['attention_dim'], class_weights=torch.tensor(float(args['class_weights'])))
    chkpt_file = glob.glob('lightning_logs/'+checkpoint_folder_name+'/checkpoints/best_loss*')[0]
    model = model.load_from_checkpoint(chkpt_file, map_location=torch.device('cpu'))
    model.eval()
    model_list.append(model)
    model_attention = [model.attention_forward(torch.tensor(x).to(model.device)).detach().cpu().numpy() for x,y in iter(RetCCLTest)]

    attention_list.append(model_attention)
    

  model_attention = [model.attention_forward(torch.tensor(x).to(model.device)).detach().cpu().numpy() for x,y in iter(RetCCLTest)]


In [16]:
len(attention_list)

5

In [17]:
attention_combined = [np.hstack(x) for x in zip(*attention_list)]


In [18]:
np.mean(attention_combined[0], axis=1).shape



(1, 14948)

In [19]:
for jj in range(len(np.array(slides_to_infer))):
        slidename = slides_to_infer[jj]
        print('Writing Attention Map ' + slidename)
        coords = h5py.File(path_to_extracted_features + "/" + slidename + ".h5", 'r')['coords'][:]
        outarray_root = zarr.open("/home/p163v/histopathology/attention_maps/MASTER/"+ group + "/"+ slidename + "_per_tile_attention.zarr", mode='w') 
        outarray_root['coords'] = coords
        outarray_root['attn'] = np.mean(attention_combined[jj], axis=1)[:].reshape(-1)


Writing Attention Map 30435-19_L2-1
Writing Attention Map 30435-19_L2-2
