This script parses the PlantSeg LateralRootPrimordia dataset that we used in our paper. It comes in HDF5 format, and many have an 'ignore' label that will not work with Omnipose without some special revisions. We tossed out any data with this -1 label, but kept the files in their 'test' or 'train' folders. The images were downsampled by 1/3 to get cells in the range of 20-30px in diameter. This was done only to speed up training and evaluation with far smaller volumes. 

In [None]:
%load_ext autoreload
%autoreload 2


import numpy as np
import time, os, sys

import skimage.io
import matplotlib.pyplot as plt
plt.style.use('dark_background')
import matplotlib as mpl
%matplotlib inline
mpl.rcParams['figure.dpi'] = 300

from scipy.ndimage import binary_closing,grey_closing
from skimage import measure
import napari 
import omnipose
import ncolor 
from omnipose.utils import sinebow

from scipy.ndimage import gaussian_filter
from omnipose.utils import normalize99, rescale
from scipy.ndimage import gaussian_filter

def localnormalize(im,sigma1=2,sigma2=3):
    im = normalize99(im)
    blur1 = gaussian_filter(im,sigma=sigma1)
    num = im - blur1
    blur2 = gaussian_filter(num*num, sigma=sigma2)
    den = np.sqrt(blur2)
    
    return normalize99(num/den+1e-8)

def cyclic_perm(a):
    n = len(a)
    b = [[a[i - j] for i in range(n)] for j in range(n)]
    return b

import tifffile
from omnipose import core
import cellpose.core
use_GPU = cellpose.core.use_gpu()

from matplotlib.colors import ListedColormap


In [None]:
from cellpose import io
save0 = '/home/kcutler/DataDrive/plantseg/traintest/LateralRootPrimordia/'
export = '/home/kcutler/DataDrive/plantseg/traintest/LateralRootPrimordia/export_good'
do_all = False # training was done on the limited dataset; testing done on full, ignoring other regions?
if not do_all:
    export = '/home/kcutler/DataDrive/plantseg/traintest/LateralRootPrimordia/export_small'
else:
    export = '/home/kcutler/DataDrive/plantseg/traintest/LateralRootPrimordia/export_small_all'
io.check_dir(export)
savedir = []
expdir = []
cpdir = []
ext = '.tif'
subset = ['test','train']
for s in subset:
    sd = os.path.join(save0,s)
    savedir.append(sd)
    sd = os.path.join(export,s)
    io.check_dir(sd)
    expdir.append(sd)
    sd = os.path.join(export,'cellpose',s)
    io.check_dir(sd)
    cpdir.append(sd)
savedir,expdir,cpdir

In [None]:
# Make a function that extracts the file name 
from pathlib import Path
def getname(path,suffix=''):
    return os.path.splitext(Path(path).name)[0].replace(suffix,'')

In [None]:
#prepare to read from the hd5 files 
import h5py, fastremap
d = []
names = []
ext = '.tif'
# bad = [4,6,10,12,14,16,18,20]

In [None]:
import scipy
downsize = True
factor = 1/3

Define scale bar

In [None]:
ZYX = np.array([0.25,0.1625,0.1625])/factor #micron/px
bar_10 = 10/ZYX #<a>px*(ZYX micron/px) = b micron => b/ZYX = a px
bar_10

Load data and format

In [None]:
clean = 0 #toggle
if clean:
    for folder,newfolder in zip(savedir,expdir):
        image_names = io.get_image_files(folder,extensions=['h5'])
        for filename in image_names:
            with h5py.File(filename, "r") as f:
                name = getname(filename)
                # if not np.any(["%04d" % b in name for b in bad]):
                # if 'Movie1_t00030' in name:
                    # List all groups
                print("Keys: %s" % f.keys())

                masks = np.array(f['label'])            
                # if do_all:
                #     masks[masks<1] = 0 # just set all ignore regions to zero
                mn = masks.min()
                print(mn,name)
                if masks.min()==1 or do_all: # seems like 0 is ignore, 1 is background. For training, I just discarded any with 0. 
                    masks = omnipose.utils.format_labels(masks,ignore=do_all) # this corrects for any duplicate labes on distinct cells
                    print('newmin0',masks.min(),'same?',masks.min()==mn)
                    img = np.array(f['raw'])
                    fastremap.renumber(masks,in_place=True)
                    print('newmin1',masks.min(),'same?',masks.min()==mn)
                    labels = fastremap.refit(masks)
                    print('newmin2',labels.min(),'same?',labels.min()==mn)
                    if downsize:
                        labels = scipy.ndimage.zoom(labels, factor, order=0, mode='nearest') #nearest prevents a single-picel laye rof zeros at the edge...
                        img = scipy.ndimage.zoom(img, factor, order=1)
                    print('newmin3',labels.min(),'same?',labels.min()==mn)
                    # tifffile.imwrite(os.path.join(newfolder,name+ext),np.uint8(omnipose.utils.normalize99(img)*(2**8-1)))
                    tifffile.imwrite(os.path.join(newfolder,name+ext),img)
                    tifffile.imwrite(os.path.join(newfolder,name+'_masks'+ext),labels)
                    d.append(omnipose.core.diameters(labels))
                    # mlist.append(masks)
                    names.append(name)

The above files will be used for Omnipose in 3D. To train Cellpose, we have to slice these volumes along each axis. This will produce a lot of images. I'll just be doing it on the AWS server prior to training. Turns out this is 3,070 training images. 

In [None]:
do_slice = 0 #toggle
if do_slice:
    d = 3
    idx = np.arange(d)
    c = np.array([0]*(d-2)+[1]*2)
    cyclic_perm(c),idx,c
    ext = '.tif'
    a = 'zyx'
    # for basedir,save0 in zip(expdir,cpdir):
    for basedir,save0 in zip([expdir[1]],[cpdir[1]]): #only do train 

        mask_filter = '_masks'
        img_names = io.get_image_files(basedir,mask_filter)
        mask_names,_ = io.get_label_files(img_names, mask_filter)

        for p1,p2 in zip(img_names,mask_names):
            img = tifffile.imread(p1)
            mask = tifffile.imread(p2)
            s = img.shape
            name = getname(p1)

            for inds,i in zip(cyclic_perm(c),idx):
                for k in range(s[i]):
                    slc = tuple([slice(-1) if i else k for i in inds])
                    suffix = '_'+a[i]+'{:03d}'.format(k)
                    l = omnipose.utils.format_labels(mask[slc],clean=True) # this corrects for any duplicate labes on distinct cells
                    if np.any(l): # only keep slices that have masks
                        tifffile.imwrite(os.path.join(save0,name+suffix+ext),img[slc])
                        tifffile.imwrite(os.path.join(save0,name+suffix+'_masks'+ext),l)

### Load data

In [None]:
imgs = []
masks_gt = []
text = []
k = 0 
# for k in [0]: #just test files now, switch to 1 to look at training set
for k in range(2): 
    for basedir,save0 in zip([expdir[k]],[cpdir[k]]): 
        mask_filter = '_masks'
        img_names = io.get_image_files(basedir,mask_filter)
        mask_names = io.get_label_files(img_names, mask_filter)

        for p1,p2 in zip(img_names,mask_names):
            print(p1)
            imgs.append(tifffile.imread(p1))
            mgt = tifffile.imread(p2)
            mgt = omnipose.utils.format_labels(mgt,clean=True,verbose=False,ignore=do_all)
            masks_gt.append(mgt)
    # count cells in test set
    c = 0
    for mgt in masks_gt:
        c+=len(fastremap.unique(mgt))-1 # don't count background
    
    text += ['number of cells in {} is {}'.format(subset[k],c)]

In [None]:
basedir = '/home/kcutler/DataDrive/omnipose_all'
with open(os.path.join(basedir,'plant_dataset{}_stats.txt'.format(['_all' if do_all else ''])), "w") as text_file:
    [print(t,file=f) for t in text for f in [None,text_file]]

In [None]:
fastremap.unique(mgt)

