In [None]:
import matplotlib.pyplot as plt
import cv2

from byotrack import Video, VideoTransformConfig

from byotrack.implementation.detector.stardist import StarDistDetector

import skimage.io as iio
from byotrack.video.transforms import ChannelSelect, ChannelAvg, ScaleAndNormalize 
from PIL import Image
import numpy as np #can switch this out for pytorch at somepoint - notation is identical
from roifile import ImagejRoi, ROI_TYPE, ROI_OPTIONS

icy_path = "/home/noah/Documents/icy-2.4.2.0-all"
tifpath = '/home/noah/Desktop/cellsegtest/segTestNew/shortStack_adjusted' #path to sequence of tiff files

In [31]:
#reader funtion for tif sequences 
def Read_Data_TIFseq(vid_path):
    # positions = (pd.read_csv(csv_path,usecols=['TrackID','t','x','y'])).values
    vid = iio.ImageCollection(vid_path + '/*.tif').concatenate() #concatonate to numpyarray
    # red_vid = iio.ImageCollection(red_vid_path + '/*.tif')
    vid = vid.reshape(vid.shape[0], vid.shape[1], vid.shape[2], 1)
    #vid = np.asarray([csbdeepNormaliser(frame) for frame in vid])
    return vid


def coordReshaper_IJ(coords): #reshaping for use with imageJ rois
    coords_reshaped = []
    for i in range(len(coords[0])):
        xdata = coords[0][i]
        ydata = coords[1][i]
        coords_reshaped.append([ydata,xdata])
    return coords_reshaped

def reshape_all_rois(all_rois):
    allROIs = []
    for rois in all_rois:
        roishaped = coordReshaper_IJ(rois)
        allROIs.append(roishaped)
    return allROIs

def convert_to_ImageJ(allROIs):
    ijrois = []
    for roi in allROIs:
        roimask = ImagejRoi.frompoints(roi)
        roimask.roitype = ROI_TYPE.POLYGON
        roimask.options |= ROI_OPTIONS.SHOW_LABELS
        ijrois.append(roimask)
    return ijrois


def coordReshaper_CV_contours(coords):

    '''
    Reshape contours generated by StarDist for use with openCV's display contours function
    '''

    coords_reshaped = []
    for contour in np.flip(coords): #hard to find this solution - coords need to be flipped, they are read clockwise by openCV and I guess this is not how they are written by StarDist. Results in contours plotted transposed from desired without flipping array.:)
        for i in range(len(contour[0])):
            coords_reshaped.append([contour[0][i], contour[1][i]])
    
    cv_format_contours = (np.array(coords_reshaped).reshape((-1,1,2)).astype(np.int32)) #https://stackoverflow.com/questions/14161331/creating-your-own-contour-in-opencv-using-python
    return cv_format_contours

def generate_contours(image, detector):
    segmentation, data = detector._model.predict_instances(image, prob_thresh=detector.prob_threshold, nms_thresh=detector.nms_threshold, predict_kwargs={"verbose": 0})
    contours = coordReshaper_CV_contours(data['coord'])
    return contours


def visualise_all_contours_cv(image, cvContours, colour = None):
    '''
    Takes contours in opencv format and plots them over image
    https://stackoverflow.com/questions/57576686/how-to-overlay-segmented-image-on-top-of-main-image-in-python
    '''

    # image_contours = image.copy()
    if colour == None:
        colour = (0,255,255)
    # Iterate over all contours
    for i,c in enumerate(cvContours):
        try:
            
            #for different colours - can define colour in this loop for each contour
            
            # Outline contour in that colour on main image, line thickness=1
            cv2.drawContours(image,[c],-1,colour,1, cv2.LINE_8)
        except Exception as e:
            print(e)
            break

In [None]:
video = Read_Data_TIFseq(tifpath)

# normalize video
mini = np.quantile(video, 0.005)
maxi = np.quantile(video, 0.999)

np.clip(video, mini, maxi, video)
video = (video - mini) / (maxi - mini)

In [None]:
model_path = "/home/noah/Desktop/STARDIST_CONFOCAL/NEWEST TdT MODELS/10X_IMAGES_ONLY"
global detector
detector = StarDistDetector(model_path, batch_size=5)

