## DeepSpot inference from H&E images

Here, we provide an example of how to use the pretrained weights and perform inference using DeepSpot to predict spatial transcriptomics from H&E images.

You can download the pretrained weights at https://zenodo.org/records/14619853

In [None]:
import os
os.chdir('../')

In [None]:
### download from zenodo
!wget -c https://zenodo.org/records/14638865/files/DeepSpot_pretrained_model_weights.zip?download=1

In [None]:
### unzip data
!unzip DeepSpot_pretrained_model_weights.zip

In [None]:
### you should see the available weights listed
!ls -al DeepSpot_pretrained_model_weights

Export packages

In [None]:
from deepspot.utils.utils_image import predict_spatial_transcriptomics_from_image_path
from deepspot.utils.utils_image import get_morphology_model_and_preprocess
from deepspot.utils.utils_image import crop_tile

from deepspot import DeepSpot

import matplotlib.image as mpimg
from openslide import open_slide
import matplotlib.pyplot as plt
from tqdm import tqdm
import squidpy as sq
import anndata as ad
import pandas as pd
import numpy as np
import pyvips
import torch
import math
import yaml
import PIL

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

Here, we specify the input parameters. This information should be selected carefully, as it is based on the spatial transcriptomics training dataset. We continue with values based on our toy example from the COAD dataset.

In [None]:
out_folder = "example_data"
white_cutoff = 200  # recommended, but feel free to explore
downsample_factor = 10 # downsampling the image used for visualisation in squidpy
model_weights = 'DeepSpot_pretrained_model_weights/Colon_HEST1K/final_model.pkl'
model_hparam = 'DeepSpot_pretrained_model_weights/Colon_HEST1K/top_param_overall.yaml'
gene_path = 'DeepSpot_pretrained_model_weights/Colon_HEST1K/info_highly_variable_genes.csv'
sample = 'ZEN38'
image_path = f'example_data/data/image/{sample}_without_fud.jpg'

In [None]:
with open(model_hparam, "r") as stream:
    config = yaml.safe_load(stream)
config

n_mini_tiles = config['n_mini_tiles'] # number of non-overlaping subspots
spot_diameter = config['spot_diameter'] # spot diameter
spot_distance = config['spot_distance'] # distance between spots
image_feature_model = config['image_feature_model'] 
image_feature_model

In [None]:
### Specify the weights for the pretrained model used for tile feature extraction
image_feature_model_path = "../huggingface/hub/models--MahmoodLab--UNI/blobs/56ef09b44a25dc5c7eedc55551b3d47bcd17659a7a33837cf9abc9ec4e2ffb40"

In [None]:
genes = pd.read_csv(gene_path)
selected_genes_bool = genes.isPredicted.values
genes_to_predict = genes[selected_genes_bool]
genes_to_predict.sort_values("highly_variable_rank")

In [None]:
# Load the image
image = mpimg.imread(image_path)

# Display the image
plt.imshow(image)
plt.axis('off')  # Turn off axis labels
plt.show()

We build the grid coordinates for the spots based on the image.

In [None]:
image = pyvips.Image.new_from_file(image_path)

coord = []
for i, x in enumerate(range(spot_diameter + 1, image.height - spot_diameter - 1, spot_distance)):
    for j, y in enumerate(range(spot_diameter + 1, image.width - spot_diameter - 1, spot_distance)):
        coord.append([i, j, x, y])
coord = pd.DataFrame(coord, columns=['x_array', 'y_array', 'x_pixel', 'y_pixel'])
coord.index = coord.index.astype(str)

We select the spots that are located within the tissue.