In [None]:
# check the data
n = 4
viewer = napari.view_image(omnipose.utils.rescale(imgs[n]), name='cells',gamma=0.2,attenuation=0.05,depiction='volume')
mnc = ncolor.label(masks_gt[n],max_depth=11)
viewer.add_labels(mnc, name='labels_gt',visible = False, color=sinebow(mnc.max()))
viewer.add_labels(masks_gt[n])

In [None]:
[m.min() for m in masks_gt]

In [None]:
viewer = napari.view_image(omnipose.utils.rescale(img), name='cells',gamma=0.2,attenuation=0.05,depiction='volume')
mnc = ncolor.label(labels,max_depth=11)
viewer.add_labels(mnc, name='labels_gt',visible = False, color=sinebow(mnc.max()))

In [None]:
# save slices theough each axis for figures
%matplotlib inline
from omnipose.utils import rescale, sinebow
from cellpose import plot

cmap = plt.get_cmap(name='viridis')
x = np.linspace(0,1,1000)
colors=cmap(1-x)
colors[...,-1] = x
cmap = {
    'colors': colors,
    'name': 'custom',
    'interpolation': 'linear'
}
cmap = ListedColormap(list(colors))

save0 = os.path.join('/home/kcutler/DataDrive/omnipose_paper/Figure 7','slices',subset[k])
io.check_dir(save0)

fig = plt.figure(figsize=[20]*2)
def cyclic_perm(a):
    n = len(a)
    b = [[a[i - j] for i in range(n)] for j in range(n)]
    return b