In [None]:
#Set model parameters for your dataset

vidCopy = video[0:50] #test batch
scale = 1

global frameID
global frame
global frame_cv
global contours

frameID = 0
frame = vidCopy[frameID].copy()
frame_cv = cv2.normalize(src=frame, dst=None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U)
frame_cv = cv2.cvtColor(frame_cv, cv2.COLOR_GRAY2BGR)
h, w = frame.shape[0:2]

#initialise plotting parameters with default values - converted to 0-100 range for openCV trackbars
probabilityThreshold = int(detector.prob_threshold*100) 
nmsThreshold = int(detector.nms_threshold*100)

contours = generate_contours(frame, detector) 

# window_name = 'Frame', f'Frame {frameID} / {len(detections_sequence_test)} - Number of detections: {len(detections_sequence_test[i])}'
window_name = 'Paramater Test - Segmentation   (Press Q to Quit)'

try:

    #create and rescale window
    cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
    cv2.resizeWindow(window_name, h*scale, w*scale)

    #Frame Trackbar
    def update_frame(x): #callback function for trackbar - default argument is the position of the track bar
        global contours
        global frame
        global frame_cv

        frame = vidCopy[x].copy()
        frame_cv = cv2.normalize(src=frame, dst=None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U)
        frame_cv = cv2.cvtColor(frame_cv, cv2.COLOR_GRAY2BGR)
        contours = generate_contours(frame, detector) 
    cv2.createTrackbar('Frame',window_name,0,len(vidCopy)-1,update_frame)

    #Probability Trackbar
    def update_probability_threshold(x):
        detector.prob_threshold = (x+1)/100 
        update_frame(frameID)
    cv2.createTrackbar('Probability Threshold', window_name, probabilityThreshold, 99, update_probability_threshold)

    #Overlap Trackbar
    def update_overlap_threshold(x):
        detector.nms_threshold = (x+1)/100
        update_frame(frameID)
    cv2.createTrackbar('Overlap Threshold', window_name, nmsThreshold, 99, update_overlap_threshold)

except Exception as e:
    print(e)


while True:
    try:

        frameID = cv2.getTrackbarPos('Frame',window_name)
        # cv2.imshow(window_name, video[frameID])
        visualise_all_contours_cv(frame_cv, contours)
        cv2.imshow(window_name, frame_cv)

        probabilityThreshold = cv2.getTrackbarPos('Probability Threshold', window_name)/100
        nmsThreshold = cv2.getTrackbarPos('Overlap Threshold', window_name)/100

        #exit on q
        if cv2.waitKey(5) == ord('q'):
            # press q to terminate the loop
            cv2.destroyAllWindows()
            break

    except Exception as e:
        print(e)
        cv2.destroyAllWindows()
        break
    
cv2.destroyAllWindows()

print('Probability Threshold: ', probabilityThreshold)
print('nms Threshold: ', nmsThreshold)



In [None]:
#run detection on full video
print('Probability Threshold: ', probabilityThreshold)
print('nms Threshold: ', nmsThreshold)
user_confirmation = input('Confirm parameters (y/n)')

if user_confirmation == 'y':
    detections_sequence = detector.run(video)
else:
    print('Aborting Segmentstion')

In [42]:
print(detections_sequence[0].position.flip(1))

