## Retrieve a patch of specified size from anywhere there is a cell detection.

Based on https://github.com/choosehappy/QuPathGeoJSONImportExport/blob/master/classify_geojson_objects_in_wsi_centeroid_based_withlevel_mask.py

In [None]:
from pathlib import Path
from math import ceil

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from tqdm.notebook import tqdm
import geojson
from shapely.geometry import Polygon, shape, box

from shapely.strtree import STRtree
from shapely.affinity import translate
import pandas as pd

# from utils import BioFormatsReader
from patch import Patch
from tiler import Tiler
from radial_profiler import get_radial_profiles, RadialProfilePlotter, RadialProfiler

from rasterio.features import rasterize

import logging

from scipy import ndimage
from skimage.measure import find_contours
from skimage.morphology import binary_dilation, disk
from skimage.exposure import equalize_hist
from skimage.color import gray2rgb

from bioformatsreader import BioFormatsReader

# import javabridge
# import bioformats
# javabridge.start_vm(class_path=bioformats.JARS)

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logging.getLogger('rasterio').setLevel(logging.ERROR)

%matplotlib inline
%load_ext line_profiler

In [None]:
tilesize = 2000 # size of the tile to load 
patchsize = 200 # patch size needed by our DL model 

minhits = 5 # the minimum number of objects needed to be present within a tile for the tile to be computed on

json_fname = 'data/b41f-sag_L-neun_pv-w09_2-cla-63x-detections_null.json'
wsi_fname = 'data/b41f-sag_L-neun_pv-w09_2-cla-63x.czi'
# channel_order = [1, 2, 0]  # G, B, R
titles = ['NEUN', 'DAPI', 'PV']
colors = ['green', 'blue', 'red']

search_area = Polygon([[1968, 6266], [4485, 6266], [4485, 8782], [1968, 8782]])  # the area to search for hits, can be coordinates, a json, or 

In [None]:
with open(json_fname) as f:
    allobjects = geojson.load(f)
logger.debug("Loaded detections")

allshapes = [shape(obj["nucleusGeometry"] if "nucleusGeometry" in obj.keys() else obj["geometry"]) for obj in allobjects]
allcenters = [s.centroid  for s in allshapes]
logger.debug("Converted %d objects", len(allcenters))

for i in range(len(allshapes)):
        allcenters[i].id = i

searchtree = STRtree(allcenters)
logger.debug("Built search tree")

reader = BioFormatsReader(wsi_fname)

def reader_func(*args):
    X, Y, W, H = args[0], args[1], args[3], args[4]
    return reader.read(XYWH=(X, Y, W, H))

sizeX, sizeY, sizeC = (reader.rdr.getSizeX(), reader.rdr.getSizeY(), reader.rdr.getSizeC())
paddingsize = patchsize // 2
logger.debug('Read image')

In [None]:
%%time

profile_plots = {}
tiler = Tiler(data_shape=(sizeX, sizeY, sizeC),
            tile_shape=(tilesize, tilesize, sizeC),
            overlap=paddingsize,
            mode='constant')

detection_count = 0

# for tile_idx, tile in tqdm(tiler.iterate(reader_func), desc='Iterating tiles', total=len(tiler), leave=True):
for tile_idx in tqdm(range(len(tiler)), desc='Iterating tiles', leave=True):
    tile_bbox = tiler.get_tile_bbox_position(tile_idx)
    X, Y = tile_bbox[0][0:2]
    # shrink the box that searches for polygons so any patch would fit entirely inside the tile
    search_box = Polygon([
        (X+paddingsize, Y+paddingsize),
        (X+tilesize-paddingsize, Y+paddingsize),
        (X+tilesize-paddingsize, Y+tilesize-paddingsize),
        (X+paddingsize, Y+tilesize-paddingsize)
        ])

    hits = searchtree.query(search_box)
    if len(hits) < minhits:
        continue
    logger.debug('Found %d detections', len(hits))
    detection_count += len(hits)

    tile = tiler.get_tile(reader_func, tile_idx)

    # Make a list of ids and polygons (detections)
    ids = [hit.id for hit in hits]
    polygons = [allshapes[id_] for id_ in ids]
    imgs = []
    # Get image patches
    imgs = np.empty((len(hits), patchsize, patchsize, sizeC), dtype=tile.dtype)
    for idx, hit in enumerate(hits):
        # hit_X, hit_Y = hit.coords[:][0]
        c, r = int(hit.x - X), int(hit.y - Y)
        imgs[idx, ...] = tile[r - patchsize // 2: r + patchsize // 2, c - patchsize // 2: c + patchsize // 2, ...]

    # Create Patch objects
    patches = [Patch(imgs[idx, ...], polygons[idx], ids[idx]) for idx in range(len(hits))]
    logger.debug('Created %d patches', len(patches))

    # Process this batch
    for p in tqdm(patches, desc='Creating plots', leave=False):
        # Keep plots identifiable by their keys
        profs = get_radial_profiles(p.get_image(), p.get_mask(), 30, 50)
        profs_dict = {}
        for ch, titl in enumerate(titles):
            profs_dict[titl] = profs[ch]
        profile_plots[p.id] = profs_dict
logger.info('Processed %d detections', detection_count)

In [None]:
df_ids = []
df_plot_keys = []
df_profiles = []
# Format for pandas dataframe
for id_, profiles in profile_plots.items():
    for titl, arr in profiles.items():
        df_ids.append(id_)
        df_plot_keys.append(titl)
        df_profiles.append(arr)

df = pd.DataFrame({'plot_key': df_plot_keys, 'profiles': df_profiles}, index=df_ids)

In [None]:
for titl, color in zip(titles, colors):
    profs = df[df['plot_key'] == titl]['profiles'].array
    avg = np.mean(profs, axis=0)
    X = np.arange(0, len(avg))
    std_dev = np.std(profs.to_numpy(), axis=0)
    upper_err = avg + std_dev
    lower_err = avg - std_dev
    plt.plot(X, avg, c=color)
    # plt.errorbar(X, avg, std_dev, c=color)
    # plt.fill_between(X, upper_err, lower_err, color=color, alpha=0.1)