Retrieve a patch of specified size from anywhere there is a cell detection.

Based on https://github.com/choosehappy/QuPathGeoJSONImportExport/blob/master/classify_geojson_objects_in_wsi_centeroid_based_withlevel_mask.py

In [1]:
tilesize = 5000 #size of the tile to load from openslide 
patchsize = 150 #patch size needed by our DL model 

minhits = 100 #the minimum number of objects needed to be present within a tile for the tile to be computed on
classnames=["Other","Lymphocyte"] #the names of those classes which will appear in QuPath later on 
colors = [-377282,-9408287] #their associated color, see selection of different color values at the bottom of the file

json_fname=r'data/detections.json' #input geojson file
json_annotated_fname=r'data/detections_anno.json' #target output geojson file
wsi_fname="data/b32f-sag_L-lxn_pv-w03_2-cla-63x-proc.czi" #whole slide image fname to load cells from which coincide with the json file

In [2]:
from pathlib import Path
from math import ceil

import numpy as np
import matplotlib.pyplot as plt
from tqdm.autonotebook import tqdm

import geojson
from shapely.geometry import shape
from shapely.strtree import STRtree
from shapely.geometry import Point
from shapely.geometry import Polygon

import bioformats
import javabridge
javabridge.start_vm(class_path=bioformats.JARS)

%matplotlib notebook

In [3]:
if json_fname.endswith(".gz"):
    with gzip.GzipFile(json_fname, 'r') as f:
        allobjects = geojson.loads(f.read(), encoding= 'ascii')
else:
    with open(json_fname) as f:
        allobjects = geojson.load(f)

print("done loading")

done loading


In [4]:
allshapes = [shape(obj["nucleusGeometry"] if "nucleusGeometry" in obj.keys() else obj["geometry"]) for obj in allobjects]
allcenters = [s.centroid  for s in allshapes]
print("done converting")

for i in range(len(allshapes)):
        allcenters[i].id = i

searchtree = STRtree(allcenters)
print("done building tree")

done converting
done building tree


In [13]:
reader = bioformats.ImageReader(wsi_fname)
XYWH = (3000, 2000, 1000, 1000)
print(reader.rdr.getSizeX())
# reader.read(c=1, XYWH=XYWH)

# wsi_shape = reader.get_sizeX
# nrow = ceil(wsi_shape[0] / tilesize)
# ncol = ceil(wsi_shape[1] / tilesize)

# scalefactor= 1
# paddingsize = patchsize // 2*scalefactor

# int_coords = lambda x: np.array(x).round().astype(np.int32)

12960


In [None]:
for y in tqdm(range(0, wsi_shape[0], round(tilesize * scalefactor)), desc="outer" , leave=False):
    for x in tqdm(range(0, wsi_shape[1], round(tilesize * scalefactor)), desc=f"inner {y}", leave=False):            

        tilepoly = Polygon([[x,y],[x+tilesize*scalefactor,y],
                            [x+tilesize*scalefactor,y+tilesize*scalefactor],
                            [x,y+tilesize*scalefactor]])
        hits=searchtree.query(tilepoly)

        if len(hits) < minhits:
            continue

        XYWH = (x-paddingsize, y-paddingsize, tilesize+2*paddingsize, tilesize+2*paddingsize)
        tile  = np.asarray(reader.read_image(c=0, XYWH=XYWH))

        shape_out = []
        arr_out = np.zeros((len(hits),patchsize,patchsize))
        id_out = np.zeros((len(hits),1))


        #---- get patches from hits within this tile and stick them (and their ids) into matricies
        for hit,arr,id in zip(hits,arr_out,id_out):
            px,py=hit.coords[:][0]  #this way is faster than using hit.x and hit.y, likely because of call stack overhead
            c=int((px-x+paddingsize)//scalefactor)
            r=int((py-y+paddingsize)//scalefactor)
            patch = tile[r - patchsize // 2:r + patchsize // 2, c - patchsize // 2:c + patchsize // 2]

            shape_out.append(hit)
            arr[:] = patch
            id[:]=hit.id

In [None]:
fig, axs = plt.subplots(5, 5, figsize=(10,10))
axs = axs.ravel()

rnd_idxs = np.sort(np.random.choice(range(0, arr.shape[0]), 25, replace=False))

rnd_imgs = [arr_out[idx][:][:] for idx in rnd_idxs]
rnd_ids = [id_out[idx] for idx in rnd_idxs]

for ax, img, id_ in zip(axs, rnd_imgs, rnd_ids):
    ax.imshow(img)
    ax.axis('off')
    ax.text(0.5, 0.05, f'id:{int(id_)}', color='white', fontsize=14, horizontalalignment='center', verticalalignment='center', transform=ax.transAxes)

fig.tight_layout()