tensor([[341.7426, 594.4257],
        [ 78.0154, 477.9692],
        [607.3692, 375.9539],
        [168.2889, 362.2222],
        [126.6454, 271.5107],
        [103.2619, 408.1429],
        [730.7476, 709.3204],
        [737.7736, 418.2264],
        [681.3594, 811.9844],
        [617.7812, 365.4219],
        [306.7733, 315.9333],
        [657.5763, 663.4068],
        [743.6411, 396.0000],
        [729.8302, 409.3773],
        [374.0000, 436.0000],
        [255.3462, 600.2115],
        [709.2090, 520.7910],
        [289.8644, 210.9661],
        [306.2414, 604.7586],
        [850.1000, 448.3800],
        [796.4000, 907.8000],
        [125.3019, 574.3962],
        [703.4832, 483.6846],
        [505.7467, 896.4533],
        [245.3457, 644.8642],
        [291.4259, 848.7037],
        [296.8361, 319.9344],
        [225.6042, 696.1771],
        [539.7903, 659.3226],
        [490.5625, 477.9844],
        [315.1887, 504.5094],
        [495.8281, 567.4688],
        [266.3333, 594.4762],
        [8

In [None]:
#inspect all detections

#to inspect previously loaded detections - simply load the detections from a np array first

In [27]:
#save tracks as np array

save_path_numpy = '/home/noah/Documents/NoahT2022/CodeRepos/Utopia/ExampleData/shortStack_adjusted/detections'
detection_array = np.asarray(detections_sequence)
np.save(save_path_numpy, detection_array, allow_pickle=True)

In [35]:
# #save tracks as ImageJ ROIs
# save_path_IJ = '/home/noah/Documents/NoahT2022/CodeRepos/Utopia/ExampleOutputs'

# #detections objects do not contain polygon information, needs to be changed in byotracks or use stardist directly
# IJreshaped_rois = reshape_all_rois(detections_sequence)
# IJrois = convert_to_imageJ(IJreshaped_rois)


In [21]:
#TODO

#add code for inspecting detections (same as setting parameters but easier)

#add code for saving detections in imageJ and numpy format

#add colours to plots so overlapping segments can be distinguished

#fix plotting so contours are continous and smoother - perhaps this was the case when I was using other methods...

In [None]:
# imageJ roi style plotting

In [None]:
# cell seg style plotting

In [None]:
type(ijrois[0])

In [36]:
plot_image_overlays(image, ijrois)
print(data.keys())

NameError: name 'plot_image_overlays' is not defined

In [None]:
img_n = cv2.normalize(src=image, dst=None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U)

In [None]:
##cv contours
# image32 = image2.astype('uint8') #opencv needs float32, images loaded as float64 here :)
imagecv = cv2.cvtColor(img_n, cv2.COLOR_GRAY2BGR)

In [None]:
contours,_ = cv2.findContours(imagecv[:,:,0],cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE)

In [None]:
np.shape(contours)

In [None]:
data['coord'][0]

In [None]:
contours[0][0][0][0]

In [None]:
cvContours[0][0]

In [None]:
plt.imshow(imagecv)
print(imagecv.shape)
print(imagecv.dtype)

In [None]:
mask = np.zeros(imagecv.shape, np.uint8)
image_contours = imagecv.copy()
# Iterate over all contours
for i,c in enumerate(cvContours):
    try:
        # Find mean colour inside this contour by doing a masked mean
        # mask = np.zeros(image_contours.shape, np.uint8)
        # cv2.drawContours(mask,[c],-1,255, -1)
        # DEBUG: cv2.imwrite(f"mask-{i}.png",mask)
        # print(':)')
        # mean,_,_,_ = cv2.mean(imagecv, mask=mask)
        # DEBUG: print(f"i: {i}, mean: {mean}")

        # Get appropriate colour for this label
        # label = 2 if mean > 1.0 else 1
        colour = (0,255,255)
        # DEBUG: print(f"Colour: {colour}")
        print(c)

        # Outline contour in that colour on main image, line thickness=1
        cv2.drawContours(image_contours,[c],-1,colour,1)
    except Exception as e:
        print(e)

        break
try:
    cv2.namedWindow('contour', cv2.WINDOW_NORMAL)
    cv2.imshow('contour',image_contours) 

    cv2.waitKey(0)
    cv2.destroyAllWindows()
except Exception as e:
    print(e)
    cv2.destroyAllWindows()

# try:
#     cv2.drawContours(mask,[c],-1,255, -1)
# except Exception as e:
#     print(e)

In [None]:
cv2.namedWindow('contour', cv2.WINDOW_NORMAL)
cv2.imshow('contour', imagecv)
cv2.waitKey(0)
  
# closing all open windows
cv2.destroyAllWindows()

In [None]:
#evaluate model paramters on data

# Do not run this cell in order to keep the defaults prob_thresh and nms_thresh

#TODO: add roi funtionality to this so plotting shows mask over image (better evaluation) Will need to edit both stardist detector and the detector class .detect() function to include polygon data and add imageJ ROI code
vidCopy = video[0:50] #test batch
scale = 1

global frameID
frameID = 0
frame = video[frameID]
h, w = frame.shape[0:2]
global mask_glob
detection_zero = detector.run([vidCopy[frameID]])
mask_glob = (detection_zero[frameID].segmentation.numpy() != 0).astype(np.uint8) * 255

# window_name = 'Frame', f'Frame {frameID} / {len(detections_sequence_test)} - Number of detections: {len(detections_sequence_test[i])}'
window_name = 'Paramater Test - Segmentation   (Press Q to Quit)'

try:

    #create and rescale window
    cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
    cv2.resizeWindow(window_name, h*scale, w*scale)

    #Frame Trackbar
    def update_frame(x): #callback function for trackbar - default argument is the position of the track bar
        detections = detector.detect(vidCopy[x][None, ...])
        global mask_glob
        mask_glob = (detections[0].segmentation.numpy() != 0).astype(np.uint8) * 255
    cv2.createTrackbar('Frame',window_name,0,len(vidCopy)-1,update_frame)

    #Probability Trackbar
    def update_probability_threshold(x):
        detector.prob_threshold = (x+1)/100 
        update_frame(frameID)
    cv2.createTrackbar('Probability Threshold', window_name, 0, 99, update_probability_threshold)

    #Overlap Trackbar
    def update_overlap_threshold(x):
        detector.nms_threshold = (x+1)/100
        update_frame(frameID)
    cv2.createTrackbar('Overlap Threshold', window_name, 0, 99, update_overlap_threshold)

except Exception as e:
    print(e)


while True:
    try:

        frameID = cv2.getTrackbarPos('Frame',window_name)
        # cv2.imshow(window_name, video[frameID])
        cv2.imshow(window_name, mask_glob)

        probabilityThreshold = cv2.getTrackbarPos('Probability Threshold', window_name)/100
        nmsThreshold = cv2.getTrackbarPos('Overlap Threshold', window_name)/100

        #exit on q
        if cv2.waitKey(5) == ord('q'):
            # press q to terminate the loop
            cv2.destroyAllWindows()
            break

    except Exception as e:
        print(e)
        cv2.destroyAllWindows()
        break
    
cv2.destroyAllWindows()
print('Prob: ', probabilityThreshold)
print('nms: ', nmsThreshold)

In [None]:
from roifile import ImagejRoi, ROI_TYPE, ROI_OPTIONS
roi = ImagejRoi.frompoints(coords)
roi.roitype = ROI_TYPE.POLYGON
roi.options |= ROI_OPTIONS.SHOW_LABELS

In [None]:
#reader funtion for tif sequences - handles reshaping and normalising (using the stardist recommended normaliser) (doesn't seem to perform that well)
from csbdeep.utils import normalize as csbdeepNormaliser
    
def Read_Data_TIFseq(vid_path):
    # positions = (pd.read_csv(csv_path,usecols=['TrackID','t','x','y'])).values
    vid = iio.ImageCollection(vid_path + '/*.tif').concatenate() #concatonate to numpyarray #not concatenating allows for dynamic loading
    # red_vid = iio.ImageCollection(red_vid_path + '/*.tif')
    vid = vid.reshape(vid.shape[0], vid.shape[1], vid.shape[2], 1)
    normalisedVid = np.asarray([csbdeepNormaliser(frame) for frame in vid])
    return normalisedVid

video = Read_Data_TIFseq(tifpath)

In [None]:
import colorsys
from scipy.ndimage.morphology import binary_dilation, binary_erosion
from skimage.morphology import disk
import random
import pandas as pd
from matplotlib.ticker import NullLocator

In [None]:
#vis from cellseg

def get_bounding_box(masks):
    indices = np.where(masks != 0)
    values = masks[indices[0], indices[1]]
    maskframe = pd.DataFrame(np.transpose(np.array([indices[0], indices[1], values]))).rename(columns = {0:"y", 1:"x", 2:"id"})
    bb_mins = maskframe.groupby('id').agg({'y': 'min', 'x': 'min'}).to_records(index = False).tolist()
    bb_maxes = maskframe.groupby('id').agg({'y': 'max', 'x': 'max'}).to_records(index = False).tolist()
    
    return bb_mins, bb_maxes



def compute_snippet_bounds(minY, minX, maxY, maxX, Y, X):
    if minX < 0: minX = 0
    if minY < 0: minY = 0
    if maxX >= X: maxX = X - 1
    if maxY >= Y: maxY = Y - 1
        
    return minY, minX, maxY, maxX


def extract_snippet(Y, X, masks, mins, maxes):
    
    minY, minX, maxY, maxX = compute_snippet_bounds(mins[0] - 1, mins[1] - 1, maxes[0] + 1, maxes[1] + 1, Y, X)
        
    return masks[(minY):(maxY), (minX):(maxX)], minY, minX, maxY, maxX

def get_mask_ids(masks):
    maskids = list(np.unique(masks))
    maskids.remove(0)
    maskids.sort()
    
    return maskids

def random_colors(N, bright=True):
    """
    Generate random colors.
    To get visually distinct colors, generate them in HSV space then
    convert to RGB.
    """
    brightness = 1.0 if bright else 0.7
    hsv = [(i / N, 1, brightness) for i in range(N)]
    colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
    colors = [(int(i[0] * 255), int(i[1] * 255), int(i[2] * 255)) for i in colors]
    random.shuffle(colors)
    return colors

def generate_mask_outlines(masks):
    
    #first compute bounding boxes
    maskids = get_mask_ids(masks)
    num_masks = len(maskids)

    bb_mins, bb_maxes = get_bounding_box(masks)

    Y, X = masks.shape

    output_im = np.zeros(masks.shape, dtype = np.uint32)

    struc = disk(1)
    for i, maskid in enumerate(maskids):

        currreg, minY, minX, maxY, maxX = extract_snippet(Y, X, masks, bb_mins[i], bb_maxes[i])

        mask_snippet = (currreg == maskid)
        interior = binary_erosion(mask_snippet, struc)
        boundary = mask_snippet ^ interior

        pix_to_update = np.nonzero(boundary)

        pix_X = np.array([min(j + minX, X) for j in pix_to_update[1]])
        pix_Y = np.array([min(j + minY, Y) for j in pix_to_update[0]])

        output_im[pix_Y, pix_X] = maskid

    return output_im
    
def overlay_outlines_and_save(image, masks, outputpath, figsize, colors = None):
    
    auto_show = False
    _, ax = plt.subplots(1, figsize=figsize)

    maskids = get_mask_ids(masks)
    num_masks = len(maskids)

    # Generate random colors
    colors = colors or random_colors(num_masks)

    bb_mins, bb_maxes = get_bounding_box(masks)

    #rgb_im = cv2.cvtColor(nimage, cv2.COLOR_GRAY2RGB)
    rgb_im = image

    rgb_im = rgb_im.astype(np.uint8)

    Y, X = masks.shape

    for i, maskid in enumerate(maskids):

        currreg, minY, minX, maxY, maxX = extract_snippet(Y, X, masks, bb_mins[i], bb_maxes[i])
        mask_snippet = (currreg == maskid)


        color = colors[i]
        #if i < 10:
        #    print(color)
        #    plt.imshow(mask_snippet)
       #     plt.show()
       #     plt.close()
        pix_to_update = np.nonzero(mask_snippet)

        #minY, minX, maxY, maxX = compute_snippet_bounds(bb_mins[i][0] - 1, bb_mins[i][1] - 1, bb_maxes[i][0] + 1, bb_maxes[i][1] + 1, Y, X)

        pix_X = np.array([min(j + minX, X) for j in pix_to_update[1]])
        pix_Y = np.array([min(j + minY, Y) for j in pix_to_update[0]])

        rgb_im[pix_Y, pix_X, :] = color
        #rgb_im[pix_Y, pix_X, 1] = 255
        #rgb_im[pix_Y, pix_X, 2] = 255

    ax.axis('off')
    img = ax.imshow(rgb_im)
    # This is needed to remove all whitespace
    plt.gca().set_axis_off()
    plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, 
                hspace = 0, wspace = 0)
    plt.margins(0,0)
    plt.gca().xaxis.set_major_locator(NullLocator())
    plt.gca().yaxis.set_major_locator(NullLocator())
    plt.savefig(outputpath, dpi=75)
    plt.close()

In [None]:
def maskplotter(image, segmentation, colors = None):

    # image=image.reshape(image.shape[0], image.shape[1])
    masks = segmentation
    maskids = get_mask_ids(masks)
    num_masks = len(maskids)

    # Generate random colors
    colors = colors or random_colors(num_masks)

    bb_mins, bb_maxes = get_bounding_box(masks)

    #rgb_im = cv2.cvtColor(nimage, cv2.COLOR_GRAY2RGB)
    rgb_im = image

    rgb_im = rgb_im.astype(np.uint8)
    print(rgb_im.shape)

    Y, X = masks.shape

    for i, maskid in enumerate(maskids):

        currreg, minY, minX, maxY, maxX = extract_snippet(Y, X, masks, bb_mins[i], bb_maxes[i])
        mask_snippet = (currreg == maskid)


        color = colors[i]
        #if i < 10:
        #    print(color)
        #    plt.imshow(mask_snippet)
        #     plt.show()
        #     plt.close()
        pix_to_update = np.nonzero(mask_snippet)

        #minY, minX, maxY, maxX = compute_snippet_bounds(bb_mins[i][0] - 1, bb_mins[i][1] - 1, bb_maxes[i][0] + 1, bb_maxes[i][1] + 1, Y, X)

        pix_X = np.array([min(j + minX, X) for j in pix_to_update[1]])
        pix_Y = np.array([min(j + minY, Y) for j in pix_to_update[0]])

        rgb_im[pix_Y, pix_X, :] = color

    return rgb_im

In [None]:
dummy = np.zeros(image.shape)
img2 = cv2.merge((dummy,image,dummy))

In [None]:
plt.imshow(img2)
plt.show()

In [None]:
mask = (segmentation != 0).astype(np.uint8) * 255
masksim  = generate_mask_outlines(mask)
colorMaskIm = maskplotter(img2, mask)

In [None]:
import tifffile
path = '/home/noah/Desktop/cellsegtest/segTestNew/shortStack_adjusted/camera2_NDTiffStack0002.tif'
image = np.array(tifffile.imread(path))
image = image.reshape(image.shape[0], image.shape[1], 1)
print(image.shape)
shape = image.shape
SHAPE = (shape[1], shape[2], shape[0])
image = np.transpose(image, (1, 2, 0))
image = image.reshape(SHAPE)

image.shape



In [None]:
overlay_outlines_and_save(image, mask, '/home/noah/Documents/NoahT2022/CodeRepos/Utopia/ExampleData', (30,30), colors = None)

In [None]:
_, ax = plt.subplots(1)
ax.axis('off')
img = ax.imshow(colorMaskIm)
# This is needed to remove all whitespace
plt.gca().set_axis_off()
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, 
            hspace = 0, wspace = 0)
plt.margins(0,0)
plt.gca().xaxis.set_major_locator(NullLocator())
plt.gca().yaxis.set_major_locator(NullLocator())
# plt.savefig(outputpath, dpi=75)
plt.show()

In [None]:
newim = masksim + image.reshape(image.shape[0], image.shape[1])
print(masksim.shape)
print(image.shape)
print(newim.shape)
plt.imshow(newim)
plt.axis('off')
plt.show()

In [None]:
imgmask = cv2.merge((newim,dummy,dummy))
plt.imshow(imgmask)
plt.show()

In [None]:
idx=~(masksim==0)
# newim = np.where(idx, image[::,0], image[::,0])
# newim2 = np.putmask(image[::,0], idx, 255)

indx = idx.nonzero()
newim3 = image.reshape(image.shape[0], image.shape[1])
newim3 = img2
print(newim3.shape)
newim3[indx][2] = 0.5
plt.imshow(newim3[::])
plt.show()

print(newim3.max())
print(newim3.shape)


In [None]:
newim.shape
plt.imshow(idx)
idx

In [None]:
plt.imshow(newim)