In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import openslide
from openslide.deepzoom import DeepZoomGenerator

from PIL import Image

import matplotlib.pyplot as plt

import pandas as pd

import torch
from torch.utils.data import DataLoader, Subset

from pathlib import Path

import numpy as np

from miso.data.loading import DatasetSlideSubsample
from miso.models import SimpleMIL

from scipy.stats import pearsonr, spearmanr

from tqdm.notebook import tqdm

import pickle as pkl

In [None]:
# Fill LOCAL_PATH with the path containing the miso repo and PATH_TO_DATA with the path containing Her2ST data
LOCAL_PATH = Path('')
PATH_TO_DATA = Path('')

# Load models

In [None]:
genes = pd.read_csv(
    LOCAL_PATH / 'miso/assets/outputs_distil/miso_sr_her2st/metrics.csv'
)['gene'].values

In [None]:
models = []
for i in range(4):
    model =  SimpleMIL(
        input_dim=768,
        output_dim=len(genes),
        hidden=[1024, 512],
        activation=torch.nn.ReLU(),
        agg_method='mean',
        device='cuda'
    )

    model.load_state_dict(
        torch.load(
            LOCAL_PATH / f'miso/assets/outputs_distil/miso_sr_her2st/repeat_0_fold_{i}/model.pth',
            map_location='cpu'
        )
    )
    models.append(model)

# Load data

In [None]:
dataset = DatasetSlideSubsample(
    path_to_feats=PATH_TO_DATA / 'processed_data',
    path_to_counts=PATH_TO_DATA / 'processed_data',
    path_to_gene_list=LOCAL_PATH / 'miso/assets/genes_her2st_all.csv',
    normalization='raw_reads'
)

In [None]:
split_ids = pkl.load(open(LOCAL_PATH / 'miso/assets/splits_benchmark_her2st.pkl', 'rb'))

# Run inference on a sample

In [None]:
n = 2  # fold index
i = 6  # sample index within the test fold

model = models[n]

test_id = split_ids[n][2][i]
sample_set = Subset(dataset, np.arange(len(dataset))[(pd.Series(dataset.slide_names) == test_id).values])

In [None]:
dataloader = DataLoader(sample_set, 128, shuffle=False)

pred = []
label = []
coords_patches = []
for data in tqdm(dataloader):
    x = data['features'].float().cuda()
    c = data['coords'][..., 1:]
    l = data['labels'][..., 1:]
    with torch.no_grad():
        p = model.mlp(x.cuda()).cpu().numpy()
    pred.append(p)
    coords_patches.append(c)
    label.append(l)
pred = np.concatenate(pred)
label = np.concatenate(label)
coords_patches = np.concatenate(coords_patches)

In [None]:
img = np.array(Image.open(PATH_TO_DATA / f'images/HE/{test_id}.jpg'))

In [None]:
# Pick a gene and check performance

gene = 'CD3D'
gene_idx = np.arange(len(genes))[genes == gene][0]
pearsonr(tile_level_preds[:, gene_idx], label[:, gene_idx])

# Slide-level visualisation (no super-resolution)

In [None]:
tile_level_preds = np.mean(pred, axis=1)
tile_coords = np.min(coords_patches, axis=1)

In [None]:
p = tile_level_preds[:, gene_idx]
p = (p - np.min(p)) / (np.max(p) - np.min(p))
c = plt.cm.inferno(p)

fig, ax = plt.subplots(1, 1)
fig.set_size_inches(12, 8)
ax.imshow(img)
ax.set_xticks([])
ax.set_yticks([])
ax.scatter(tile_coords[:, 1], tile_coords[:, 0], marker='o', c=c, s=50, alpha=0.7)
plt.show()

# Patch-level visualisation

In [None]:
n = np.argsort(tile_level_preds[:, gene_idx])[-50]

In [None]:
img_tile = img[
    int(tile_coords[n, 0]-112):int(tile_coords[n, 0]+112),
    int(tile_coords[n, 1]-112):int(tile_coords[n, 1]+112)
]

In [None]:
fig, (ax0, ax1) = plt.subplots(1, 2)
fig.set_size_inches(10, 5)

ax0.imshow(img_tile)
ax0.set_xticks([])
ax0.set_yticks([])

ax1.imshow(img_tile)
pred_patch = pred[n]
coords_patch = coords_patches[n] - tile_coords[n] + 7

p = pred_patch[:, gene_idx]
p = (p - np.min(p)) / (np.max(p) - np.min(p))
c = plt.cm.inferno(p)

ax1.scatter(coords_patch[:, 1], coords_patch[:, 0], marker='o', c=c, s=100, alpha=0.7)
ax1.set_xticks([])
ax1.set_yticks([])
plt.show()