In [None]:
is_white = []
counts = []
for _, row in tqdm(coord.iterrows()):
    x = row.x_pixel - int(spot_diameter // 2)
    y = row.y_pixel - int(spot_diameter // 2)
    
    main_tile = crop_tile(image, x, y, spot_diameter)
    main_tile = main_tile[:,:,:3]
    white = np.mean(main_tile)
    is_white.append(white)

counts = np.empty((len(is_white), selected_genes_bool.sum())) # empty count matrix 

coord['is_white'] = is_white

We create the anndata object, empty for now

In [None]:
adata = ad.AnnData(counts)
adata.var.index = genes[selected_genes_bool].gene_name.values
adata.obs = adata.obs.merge(coord, left_index=True, right_index=True)
adata.obs['is_white'] = coord['is_white'].values
adata.obs['is_white_bool'] = (coord['is_white'].values > white_cutoff).astype(int)
adata.obs['sampleID'] = sample
adata.obs['barcode'] = adata.obs.index
adata = adata[adata.obs.is_white_bool == 0, ]
adata

In [None]:
### CREATE IMAGE
img = open_slide(image_path)
n_level = len(img.level_dimensions) - 1 # 0 based


large_w, large_h = img.dimensions
new_w = math.floor(large_w / downsample_factor)
new_h = math.floor(large_h / downsample_factor)

whole_slide_image = img.read_region((0, 0), n_level, img.level_dimensions[-1])
whole_slide_image = whole_slide_image.convert("RGB")
img_downsample = whole_slide_image.resize((new_w, new_h), PIL.Image.BILINEAR)


adata.obsm['spatial'] = adata.obs[["y_pixel", "x_pixel"]].values
# adjust coordinates to new image dimensions
adata.obsm['spatial'] = adata.obsm['spatial'] / downsample_factor
# create 'spatial' entries
adata.uns['spatial'] = dict()
adata.uns['spatial']['library_id'] = dict()
adata.uns['spatial']['library_id']['images'] = dict()
adata.uns['spatial']['library_id']['images']['hires'] = np.array(img_downsample)

In [None]:
# Load the YAML file into a regular Python dictionary
with open(model_hparam, 'r') as yaml_file:
    model_hparam = yaml.safe_load(yaml_file)
model_hparam

Initialize DeepSpot and the pretrained pathology foundation model. This time, we compute the tile representation on the fly, which may take more time. The current implementation preprocesses a single spot per batch, but extending it to multiple spots might offer additional speed improvements. Contributions are welcome.

In [None]:
model_expression = torch.load(model_weights, map_location=device)
model_expression.to(device)
model_expression.eval()

In [None]:
morphology_model, preprocess, feature_dim = get_morphology_model_and_preprocess(model_name=image_feature_model, 
                                                                                device=device, model_path=image_feature_model_path)
morphology_model.to(device)
morphology_model.eval()

In [None]:
counts = predict_spatial_transcriptomics_from_image_path(image_path, 
                                                        adata,
                                                        spot_diameter,
                                                        n_mini_tiles,
                                                        preprocess, 
                                                        morphology_model, 
                                                        model_expression, 
                                                        device,
                                                        super_resolution=False,
                                                        neighbor_radius=1)

##### Remember from the training notebook...
The `scaler` is important to be the same as the one used during training, so that the predictions of DeepSpot can be rescaled back to their original ranges using the `inverse_transform` function. 

##### IMPORTANT: Remember to manually rescale the values, as this is not done automatically.
```
expression_norm = model(X)
expression_norm should be np.array
expression = model.inverse_transform(expression_norm)
```

In [None]:
counts = model_expression.inverse_transform(counts)
counts

You are free to explore other types of transformations that may enhance spatial transcriptomics predictions. The following are just a few examples.

In [None]:
counts[counts < 0] = 0

In [None]:
adata_predicted = ad.AnnData(counts, 
                             var=adata.var,
                             obs=adata.obs, 
                             uns=adata.uns, 
                             obsm=adata.obsm).copy()
adata_predicted

In [None]:
sq.pl.spatial_scatter(adata_predicted, 
                      color=['MUC2', 'ITLN1', 
                             'CLCA1', 'FCGBP'], 
                      wspace=0,
                      ncols=2,
                      size=5)