for i,name in enumerate([names[0]]):
    print(name,i)
    for j in range(len(imgs)):
        img = omnipose.utils.normalize99(imgs[j]**.2)
        mgt = masks_gt[j]
        mnc = ncolor.label(mgt)
        smap = sinebow(mnc.max()+1)
        smap = ListedColormap(list(smap.values()))
        # plot flow and dist slices 
        d = img.ndim
        c = np.array([0]*(d-2)+[1]*2)

        slices = []
        idx = np.arange(d)
        for inds in cyclic_perm(c):
            slc = tuple([slice(None) if i else img.shape[k]//2 for i,k in zip(inds,idx)])
            # mgtnc = ncolor.label(mgt[slc])
            mgtnc = mnc[slc]
            
            pic = smap(mgtnc)
            pic[...,-1] = mgtnc>0
            tifffile.imwrite(os.path.join(save0,'slice_'+str(j)+'_'+str(inds)+'_gt.tif'),np.uint8(omnipose.utils.rescale(pic)*(2**8-1)))
            plt.imshow(mgt[slc],interpolation='none')
            pic = cmap(img[slc])
            # pic[...,-1] = imgs[j][slc]
            tifffile.imwrite(os.path.join(save0,'slice_'+str(j)+'_'+str(inds)+'_img.tif'),np.uint8(omnipose.utils.rescale(pic)*(2**8-1)))
            # plt.imshow(pic)
            # plt.imshow(mnc[slc],interpolation='none')
            plt.axis('off')
            plt.show()

### Plot in 3D
Plot for figure 7 depicting ground truth distance field and how Euler integration works given a perfect flow prediction. 

In [None]:
# im = imgs[0]
# mask = masks_gt[0]
name = 'Movie3_T00004'
name = 'Movie2_T00016'
# name = 'Movie2_T00002'
im = tifffile.imread('/home/kcutler/DataDrive/plantseg/traintest/LateralRootPrimordia/export_small_all/train/'+name+'_crop_gt.tif')
mask = tifffile.imread('/home/kcutler/DataDrive/plantseg/traintest/LateralRootPrimordia/export_small_all/train/'+name+'_crop_gt_masks.tif')
mask = (mask-1).clip(0,mask.max()-1) #format 
dim = mask.ndim


In [None]:
# select a cell for demonstation; clean up a little with binary closing 
from scipy.ndimage import binary_closing,binary_opening
# from skimage.morphology import ball, closing
cell_ID = 26
cellmask = binary_opening(mask==cell_ID,iterations=2)
# cellmask = closing(mask==cell_ID,ball(11))
# viewer = napari.view_labels(cellmask, depiction='volume')
# viewer.add_labels(mask==cell_ID, name='original',depiction='volume')

In [None]:
_, dt, sd, mu = omnipose.core.masks_to_flows(cellmask,use_gpu=True,dim=dim,omni=True)

In [None]:
viewer = napari.view_image(omnipose.utils.rescale(sd), name='distance',gamma=1,depiction='volume', visible=0, scale=ZYX)
# viewer.add_image(omnipose.utils.rescale(dt==1), name='boundary',visible=True)
mnc = ncolor.label(mask,max_depth=11,offset=1)

viewer.add_labels(mnc, name='labels_gt_ncolor',visible = 0, color=sinebow(mnc.max()))

In [None]:
mnc_nocell = mnc.copy()
mnc_nocell[mask==cell_ID] = 0
viewer.add_labels(mnc_nocell, name='labels_gt_ncolor_nocell',visible = 1)

In [None]:
cmap = plt.get_cmap(name='gray')
colors=cmap(np.linspace(0,1,mnc.max()))
colordict = {0:[0,0,0,0]}
N = mnc.max()
for j in range(N): 
    colordict.update({j+1:cmap(((j+1)/N)*.6 +.2)})

# viewer.layers['labels_gt_ncolor'].color = colordict
viewer.layers['labels_gt_ncolor_nocell'].color = colordict
gray_cmap = ListedColormap(list(colordict.values()))
# mnc.max()
# sinebow(mnc.max())
# dict(colors)
colordict

In [None]:
viewer.camera
# viewer.add_labels(mask==cell_ID, name='cellmask_gt',visible = 1,color=sinebow(2))

In [None]:
viewer.add_labels(mask==cell_ID, name='cellmask_gt',visible = 1,color=sinebow(1),blending='opaque')

In [None]:
viewer.dims.ndisplay = 3
# viewer.camera.center = [s//2 for s in im.shape]
viewer.camera.zoom=2.4
# viewer.camera.angles=(49.01724873267338, -59.486485045019705, 36.815397248657874)
viewer.camera.center = np.mean(np.where(mask==cell_ID),axis=1)
viewer.camera.angles = (-87.82189877549096, 4.423733296006766, 32.36094416005523)
viewer.camera.perspective = 0.0
viewer.camera.interactive = True

In [None]:
viewer.scale_bar.visible = True
viewer.scale_bar.unit = "um"

save slices showing how the cell is placed relative to the others

In [None]:
save0 = os.path.join('/home/kcutler/DataDrive/omnipose_paper/PlantSeg','slices')
io.check_dir(save0)

slices = [tuple([27,slice(None),slice(None)]),tuple([slice(None),slice(None),315])]
smap = sinebow(mnc.max()+1)
smap = ListedColormap(list(smap.values()))
cmap = plt.get_cmap(name='viridis')
x = np.linspace(0,1,1000)
colors=cmap(1-x)
colors[...,-1] = x
cmap = {
    'colors': colors,
    'name': 'custom',
    'interpolation': 'linear'
}
cmap = ListedColormap(list(colors))

gmap = plt.get_cmap(name='gray')

for slc in slices:
    mgtnc = mnc[slc]
    # pic = gmap(mgtnc)
    # pic = omnipose.utils.rescale(np.stack([mgtnc]*4,axis=-1))*.6 +.2
    pic = gray_cmap(mgtnc)
    
    print(pic.shape)
    pic[...,-1] = mgtnc>0
    tifffile.imwrite(os.path.join(save0,'slice_'+str(slc)+'_gt.tif'),np.uint8(omnipose.utils.rescale(pic)*(2**8-1)))
    
    cm = mask[slc]==cell_ID
    pic = smap(cm)
    pic[...,-1] = cm>0
    tifffile.imwrite(os.path.join(save0,'slice_'+str(slc)+'_gt_cell'+str(cell_ID)+'.tif'),np.uint8(omnipose.utils.rescale(pic)*(2**8-1)))
    # plt.imshow(mgt[slc],interpolation='none')
    plt.imshow(pic,interpolation='none')

    
    
    pic = cmap(omnipose.utils.normalize99(im[slc]**.2))
    # pic[...,-1] = imgs[j][slc]
    tifffile.imwrite(os.path.join(save0,'slice_'+str(slc)+'_img.tif'),np.uint8(omnipose.utils.rescale(pic)*(2**8-1)))
    # plt.imshow(pic)
    # plt.imshow(mnc[slc],interpolation='none')
    plt.axis('off')
    plt.show()

# io.imsave(os.path.join(basedir,'3D_'+name+'_seg.png'),img[slc])

In [None]:
plt.imshow(mask[slc]==26)

In [None]:
# viewer.dims.ndisplay = 3
# # viewer.camera.center = [s//2 for s in im.shape]
# viewer.camera.zoom=2
# # viewer.camera.angles=(49.01724873267338, -59.486485045019705, 36.815397248657874)
# viewer.camera.center = (63.78328046536635, 57.61679613498146, 240.23463821506314)
# viewer.camera.angles = (-93.20291090101367, -7.291412290975056, -89.99473338269273)
# viewer.camera.perspective = 0.0
# viewer.camera.interactive = True


In [None]:

cmap = plt.get_cmap(name='magma')
x = np.linspace(0,1,100)
# colors=cmap(1-x)
colors = cmap(x)
colors[...,-1] = x
# colors[0,...,-1] = 0
new_colormap = {
    'colors': colors,
    'name': 'custom',
    'interpolation': 'linear'
}
for key in ['distance']:
    viewer.layers[key].visible = 1
    viewer.layers[key].colormap = new_colormap
    viewer.layers[key].rendering='mip'#'translucent'
    viewer.layers[key].contrast_limits=[0,1]

In [None]:
# viewer.add_image(omnipose.utils.rescale(im), name='cells',gamma=.2,depiction='volume', visible=0)
cmap = plt.get_cmap(name='viridis')
x = np.linspace(0,1,1000)
colors=cmap(1-x)
colors[...,-1] = x
cmap = {
    'colors': colors,
    'name': 'custom',
    'interpolation': 'linear'
}
for key in ['cells']:
    viewer.layers[key].colormap = cmap
    viewer.layers[key].rendering='mip'#'translucent'
    viewer.layers[key].contrast_limits=[0,1]

In [None]:
# add scale bars
n_axes = 3
vectors = np.zeros((n_axes, 2, 3), dtype=np.float32)
x_px = np.array([1, 0, 0])
y_py = np.array([0, 1, 0])
z_pz = np.array([0, 0, 1])
vectors[0, 1] = x_px
vectors[1, 1] = y_py
vectors[2, 1] = z_pz
for k in range(im.ndim):
    viewer.add_vectors(vectors[k], name='ax'+str(k), edge_width=1,edge_color='k',blending='translucent_no_depth', length=bar_10[k])

In [None]:
viewer.add_vectors(vectors[k], name='ax'+str(k), edge_width=1,edge_color='k',blending='translucent_no_depth', length=im.shape[k])

### Show Euler integration 

In [None]:
dists = dt.copy()

# inds = np.array(np.nonzero(mask==cell_ID)).astype(np.int32)
inds = np.array(np.nonzero(cellmask)).astype(np.int32)

masks_pred, p, tr = omnipose.core.compute_masks(mu,dists,inds=inds,calc_trace=True,use_gpu=1,
                                           omni=1,niter=200,interp=1,mask_threshold=-1,verbose=True,
                                          flow_threshold = 0.0, min_size=50, dim=3, cluster=1)

In [None]:
mnc = ncolor.label(masks_pred,max_depth=11)
viewer.add_labels(mnc, name='labels_pred',visible = True, color=sinebow(mnc.max()))

In [None]:
import edt
dist = edt.edt(masks_pred)
inds = np.array(np.nonzero(masks_pred)).astype(np.int32)

In [None]:
# bd_inds = np.array(np.nonzero(dists[tuple(inds)]==1))
bd_inds = np.hstack(np.argwhere(dist[tuple(inds)]==1)) #equivalent, better shape
tr[:,bd_inds,:].shape

In [None]:
bdy_inds = np.hstack(np.argwhere(dist[tuple(inds)]>0))
bdy_inds.shape

In [None]:
tracks = []
for ID in bd_inds:
# for ID in bdy_inds:
# for ID in np.random.choice(bd_inds,len(bd_inds)//1):

    track = tr[:,ID,:].T
    times = np.arange(0,track.shape[0])[...,np.newaxis]
    IDs = np.ones_like(times)*ID
    
#      # calculate the speed as a feature
#     gz = np.gradient(track[:, 2])
#     gy = np.gradient(track[:, 3])
#     gx = np.gradient(track[:, 4])
    g = np.stack([np.gradient(track[:,k]) for k in range(track.shape[1])],axis=1)
    speed = np.sum(g**2,axis=1)**0.5
    speed = speed[:,np.newaxis]
    distance = np.sum(speed)
    dist = np.ones_like(times)*distance
    cumdist = np.cumsum(speed)[:,np.newaxis]

    tracks.append(np.concatenate([IDs,times,track, dist, cumdist, speed],axis=-1))

tracks = np.concatenate(tracks, axis=0)
features = {
    'time': tracks[:, 1],
    'dist': tracks[:, -3],
    'cumdist': tracks[:, -2],
    'speed': tracks[:, -1]
}
    #     'gradient_z': tracks[:, 5],
    #     'gradient_y': tracks[:, 6],
    #     'gradient_x': tracks[:, 7],
    #     'speed': tracks[:, 8],
    #     'distance': tracks[:, 9],
    # }

In [None]:
g = np.stack([np.gradient(track[:,k]) for k in range(track.shape[1])],axis=1)

In [None]:
# track.shape, g.shape, 
# speed = np.sum(g**2,axis=1)**0.5
# distance = np.sum(speed)
# distance
viewer.layers.translate(0,0,0,1)

In [None]:
viewer.add_tracks(tracks[:,:5],colormap='red',features=features, color_by='speed', blending='translucent',tail_length=1000)

In [None]:
points = tracks[:, 1:5]
point_properties = {
    'color': 1,
    'dist': tracks[:, -3],
    'speed': tracks[:, -1]
}

points_layer = viewer.add_points(
    points,
    properties=point_properties,
    face_colormap='red',
    face_color='dist',
    # face_color_by='dist',
    # face_color='color',
    # face_color_cycle=['red', 'green'],
    edge_width=0,
    size=1, blending='translucent_no_depth'
)

In [None]:
# capture extra wide view of masks

figdir = '/home/kcutler/DataDrive/omnipose_paper/current/Figure 6'

# toggle all layers off
for layer in  viewer.layers:
    layer.visible = False

# toggle specific one(s) on 
layerlist = ['labels_pred','labels_gt_ncolor_nocell']
for name in layerlist:
    viewer.layers[name].visible = True

# set slicing windiw
S = 2000
c = S//2
w = 400
slc = tuple([slice(None),slice(c-w,c+w)])


img = viewer.screenshot(size=(S,S),scale=1,canvas_only=True,flash=False)
io.imsave(os.path.join(figdir,'red_select.png'),img[slc])

In [None]:
# capture view of distance

figdir = '/home/kcutler/DataDrive/omnipose_paper/current/Figure 6'

# toggle all layers off
for layer in  viewer.layers:
    layer.visible = False

# toggle specific one(s) on 
layerlist = ['distance']
for name in layerlist:
    viewer.layers[name].visible = True

# set slicing windiw
S = 2000
c = S//2
w = 200
slc = tuple([slice(None),slice(c-w,c+w)])


img = viewer.screenshot(size=(S,S),scale=1,canvas_only=True,flash=False)
io.imsave(os.path.join(figdir,'distance.png'),img[slc])

In [None]:
gifdir = '/home/kcutler/DataDrive/omnipose_paper/current/Figure 6/gifs'
io.check_dir(gifdir)

#toggle all layers off
for layer in  viewer.layers:
    layer.visible = False
# toggle tspecific one(s) on 
layerlist = ['points','distance']
for name in layerlist:
    viewer.layers[name].visible = True

# set slicing windiw
S = 2000
c = S//2
w = 200
slc = tuple([slice(None),slice(c-w,c+w)])

for k in range(tr.shape[-1]):
    viewer.dims.current_step = (k,)+viewer.dims.current_step[1:]
    img = viewer.screenshot(size=(S,S),scale=1,canvas_only=True,flash=False)
    io.imsave(os.path.join(gifdir,str(k)+'.png'),img[slc])
    

In [None]:


img = viewer.screenshot(size=(S,S),scale=1,canvas_only=True,flash=False)
plt.imshow(img[slc])

In [None]:
from napari_animation import Animation
animation = Animation(viewer)
for k in range(tr.shape[-1]):
    viewer.dims.current_step = (k,)+viewer.dims.current_step[1:] 
    animation.capture_keyframe(steps=1)
animation.animate('/home/kcutler/DataDrive/omnipose_paper/Figure 7/demo3.gif', canvas_only=True)

In [None]:
img = viewer.screenshot(size=(4000,)*2,scale=1,canvas_only=True,flash=False)

In [None]:
basedir = '/home/kcutler/DataDrive/omnipose_paper/Figure 7/tracks'
io.check_dir(basedir)
io.imsave(os.path.join(basedir,'3D_'+str(cell_ID)+'_tracks.png'),img)

In [None]:
viewer.add_labels(masks_pred, name='single',visible = False)

In [None]:
m = np.zeros_like(dists,dtype=int)
m[tuple(inds)] = 1
viewer.add_labels(m, name='single_gt',visible = False)

### Run Cellpose3D

In [None]:
# Cellpose seems to need training with  --chan 2 --chan2 1 
modeldir = '/home/kcutler/DataDrive/plantseg/traintest/LateralRootPrimordia/export_small/cellpose/train/models/cellpose_residual_on_style_on_concatenation_off_train_2022_05_09_01_43_03.829853_epoch_499'
omni = False
rescale = False
diam_mean = 0
from cellpose import models, core
use_GPU = core.use_gpu()
model = models.CellposeModel(gpu=use_GPU, pretrained_model=modeldir, net_avg=False, 
                             diam_mean=diam_mean)


In [None]:
mask_threshold = 0 ##############
net_avg = 0
verbose = 0 #####

tile = 0
chans = [0,0]
# chans = None
compute_masks = 1
# n = [0]
rescale = None

masks_cp, flows_cp, _ = model.eval(imgs,channels = chans,rescale=rescale,
                                  mask_threshold=mask_threshold, net_avg=net_avg,
                                  transparency=True, flow_threshold=0., verbose=verbose,tile=1,
                                  compute_masks=compute_masks, do_3D=True, omni=False)

### Run Omnipose 

In [None]:
del model
#best
modeldir = '/home/kcutler/DataDrive/plantseg/traintest/LateralRootPrimordia/export_small/train/models/cellpose_residual_on_style_on_concatenation_off_omni_train_2022_05_08_01_56_56.310115_epoch_1951'
# no improvement over cp somehow 
# modeldir = '/home/kcutler/DataDrive/plantseg/traintest/LateralRootPrimordia/export_small/train/models/cellpose_residual_on_style_on_concatenation_off_omni_train_2022_05_08_01_56_56.310115_epoch_2951'
# try stepping back; now a ral improvement over 1951
modeldir = '/home/kcutler/DataDrive/plantseg/traintest/LateralRootPrimordia/export_small/train/models/cellpose_residual_on_style_on_concatenation_off_omni_train_2022_05_08_01_56_56.310115_epoch_2901'
# 2901 is great, 3051 is ok but not as strong, so try 3001 - also not as good as 1951
# modeldir = '/home/kcutler/DataDrive/plantseg/traintest/LateralRootPrimordia/export_small/train/models/cellpose_residual_on_style_on_concatenation_off_omni_train_2022_05_08_01_56_56.310115_epoch_3001'
#3551 also not great, nor is 3501, but 3401 is terrific, 3451 awful, 3201 meh despite low loss, 3351 bad, 3201 bad, 3251 pretty good, 3151 not great,3301 bad,
# modeldir = '/home/kcutler/DataDrive/plantseg/traintest/LateralRootPrimordia/export_small/train/models/cellpose_residual_on_style_on_concatenation_off_omni_train_2022_05_08_01_56_56.310115_epoch_3401'
# modeldir = '/home/kcutler/DataDrive/plantseg/traintest/LateralRootPrimordia/export_small/train/models/cellpose_residual_on_style_on_concatenation_off_omni_train_2022_05_08_01_56_56.310115_epoch_3301'

dim = 3
nclasses = dim+2
nchan = 1
omni = 1
rescale = False
diam_mean = 0
from cellpose import models, core
# use_GPU = core.use_gpu(3)
# print('>>> GPU activated? %d'%use_GPU)
use_GPU = 0
model = models.CellposeModel(gpu=use_GPU, pretrained_model=modeldir, net_avg=False, 
                             diam_mean=diam_mean, nclasses=nclasses, dim=dim, nchan=nchan)#,device=torch.device('cuda:1'))
model.pretrained_model

In [None]:
import torch
torch.cuda.empty_cache()
mask_threshold = -5 ##############
diam_threshold = 12
net_avg = 0
cluster = 0 ##########
verbose = 1
tile = 0
chans = None
compute_masks = 1
resample=False
rescale=None
omni=True

nimg = len(imgs)
masks_om, flows_om = [[]]*nimg,[[]]*nimg
for k in range(nimg):
    masks_om[k], flows_om[k], _ = model.eval(imgs[k],channels=None,rescale=rescale,mask_threshold=mask_threshold,net_avg=net_avg,
                                      transparency=True,flow_threshold=0.,omni=omni,resample=resample,verbose=verbose,
                                      diam_threshold=diam_threshold,cluster=cluster,tile=1,
                                      compute_masks=compute_masks,flow_factor=10)
#     else:

In [None]:
# [i.shape for i in imgs]


In [None]:
# combine data
masks = [masks_cp,masks_om]
flows = [flows_cp,flows_om]

In [None]:
from cellpose import io
basedir = '/home/kcutler/DataDrive/omnipose_paper/Figure 7'
io.check_dir(basedir)

names = ['cellpose','omnipose']

In [None]:
# now quanititative comparison 
from stardist import matching
a = 0.5
b = 1.0
x = np.arange(a,b,(b-a)/100)
fig = plt.figure()
c = ['r','g']
JI = []

# confogure plot
sz = 1
golden = (1 + 5 ** 0.5) / 2
labelsize = 7

n = len(names)
z = 1
master_color_scheme = [[i,0,0] for i in np.linspace(1,.5,z)]+[[i,i,i] for i in np.linspace(.75,.25,n-z)]

darkmode = 0
if darkmode:
    plt.style.use('dark_background')
    axcol = 'w'
    colors = sinebow(n+1)
    colors = [colors[j+1] for j in range(n)]
    background_color = 'k'
    suffix = '_dark'
else:
    mpl.rcParams.update(mpl.rcParamsDefault)
    axcol = 'k'
    colors = master_color_scheme
    # colors = sinebow(n+1)
    # colors = [colors[j+1] for j in range(n)]
    background_color = np.array([1,1,1,1])
    suffix = ''

mpl.rcParams['figure.dpi'] = 300
fig = plt.figure(figsize=(sz, sz)) 
ax = plt.axes()

plt.xticks(np.arange(min(x), max(x)+1, .25))
plt.xlim([0.5,1])
plt.ylim([0,.4])


# for i,name in enumerate(names):
#     met =[matching.matching((mgt-1).clip(0,mgt.max()-1),m, thresh=x) for mgt,m in zip(masks_gt, masks[i])]
#     a = np.array([[met[k][i].accuracy for i in range(len(x))] for k in range(len(met))]).T
#     np.save(os.path.join(basedir,name+'_3D_plants_individual'),a)
#     a = np.mean(a,axis=1)
#     JI.append(a)
#     plt.plot(x,a, label=name,c=colors[i],linewidth=.5)#+['masks_cp'])
#     np.save(os.path.join(basedir,name+'cellpose_3D_plants'),a)
    # plt.legend

cats = [getname(m) for m in mask_names]
cellnum = np.array([len(np.unique(m))-1 for m in masks_gt])

groups = [[0,1,-2,-1],[2]]
group_names = ['early','late']
# groups = [[0,1],[2],[-2,-1]]
# group_names = ['movie1_4_6','movie1_30','movie2_10_20']
# groups = [[0,1,2],[-2,-1]]
# group_names = ['movie1','movie2']
# groups = [[0,1,2,-2,1]]
groups = [[0,1,2,-2,1]]
group_names = ['all']

groups = [[0,1,2,-2,1],[2]]
group_names = ['all','late']

n = len(groups)
colors = sinebow(n+1)
colors = np.array([colors[j+1] for j in range(n)])
style = ['--','-']
for i,name in enumerate(names):

    a = np.load(os.path.join(basedir,name+'_3D_plants_individual.npy'))
    for k,group in enumerate(groups):
        cn = [cellnum[g] for g in group]
        subcats = '\n'.join([cats[g] for g in group])
        sub_a = np.stack([a[:,g] for g in group])
        # y = np.sum(sub_a.T*cn,axis=1)/sum(cn) #weighting by cell number 
        y = np.mean(sub_a.T,axis=1)
        
        plt.plot(x,y,style[i],c=colors[k],linewidth=.5,label=group_names[k])#+['masks_cp'])
        
    y = a.copy()
    a = np.sum(a*cellnum,axis=1)/sum(cellnum)
    # a = np.mean(a,axis=1)
    JI.append(a)
    # plt.plot(x,a, label=name,c=colors[i],linewidth=.5)#+['masks_cp'])
    # np.save(os.path.join(basedir,name+'cellpose_3D_plants'),a)
    
ax.legend(prop={'size': labelsize}, loc='upper left', frameon=False,bbox_to_anchor=(1.05, 1))
ax.tick_params(axis='both', which='major', labelsize=labelsize,length=3, direction="out",colors=axcol,bottom=True,left=True)
ax.tick_params(axis='both', which='minor', labelsize=labelsize,length=3, direction="out",colors=axcol,bottom=True,left=True)

ax.set_ylabel('Jaccard Index', fontsize = labelsize)
ax.set_xlabel('IoU matching threshold', fontsize = labelsize)


ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.patch.set_alpha(0.0)

fig.patch.set_facecolor(background_color)
plt.show()

figname = 'JI_vs_IoU'
for ext in ['.pdf','.png','.eps']:
    fig.savefig(os.path.join(basedir,figname+suffix+ext),bbox_inches="tight",pad_inches = 0.05)

In [None]:
mask_names

In [None]:
cutoff = 80 # IoU cutoff of 0.8

cp = JI[0].T
om = JI[1].T

av_all = np.mean(np.divide((om-cp),cp,out=np.zeros_like(cp), where=np.logical_and(cp!=0,~np.isnan(cp))))
av_up = np.mean(np.divide((om[cutoff:]-cp[cutoff:]),cp[cutoff:],out=np.zeros_like(cp[cutoff:]), where=np.logical_and(cp[cutoff:]!=0,~np.isnan(cp[cutoff:]))))
av_dwn = np.mean(np.divide((om[:cutoff]-cp[:cutoff]),cp[:cutoff],out=np.zeros_like(cp[:cutoff]), where=np.logical_and(cp[:cutoff]!=0,~np.isnan(cp[:cutoff]))))

print(av_all*100,av_up*100,av_dwn*100,(np.mean(om)-np.mean(cp))/np.mean(cp)*100) # turn

In [None]:
import random
random.seed(42)

k = 0
viewer = napari.view_image(imgs[k], name='cells',visible = False)
viewer.dims.ndisplay = 3

mnc = ncolor.label(masks_gt[k],max_depth=11)
viewer.add_labels(mnc, name='labels_gt',visible = False, color=sinebow(mnc.max()))


# cmap = sinebow(masks[0][k].max())
# l = list(cmap.items())[1:]
# random.shuffle(l)
# cmap = [(0,cmap[0])]+l
# cmap = dict(cmap)
mnc = ncolor.label(masks[0][k],max_depth=20)
cmap =  sinebow(mnc.max())
viewer.add_labels(mnc, name=names[0],visible = False, color=cmap)



mnc = ncolor.label(masks[1][k],max_depth=11)
viewer.add_labels(mnc, name=names[1], color=sinebow(mnc.max()))

In [None]:
viewer.camera.center = [s//2 for s in imgs[k].shape]
viewer.camera.zoom=2
viewer.camera.angles=(49.01724873267338, -59.486485045019705, 36.815397248657874)
viewer.camera.perspective=0.0
viewer.camera.interactive=True

In [None]:
k = 0
j = 1
img = imgs[k]
n_axes = 3
vectors = np.zeros((n_axes, 2, 3), dtype=np.float32)
x_px = np.array([1, 0, 0])
y_py = np.array([0, 1, 0])
z_pz = np.array([0, 0, 1])
vectors[0, 1] = x_px
vectors[1, 1] = y_py
vectors[2, 1] = z_pz
for k in range(img.ndim):
    viewer.add_vectors(vectors[k], name='ax'+str(k), edge_width=1,edge_color='k',blending='translucent_no_depth', length=img.shape[k])

In [None]:
# turn off all masks
layerlist = names+['labels_gt']
for name in layerlist:
    viewer.layers[name].visible = False

for name in layerlist:
    viewer.layers[name].visible = True

    img = viewer.screenshot(size=(2000,2000),scale=1,canvas_only=True,flash=False)
    viewer.layers[name].visible = False

    io.imsave(os.path.join(basedir,'3D_'+name+'_seg.tiff'),img)

In [None]:
viewer.add_image(omnipose.utils.rescale(flows_om[0][2]), name='distance',gamma=0.2,attenuation=0.05,depiction='volume')

In [None]:
cmap = plt.get_cmap(name='viridis')
x = np.linspace(0,1,100)
colors=cmap(1-x)
colors[...,-1] = x
new_colormap = {
    'colors': colors,
    'name': 'custom',
    'interpolation': 'linear'
}
key = 'distance'
viewer.layers[key].colormap = new_colormap
viewer.layers[key].rendering='attenuated_mip'
viewer.layers[key].contrast_limits=[.01,1]

In [None]:
# Plot cross sections and save some output for the figure 

cmap = mpl.cm.get_cmap('viridis')
for i,name in enumerate(names):
    print(name,i)
    for j in range(len(imgs)):
        img = imgs[j]
        mgt = masks_gt[j]
        msk = masks[i][j]
        dt = flows[i][j][2]
        bd = flows[i][j][4]
        frgb = flows[i][j][0]
        k = img.shape[0]//2 # middle slice 
        fig = plt.figure()
        print(np.array(msk).shape,k)
        mgtnc = ncolor.label(mgt[k])
        plt.imshow(np.hstack(tuple(map(cmap,map(omnipose.utils.rescale,[img[k],dt[k], bd[k] ,
                                                                        ncolor.label(msk[k]),
                                                                        mgtnc])))
                             +(frgb[k]/255,)))
        
        tifffile.imwrite(os.path.join(basedir,'slice_'+name+'_'+str(j)+'_frgb.tif'),frgb[k])
        if i:
            smap = sinebow(4)
            smap = ListedColormap(list(smap.values()))
            pic = smap(mgtnc)
            pic[...,-1] = mgtnc>0
            tifffile.imwrite(os.path.join(basedir,'slice_'+str(j)+'_gt.tif'),np.uint8(omnipose.utils.rescale(pic)*(2**8-1)))


        plt.axis('off')
        

In [None]:
### Plot some ground truth 

In [None]:
5

In [None]:
from mpl_toolkits import mplot3d


fig = plt.figure()
ax = plt.axes(projection='3d')
ax.voxels(imgs[0])

In [None]:
# np.unique(sd)
# plt.imshow(sd[-1])
import ipyvolume as ipv
# ipv.figure()
# ipv.quickvolshow(sd)

import ipyvolume.pylab as p3
p3.figure()
p3.style.use('light')
p3.style.axes_off()
p3.style.box_off()
# vol = p3.volshow(np.log(T+1e-12))
# p3.volshow(ncolor.label(seg))
# points = np.nonzero(dt==1)
# m = np.max(mask)-100
# points = np.nonzero((mask>m-5)*(dt==1)*(mask<=m))
# select = np.logical_or.reduce([mask == i for i in np.unique(mask[n])[1:]])
# points = np.nonzero((dt==1)*select)


# mg = np.meshgrid(*points,indexing='ij')
# ax.quiver(*mg,*tuple(mu),length=0.1)
# sub_mu = [mu[(i,)+points] for i in range(mu.shape[0])]
# ax.quiver(*mg,*sub_mu,normalize=True)
# sub_p =  [p[(i,)+points] for i in range(p.shape[0])]
# p3.scatter(*sub_p, size=1, marker="sphere",origin=np.mean(sub_p,axis=1))
vol = p3.volshow(np.log(imgs[0]+.1),lighting=0)
# ipv.quickvolshow(np.log(imgs[0]+1))
# vol.ray_steps =20


p3.show()

In [None]:
import numpy as np
import plotly.graph_objects as go

# Generate nicely looking random 3D-field
np.random.seed(0)
vol = omnipose.utils.normalize99(imgs[0])
shape = vol.shape
print(shape,vol.min(),vol.max())
X, Y, Z = np.mgrid[:shape[0], :shape[1], :shape[2]]
vol = np.zeros((l, l, l))

fig = go.Figure(data=go.Volume(
    x=X.flatten(), y=Y.flatten(), z=Z.flatten(),
    value=vol.flatten(),
    # isomin=0.0,
    # isomax=0.99,
    opacity=0.9,
    # surface_count=1,
    opacityscale=[[-0.5, 1], [-0.2, 0], [0.2, 0], [0.5, 1]],
    surface_count = 30
    ))
fig.update_layout(scene_xaxis_showticklabels=False,
                  scene_yaxis_showticklabels=False,
                  scene_zaxis_showticklabels=False)
fig.show()

In [None]:
import plotly.graph_objects as go
import numpy as np
X, Y, Z = np.mgrid[-8:8:40j, -8:8:40j, -8:8:40j]
values = np.sin(X*Y*Z) / (X*Y*Z)

fig = go.Figure(data=go.Volume(
    x=X.flatten(),
    y=Y.flatten(),
    z=Z.flatten(),
    value=values.flatten(),
    isomin=0.1,
    isomax=0.8,
    opacity=0.1, # needs to be small to see through all surfaces
    surface_count=17, # needs to be a large number for good volume rendering
    ))
fig.show()

In [None]:
from mayavi import mlab
mlab.init_notebook()
# x, y, z = np.ogrid[-10:10:20j, -10:10:20j, -10:10:20j]
# s = np.sin(x*y*z)/(x*y*z)
# mlab.pipeline.volume(mlab.pipeline.scalar_field(omnipose.utils.normalize99(s)),vmin=0, vmax=0.8)
# mlab.contour3d(s)
mlab.test_plot3d()

In [None]:
# plot some slices 
cmap = mpl.cm.get_cmap('viridis')
for i,name in enumerate(names):
    print(name,i)
    for j in range(len(imgs)):
        img = imgs[j]
        mgt = masks_gt[j]
        msk = masks[i][j]
        dt = flows[i][j][2]
        bd = flows[i][j][4]
        frgb = flows[i][j][0]
        k = img.shape[0]//2 # middle slice 
        fig = plt.figure()
        print(np.array(msk).shape,k)
        plt.imshow(np.hstack(tuple(map(cmap,map(omnipose.utils.rescale,[img[k],dt[k], bd[k] ,
                                                                        ncolor.label(msk[k]),
                                                                        ncolor.label(mgt[k])])))
                             +(frgb[k]/255,)))
        plt.axis('off')

In [None]:
basedir

In [None]:
# plt.imshow(msk)
msk.dtype

In [None]:
im = tifffile.imread('/home/kcutler/DataDrive/plantseg/traintest/LateralRootPrimordia/export_good/test/Movie1_t00030_crop_gt.tif')

In [None]:
# viewer.axes.visible = False
# viewer.overlays.interaction_box.show_handle=1
viewer = napari.view_image(omnipose.utils.rescale(im), name='cells',gamma=0.2,attenuation=0.05,depiction='volume')
viewer.dims.ndisplay = 3
viewer.camera.center = [s//2 for s in im.shape]
viewer.camera.zoom=.7
viewer.camera.angles=(49.01724873267338, -59.486485045019705, 36.815397248657874)
viewer.camera.perspective=0.0
viewer.camera.interactive=True
n_axes = 3
vectors = np.zeros((n_axes, 2, 3), dtype=np.float32)
x_px = np.array([1, 0, 0])
y_py = np.array([0, 1, 0])
z_pz = np.array([0, 0, 1])
vectors[0, 1] = x_px
vectors[1, 1] = y_py
vectors[2, 1] = z_pz
for k in range(im.ndim):
    viewer.add_vectors(vectors[k], name='ax'+str(k), edge_width=1,edge_color='k',blending='translucent_no_depth', length=im.shape[k])

In [None]:
viewer.layers[0].contrast_limits=[.002,1]


In [None]:

# colors = np.linspace(
#     start=[1, 1, 1, 0],
#     stop=[0, 0, 0, 1],
#     num=10,
#     endpoint=True
# )
cmap = plt.get_cmap(name='viridis')
x = np.linspace(0,1,100)
colors=cmap(1-x)
colors[...,-1] = x
new_colormap = {
    'colors': colors,
    'name': 'custom',
    'interpolation': 'linear'
}
viewer.layers[0].colormap = new_colormap
viewer.layers[0].rendering='attenuated_mip'
viewer.layers[0].contrast_limits=[.01,1]

# k = 0
# im = imgs[k]

In [None]:
img = viewer.screenshot(size=(2000,2000),scale=1,canvas_only=True,flash=False)
from cellpose import io
basedir = '/home/kcutler/DataDrive/omnipose_paper/Figure 7'
io.check_dir(basedir)
io.imsave(os.path.join(basedir,'3D.tiff'),img)

In [None]:
fig = plt.figure(figsize=(20,20))
plt.rcParams['figure.dpi'] = 300
plt.imshow(img)
plt.axis('off')

In [None]:
viewer.__dict__

In [None]:
len(masks_gt)

In [None]:
cmap = mpl.cm.get_cmap('viridis')

for j in range(len(imgs)):
    img = imgs[j]
    msk = masks[j]
    dt = flows[j][2]
    bd = flows[j][4]
    frgb = flows[j][0]
    k = img.shape[0]//2 # middle slice 
    fig = plt.figure()
    plt.imshow(np.hstack(tuple(map(cmap,map(omnipose.utils.rescale,[img[k],dt[k],bd[k],
                                                                    ncolor.label(msk[k]),
                                                                    ncolor.label(masks_gt[j][k])])))
                         +(frgb[k]/255,)))
    plt.axis('off')

In [None]:
# now quanititative comparison 

from stardist import matching
a = 0.5
b = 1.0
thresh =np.arange(a,b,(b-a)/100)
met =[matching.matching(mgt,m, thresh=thresh) for mgt,m in zip(masks_gt, masks)]
a = np.array([[met[k][i].accuracy for i in range(len(thresh))] for k in range(len(met))]).T
plt.plot(thresh,a)#+['masks_cp'])
np.save(os.path.join(basedir,'omnipose_3D_plants_individual'),a)
a = np.mean(a,axis=1)
np.save(os.path.join(basedir,'omnipose_3D_plants'),a)


In [None]:
names = ['cellpose','omnipose']
JI = [np.array(np.load(os.path.join(basedir,name+'_3D_plants_individual.npy'))) for name in names]
JI_mean = [np.array(np.load(os.path.join(basedir,name+'_3D_plants.npy'))) for name in names]

In [None]:
JI[0].shape

In [None]:
plt.figure()
plt.plot(thresh,JI[0],label=names[0])#+['masks_cp'])
plt.plot(thresh,JI[1],label=names[1])#+['masks_cp'])
plt.legend()

In [None]:
masks_all,

In [None]:
plt.imshow(masks[0])

In [None]:
# d = omnipose.core.diameters(masks_gt)
# d
image_names = io.get_image_files(savedir[1],extensions=['h5'])
filename = image_names[10]
with h5py.File(filename, "r") as f:
        # List all groups
        print("Keys: %s" % f.keys())
        masks = np.array(f['label'])
        print(masks.min())
        masks = omnipose.utils.format_labels(masks)
        name = getname(filename)
        print(name)
        fastremap.renumber(masks,in_place=True)
        labels = fastremap.refit(masks)
        img = np.array(f['raw'])

In [None]:
img.shape

In [None]:
import edt
dt = edt.edt(labels)

In [None]:
# np.max(masks),np.max(labels),np.min(masks),np.min(labels),filename
omnipose.core.diameters(masks,dt)

In [None]:
from omnipose.utils import rescale 
k = 100
plt.imshow(np.hstack(tuple(map(rescale,(img[:,k],ncolor.label(labels[:,k]),dt[:,k])))))
# plt.imshow(dt[:,500])

In [None]:
6

In [None]:
viewer = napari.view_image(img, name='cells')
viewer.add_labels(labels, name='labels')

In [None]:
# len(flows)
a,b = map(list,zip(*[[l*2,l/2] for l in range(4)]))
a,b

In [None]:
from omnipose.core import diameters
diameters(masks[slc])

In [None]:
a = np.ones([10,10])
pad = 4
apad = np.pad(a,pad)
unpad = tuple([slice(pad,-pad) if pad else slice(None,None)]*a.ndim)
apad[unpad].shape,a.shape,apad.shape

In [None]:
np.pad(a,None)

In [None]:
import fastremap
m,_ = fastremap.renumber(masks,in_place=False)
np.max(m)

In [None]:
import edt
dt = edt.edt(m[slc][50])
dt

In [None]:
im.shape

In [None]:
# modeldir = '/home/kcutler/DataDrive/3D_BBBC/BBBC027/train/models/cellpose_residual_on_style_on_concatenation_off_omni_train_2022_05_03_18_58_51.851134_epoch_301'
import tifffile
im = tifffile.imread('/home/kcutler/DataDrive/3D_BBBC/BBBC027/test/0000.tif')

In [None]:
masks_gt = tifffile.imread('/home/kcutler/DataDrive/3D_BBBC/BBBC027/test/0000_masks.tif')
# d = omnipose.core.diameters(masks_gt)
# d
im.shape

In [None]:
modeldir = '/home/kcutler/DataDrive/3D_BBBC/BBBC027/train/models/cellpose_residual_on_style_on_concatenation_off_omni_train_2022_05_03_18_58_51.851134_epoch_101'

modeldir = '/home/kcutler/DataDrive/3D_BBBC/BBBC027/train/models/cellpose_residual_on_style_on_concatenation_off_omni_train_2022_05_03_22_16_13.356647_epoch_2601'

dim = 3
nclasses = dim+2
nchan = 1
omni = 1
rescale = False
diam_mean = 0
from cellpose import models, core
use_GPU = core.use_gpu()

# print('>>> GPU activated? %d'%use_GPU)
# use_GPU = 0 # uh oh, 3D models not working for CPU
# use_GPU = 0
model = models.CellposeModel(gpu=use_GPU, pretrained_model=modeldir, net_avg=False, 
                             diam_mean=diam_mean, nclasses=nclasses, dim=dim, nchan=nchan)

In [None]:
mask_threshold = -5 ##############
diam_threshold = 12
net_avg = 0
cluster = 1
verbose = 0 #####

tile = 0
chans = [0,0]
chans = None
compute_masks = 1
slc_crop = slice(None, -1, None)
# slc_crop = slice(100, 120, None)
# slc_crop = slice(13, 53, None) # 733
# slc_crop = slice(50, 100, None) # 704
# slc_crop = slice(15, 52, None) # 707
# slc_crop = slice(30, 70, None) # 712
# slc_crop = (Ellipsis,)+tuple([slice(0,129)]*2)
# slc_crop = (Ellipsis,)+tuple([slice(0,400)]*2)

resample=False
# rescale = .5 #None
rescale=None

imcrop = im[slc_crop]
# pad = 30
# imcrop = np.pad(imcrop,pad,mode='reflect')
# imcrop = im[100:120]
# imcrop = im.copy()
# import torch
#TILING not working in 3D, should fix that 
masks, flows, styles = model.eval(imcrop,channels = chans,rescale=rescale,mask_threshold=mask_threshold,net_avg=net_avg,
                                  transparency=True,flow_threshold=0.,omni=omni,resample=resample,verbose=verbose,
                                  diam_threshold=diam_threshold,cluster=cluster,tile=1,
                                  compute_masks=compute_masks,batch_size=2)
#     else:

In [None]:
for i,a in enumerate(['a','b']):
    print(i,a)

In [None]:
k = 50
dt = flows[2]
bd = flows[4]
cmap = mpl.cm.get_cmap('viridis')

plt.imshow(np.hstack(tuple(map(cmap,map(omnipose.utils.rescale,[imcrop[k],dt[k],bd[k],masks[k]])))+(flows[0][k]/255,)))

In [None]:
from stardist import matching
a = 0.5
b = 1.0
thresh =np.arange(a,b,(b-a)/100)
met = [[]]*1
met[0] = matching.matching(masks_gt[slc_crop], newmasks, thresh=thresh)
# gotta make sure precision is jaccard index!!!!
plt.plot(thresh,np.array([[met[k][i].accuracy for i in range(len(thresh))] for k in range(len(met))]).T)#+['masks_cp'])
# plt.legend()

In [None]:
imgi = im.copy()[np.newaxis]

In [None]:
tile_overlap = 0.1
bsize = 224

nchan = imgi.shape[0]
shape = imgi.shape[1:]
dim = len(shape)
tile_overlap = min(0.5, max(0.05, tile_overlap))
# bsizeY, bsizeX = min(bsize, Ly), min(bsize, Lx)
# B = [np.int32(min(b,s)) for s,b in zip(im.shape,bsize)] if bzise variable
bbox = tuple([np.int32(min(bsize,s)) for s in shape])

# tiles overlap by 10% tile size
ntyx = [1 if s<=bsize else int(np.ceil((1.+2*tile_overlap) * (s / bsize))) for s in shape]
start = [np.linspace(0, s-b, n).astype(int) for s,b,n in zip(shape,bbox,ntyx)]

sub = [[] for n in range(dim)]
# IMG = np.zeros((len(ystart), len(xstart), nchan,  bsizeY, bsizeX), np.float32)
IMG = np.zeros(tuple([len(s) for s in start])+(nchan,)+bbox, np.float32)

# for j in range(len(ystart)):
#     for i in range(len(xstart)):
#         ysub.append([ystart[j], ystart[j]+bsizeY])
#         xsub.append([xstart[i], xstart[i]+bsizeX])
#         IMG[j, i] = imgi[:, ysub[-1][0]:ysub[-1][1],  xsub[-1][0]:xsub[-1][1]]
IMG.shape, bbox,ntyx, np.diff(start[1]),

In [None]:
int(np.ceil((1.+2*tile_overlap) * (shape[1] / bsize)))

In [None]:
start

In [None]:
import itertools
intervals = [[slice(si,si+bsize) for si in s] for s in start]
# intervals = []
# for s in start:
#     intervals.append([slice(si,si+bsize) for si in s])
#     # if len(s)>1:
#     #     intervals.append([slice(si,si+bsize) for si in s])
#     # else:
    #     intervals.append([slice(None)])
subs = list(itertools.product(*intervals))
len(subs)

In [None]:
I = []
for slc in subs:
    imcrop = imgi[(Ellipsis,)+slc]
    print(imcrop.shape)
    I.append(imcrop)

In [None]:
I2 = np.stack(I)
I2.shape

In [None]:
# viewer = napari.view_image(I[0], name='cells')
I[0].shape


In [None]:
from cellpose.transforms import make_tiles
IMG, ysub, xsub, Ly, Lx = make_tiles(im[0][np.newaxis])

In [None]:
IMG.shape,Ly,Lx,ysub

In [None]:
viewer = napari.view_image(imcrop, name='cells')
viewer.add_labels(masks, name='labels')

In [None]:
# viewer.add_labels(masks_gt[slc_crop], name='labels_gt')
viewer.add_image(bd, name='bd')
viewer.add_image(dt, name='dt')

In [None]:
viewer.add_labels(ncolor.label(masks_gt[slc_crop]), name='labels_gt')


In [None]:
viewer.add_image(flows[0], name='flows')

In [None]:
mu = flows[1]
dt = flows[2]
p = flows[3]
bd = flows[4]
hdbscan = 1
cluster = 0
newmasks,p,_,lab= omnipose.core.compute_masks(mu,dt,bd=bd,mask_threshold=mask_threshold,
                                           use_gpu=True,verbose=True,cluster=cluster, 
                                           flow_threshold=0., dim=3, hdbscan=hdbscan, 
                                              debug=1, eps=3, min_size=30, nclasses=4)

In [None]:
alg = ['','H']
mnc = ncolor.label(newmasks)
viewer.add_labels(mnc,name='{}DBSCAN'.format(alg[hdbscan]),color=sinebow(mnc.max()))
# 17.91925573348999 for hdbscan, 10.16240119934082 for dbscan ; then 25 vs 9 later 

In [None]:
np.unique(newmasks)
color = sinebow(newmasks.max())
# [color[l]]
selection = newmasks>0
sub_p = [p[i].ravel()[selection.ravel()] for i in range(p.shape[0])]
# sub_col = [color[l] for l in lab[selection.ravel()]]
sub_col = [color[l] for l in mnc[mnc>0]]
data = np.stack(sub_p,axis=-1)
viewer.add_points(np.round(data),size=1,symbol='disc',face_color=sub_col,edge_color=sub_col)

In [None]:
import hdbscan
clusterer = hdbscan.HDBSCAN(min_cluster_size=400,min_samples=3)
labels = clusterer.fit_predict(np.array(sub_p).T)
len(np.unique(labels))

In [None]:
clusterer = hdbscan.HDBSCAN(cluster_selection_epsilon=2**0.5,min_samples=3, allow_single_cluster=True)
clusterer.fit(np.stack(sub_p,axis=1))

In [None]:
np.unique(clusterer.labels_)
# np.sqrt(2)

In [None]:
clusterer.condensed_tree_.plot()

In [None]:
from sklearn.cluster import DBSCAN
eps=1.33
db = DBSCAN(eps=eps, min_samples=5, n_jobs=-1).fit(np.array(sub_p).T)

In [None]:
len(np.unique(db.labels_))

In [None]:
# np.prod(p.shape)
# p[0].ravel().shape,(points.ravel()).shape
%timeit p[0].ravel()[points.ravel()]
# p[0][points].ravel()

In [None]:
%timeit p[0][points].ravel()

In [None]:
# import datashader as ds, pandas as pd, colorcet
# !pip install vaex-jupyter
import vaex
# df = vaex.from_arrays(z=p[(0,+)],y=p[1],x=p[2])

In [None]:
axes = [i for i in 'zyx']
points = bd>0
pdict = dict(zip(axes, [p[i].ravel()[points.ravel()] for i in range(p.shape[0])]))

In [None]:
df = vaex.from_dict(pdict)
df

In [None]:
import vaex.jupyter.model as vjm
await vaex.jupyter.gather()

In [None]:
# df.viz.scatter(df.x, df.y, selection=df.z==40, c="red", alpha=0.5, s=4)
%matplotlib inline
plt.figure()
df.viz.heatmap([["x", "y"], ["x", "z"], ["y", "z"]], title="Face on and edge on", figsize=(8, 4), limits='100%')

In [None]:
df.viz.heatmap("z", "y", z="FeH:-3,-1,8",
               visual=dict(row="z"),
               figsize=(12, 8),
               f="log",
               wrap_columns=3,
               limits='99%');


In [None]:
omni = 1
masks, dists, T, mu = omnipose.core.masks_to_flows(mask,use_gpu=1,omni=omni,dim=mask.ndim)

In [None]:
from omnipose.core import divergence
from omnipose.utils import rescale
from omnipose import utils
from cellpose import plot
from scipy.ndimage import gaussian_filter
mu_f = gaussian_filter(mu,sigma=.5)

In [None]:
%matplotlib inline
fig = plt.figure(figsize=[20]*2)

if mask.ndim==2:
    plt.imshow(plot.dx_to_circ(mu))
else:
    # plot flow and dist slices 
    from omnipose.utils import rescale
    d = mu.shape[0]
    c = np.array([1]*2+[0]*(d-2))
    # c = np.arange(d)
    def cyclic_perm(a):
        n = len(a)
        b = [[a[i - j] for i in range(n)] for j in range(n)]
        return b
    slices = []
    idx = np.arange(d)
    cmap = mpl.cm.get_cmap('magma')
    # pic = cmap(np.sqrt(heat))
    # pic[:,:,-1] = masks>0
    for inds in cyclic_perm(c):
        slc = tuple([slice(-1) if i else mu.shape[k+1]//2 for i,k in zip(inds,idx)])
        flow = rescale(plot.dx_to_circ(mu_f[np.where(inds)+slc],transparency=1))
        # dist = cmap(rescale(T)[slc])
        vals = rescale(np.log(T+1)[slc]) if omni else rescale(np.log(T+1e-5)[slc])
        dist = cmap(vals)
        dist[:,:,-1]=vals
        
        
        vals = utils.normalize99(divergence(mu_f))[slc]
        div = cmap(vals)
        div[:,:,-1]=vals
        div = utils.normalize99(div)
        
        
        msk = rescale(np.stack([mask[slc]]*4,axis=-1))>0
        dist[:,:,-1]=mask[slc]>0
        fig = plt.figure(figsize=[80]*2)
        plt.imshow(np.hstack((msk,flow,dist,div)),interpolation='none')
        plt.axis('off')
        plt.show()

In [None]:
from cellpose import transforms
tyx = (50,)*3
rs = None
omni = 1
#[imgs[k][np.newaxis,:,:] if imgs[k].ndim==2 else imgs[k] for k in n], [flows[k] for k in n]
img,lbl,sc = transforms.random_rotate_and_resize([imcrop[np.newaxis]],[masks],
                                                 tyx=tyx,
                                                 omni=omni,
                                                 rescale=rs,
                                                 dim=masks.ndim)


In [None]:

from omnipose.utils import rescale
from omnipose.core import sigmoid

j = 0
img0 = img[j].squeeze()
mask0 = lbl[j][0]
bin0 = lbl[j][1]
bd0 = lbl[j][2]
dist0 = lbl[j][3]
weight0 = lbl[j][4]
mu0 = lbl[j][5:]

d = mu0.shape[0]
c = np.array([1]*2+[0]*(d-2))
# c = np.arange(d)
def cyclic_perm(a):
    n = len(a)
    b = [[a[i - j] for i in range(n)] for j in range(n)]
    return b
slices = []
idx = np.arange(d)
cmap = mpl.cm.get_cmap('magma')

for inds in cyclic_perm(c):
    
    slc = tuple([slice(-1) if i else mu0.shape[k+1]//2 for i,k in zip(inds,idx)])
    flow = rescale(plot.dx_to_circ(mu0[np.where(inds)+slc],transparency=1))
    # dist = cmap(rescale(T)[slc])
    vals = rescale(dist0)[slc] if omni else rescale(np.log(dist0+1e-5)[slc])
    dist = cmap(vals)
    dist[:,:,-1]=vals


    vals = utils.normalize99(divergence(mu0))[slc]
    div = cmap(vals)
    div[:,:,-1]=vals
    div = utils.normalize99(div)
    
    bin_ = cmap(bin0[slc])
    bd = cmap(bd0[slc])
    weight = cmap(rescale(weight0)[slc])

    msk = cmap(rescale(ncolor.label(mask0[slc])))
    fig = plt.figure(figsize=[80]*2)
    plt.imshow(np.hstack((msk,bin_,flow,dist,bd,weight,div)),interpolation='none')
    plt.axis('off')
    plt.show()

In [None]:
ncolor.label(mask0).shape

In [None]:
plt.imshow(bd)

In [None]:
lbl.shape