# Figure 6 and Extended Data Figure 8
Here we want to pull out a good example from the cyto2 dataset. Cellpose and Omnipose really 'trade blows', so to speak - there are plenty of cases where one fails where the other succeeds. Generally this happens in cases where cells overlap or are otherwise ambiguous. Our performance metrics show more or less equivalent performance.

In [None]:
%load_ext autoreload
%autoreload 2

# First, import dependencies.
import numpy as np
import time, os, sys
from cellpose import models, core, utils, io
import skimage.io
import omnipose

# This checks to see if you have set up your GPU properly.
# CPU performance is a lot slower, but not a problem if you are only processing a few images.
use_GPU = core.use_gpu()
print('>>> GPU activated? %d'%use_GPU)

# for plotting 
import matplotlib.pyplot as plt
plt.style.use('dark_background')
import matplotlib as mpl
%matplotlib inline
mpl.rcParams['figure.dpi'] = 300

In [None]:
# files = ['/home/kcutler/DataDrive/cyto2/test/005_img.tif']
basedir = '/home/kcutler/DataDrive/cyto2/test/'
mask_filter = '_masks'
img_filter = '_img'
img_names = io.get_image_files(basedir, mask_filter, img_filter)
mask_names = io.get_label_files(img_names, mask_filter, img_filter)

In [None]:
multi_only = 0
if multi_only:
    files = []
    maskfiles = []
    for path,maskpath in zip(img_names,mask_names):
        im = skimage.io.imread(path)
        if np.any(im[0]) and np.any(im[1]):
            files.append(path)
            maskfiles.append(maskpath)
            print(im.shape)
else:
    files = img_names
    maskfiles = mask_names

In [None]:
# def getname(path,suffix='_masks'):
#     return os.path.splitext(Path(path).name)[0].replace(suffix,'')
# names = [getname(path) for path in mask_names]
# select = [5,27]
# files = [path for path,name in zip(img_names,names) if any('%03d' % (n,)  in name for n in select)]
# files

In [None]:
imgs = [skimage.io.imread(f) for f in files]
masks_gt = [skimage.io.imread(f) for f in maskfiles]


In [None]:
diameters = [omnipose.core.diameters(mask) for mask in masks_gt]
np.mean(diameters),len(diameters)

The images in cyto2 generally have cytoplasm in channel 1 (red) and nuclei in channel 2 (green), not related to how the actual data was acquired. 

In [None]:
from cellpose import io, transforms, plot

# for i in imgs:
#     print(i.shape)
nimg = len(imgs)
print(nimg)

plt.figure(figsize=[20]*2) # initialize figure
for k in range(len(imgs)):
    img = transforms.move_min_dim(imgs[k]) # move the channel dimension last
    imgs[k] = transforms.normalize99(img,omni=True)
    print(imgs[k].shape,k)
    # plt.subplot(1,len(files),k+1)
    # rgb = np.stack((imgs[k][0],imgs[k][1],np.zeros_like(imgs[k][0])),axis=-1)
    # plt.imshow(plot.image_to_rgb(imgs[k])) 
    # plt.axis('off')
    # plt.show()

In [None]:
# model_name = ['cyto2','cyto2_omni']
# L = len(model_name)
# model = [models.CellposeModel(gpu=use_GPU, model_type=model_name[i]) for i in range(L)]
model_name = ['cyto2','cyto2_omni_bit loss']
L = len(model_name)

modeldir = '/home/kcutler/DataDrive/cyto2/train/models/cellpose_residual_on_style_on_concatenation_off_omni_train_2022_04_15_00_39_49.881936_epoch_301' #oops, no size model! maybe that's why it is taking cyto2 so long to train 
modeldir = '/home/kcutler/DataDrive/cyto2/train/models/cellpose_residual_on_style_on_concatenation_off_omni_train_2022_04_16_01_24_45.606751_epoch_2501'
# the next model I am trainignis with my big cell fixes and rescaling to 35 instead of 30, closer to the mean diameter of the dataset
# will need to validate the szmean is loaded!! 
modeldir = '/home/kcutler/DataDrive/cyto2/train/models/cellpose_residual_on_style_on_concatenation_off_omni_train_2022_04_22_17_55_58.018802_epoch_2601'

# model = [models.CellposeModel(gpu=use_GPU, model_type='cyto2'), models.CellposeModel(gpu=use_GPU, pretrained_model=modeldir)]
# model = [models.CellposeModel(gpu=use_GPU, pretrained_model=modeldir)] 
model = [models.CellposeModel(gpu=use_GPU, model_type='cyto2'),
         models.CellposeModel(gpu=use_GPU, model_type='cyto2_omni',diam_mean=36)] #SO it appears like new model 701 is already better 

In [None]:
# chans = [[2,3],[2,1]] # green cytoplasm and blue nucleus; cellpose documentation is confusing about this
chans = [[1,2],[2,1]] # cyto2 makes more sense 
# chans = [[1,2],[1,2]]
# chans = [[1,1]]

n = range(nimg) 
# n = [-2]
# n = [-4]
# n = range(10,15)
n = [-1]
# [-5] is bad
n = [1]
# define parameters
mask_threshold = [0,0,-1]
verbose = 0 # turn on if you want to see more output 
use_gpu = use_GPU #defined above
transparency = True # transparency in flow output
rescale= None # give this a number if you need to upscale or downscale your images
flow_threshold = 0 # default is .4, but only needed if there are spurious masks to clean up; slows down output
resample = True #whether or not to run dynamics on rescaled grid or original grid 

N = L+1 # three options: pure cellpose, mixed, omnipose, new omnipose
omni = [0,1,1]
ind = [0,0,1]
masks, flows, styles = [[]]*N, [[]]*N, [[]]*N

In [None]:
import time

t = [[]]*N
# for i in range(N):
# for i in [0,-1]:
# n = [-5]
chans=[[1,2],[2,1]]
# n = range(nimg) 
# n = [49]
flow_threshold = 0
for i in [-1]:
    start_time = time.time()
    masks[i], flows[i], styles[i] = model[ind[i]].eval([imgs[k] for k in n],channels=chans[ind[i]],diameter=[diameters[k] for k in n],mask_threshold=mask_threshold[i],
                                               transparency=transparency,flow_threshold=flow_threshold,omni=omni[i], #toggle omni 
                                               resample=resample, verbose=verbose, cluster=omni[i],interp=True) #omni[i]
    t[i] = time.time() - start_time
    print(t[i])

In [None]:
# plt.imshow(plot.image_to_rgb(imgs[2]))
diameters[2]

In [None]:
# imgs[n[0]-1].shape
# plt.imshow(omnipose.utils.normalize99(imgs[-5][0]))

In [None]:
%matplotlib inline
mpl.rcParams['figure.dpi'] = 300
from cellpose import plot
import omnipose

for idx,i in enumerate(n):
    
    # for k,ki in enumerate(ind):
    for k in [-1]:
        ki = ind[k]
        print('model is:',model_name[ki],', omni is:',omni[ki])
        maski = masks[k][idx]
        flowi = flows[k][idx][0]
        print('m',maski.shape,imgs[i].shape,'chans',chans[i%2])
        fig = plt.figure(figsize=(12,5))
        # im = transforms.move_min_dim(imgs[i])
        # print(im.shape)
        # plot.show_segmentation(fig, omnipose.utils.normalize_image(imgs[i],1-masks_gt[i]>0,bg = .4), maski, flowi, channels=chans[i%2], omni=True, bg_color=0)
        if not np.any(imgs[i][1]):
            im = imgs[i][0]
        else:
            im = imgs[i]
        plot.show_segmentation(fig, im, maski, flowi, channels=chans[i%2], omni=1, bg_color=0)
        
        plt.tight_layout()
        plt.show()

In [None]:
cmap = mpl.cm.get_cmap('plasma')
outline_col = cmap(0.85)[:3]
k = -1
io.save_masks([imgs[i] for i in n], [masks[k][i] for i in n], [flows[k][i] for i in n], [files[i] for i in n], 
              tif=True, #whether to use PNG or TIF format
              suffix='', # suffix to add to files if needed 
              save_flows=True, 
              save_outlines=True, # save outline images 
              outline_col = outline_col,
              dir_above=0, # save output in the image directory or in the directory above (at the level of the image directory)
              in_folders=True, # save output in folders (recommended)
              save_txt=False, # txt file for outlines in imageJ
              save_ncolor=True) # save ncolor version of masks for visuaizatin and editing 

In [None]:
n = [49]
n = range(nimg) 
imlist = []
chanlist = []
for i in n:
    if not np.any(imgs[i][1]):
        imlist.append(imgs[i][0])
        # chanlist.append([0,0])
    else:
        imlist.append(imgs[i])
        # chanlist.append([2,1])
io.save_masks(imlist, [masks[k][i] for i in n], [flows[k][i] for i in n], [files[i] for i in n], 
              tif=True, #whether to use PNG or TIF format
              suffix='', # suffix to add to files if needed 
              save_flows=0, 
              save_outlines=True, # save outline images 
              outline_col = outline_col,
              dir_above=0, # save output in the image directory or in the directory above (at the level of the image directory)
              in_folders=True, # save output in folders (recommended)
              save_txt=False, # txt file for outlines in imageJ
              save_ncolor=0, # save ncolor version of masks for visuaizatin and editing 
              omni=1)

In [None]:
save0 = os.path.join('/home/kcutler/DataDrive/omnipose_paper/Figure S8/cyto2/crop_outlines')
io.check_dir(save0)
ext='.png'
n = range(nimg) 
# n = [49]
imlist = []
chanlist = []
k = -1
for i in n:
    
    if not np.any(imgs[i][1]):
        img0 = imgs[i][0]
    else:
        img0 = imgs[i]
    
    mgt = masks[k][i]
    s = mgt.shape
    L = np.min(s)
    outli = plot.outline_view(img0,mgt,color=outline_col)#,mode='thick')
    
    slc = (slice(0,L),)*2
    crop_outli = outli[slc]
    
    # plt.imshow(crop_outli,interpolation='none')
    # plt.axis('off')
    # plt.show()
    
    name = '%03d' % (i,) 

    # save the outlines
    savepath = os.path.join(save0,name+'_crop_outlines'+ext)
    io.imsave(savepath,crop_outli)

In [None]:
crop_outli.shape,L

In [None]:
basedir

In [None]:
import scipy
mu = flows[k][idx][1]
f = scipy.ndimage.zoom(mu, tuple([1,2,2]), order=1)
f.shape, mu.shape

In [None]:
# %%timeit 
a = 2.1
b = 1.9
# f = scipy.ndimage.zoom(mu, tuple([1,a,b]), order=1)
scipy.ndimage.zoom(imgs[0], a, order=1).shape,imgs[0].shape

In [None]:
%%timeit 
a = 2.1
b = 1.9
# f = scipy.ndimage.zoom(mu, tuple([1,a,b]), order=1)
np.stack([ scipy.ndimage.zoom(mu[k], tuple([a,b]), order=1) for k in range(2)])

In [None]:
bd = flows[k][idx][4]
def sigmoid(x):
    return 1 / (1 + np.exp(-x))
plt.imshow(sigmoid(bd))

In [None]:
mu = flows[k][idx][1]
div = omnipose.core.divergence(mu)
plt.imshow(div)

In [None]:
%matplotlib widget
mpl.rcParams['figure.dpi'] = 100
plt.imshow(masks_gt[1])

In [None]:
import ncolor
save0 = os.path.join('/home/kcutler/DataDrive/omnipose_paper/Figure 6/cyto2')
io.check_dir(save0)

# define outline color
cmap = mpl.cm.get_cmap('plasma')
outline_col = cmap(0.85)[:3]
ext = '.png'

k = -1 # k denotes the model 
pad = 10
labels = [34,36]
name = ['cellpose','mixed','omnipose']
for idx,i in enumerate(n):
    print('model is:',name[j],', omni is:',omni[j])
    mgt = masks_gt[i]
    bin0 = np.zeros_like(mgt)
    for l in labels:
        bin0[mgt==l] = 1
    
    inds = np.nonzero(bin0)
    max_inds = np.array(bin0.shape)-1
    slc = tuple([slice(max(0,min(inds[k])-pad),min(max_inds[k],max(inds[k])+pad)) for k in range(mgt.ndim)])
    
    crop_img = imgs[i][(Ellipsis,)*(imgs[i].ndim-2)+slc]
    crop_masks = masks[k][idx][slc]
    crop_flow = flows[k][idx][0][slc]
    
    crop_outli = plot.outline_view(crop_img,crop_masks,color=outline_col)#,mode='thick')

    plt.imshow(crop_outli,interpolation='none')
    plt.axis('off')
    plt.show()

    basedir = os.path.join(save0,name[j])
    io.check_dir(basedir)
    # save the cropped image, RGB uint8 is not interpolated in illustrator ;) 
    img0 = np.stack([crop_img,]*3,axis=-1)
    savepath = os.path.join(basedir,'crop_img'+ext)
    io.imsave(savepath,np.uint8(img0*(2**8-1)))

    # save the outlines
    savepath = os.path.join(basedir,'crop_outlines'+ext)
    io.imsave(savepath,crop_outli)

    # save the masks
    savepath = os.path.join(basedir,'crop_masks'+ext)
    io.imsave(savepath,np.uint8(crop_masks))

    # save the flows
    savepath = os.path.join(basedir,'crop_flows'+ext)
    skimage.io.imsave(savepath,np.uint8(crop_flow))

    # save the distance
    savepath = os.path.join(basedir,'crop_dist'+ext)
    dist = omnipose.utils.rescale(flows[k][idx][2][slc])
    cmap = mpl.cm.get_cmap('plasma')
    pic = cmap(dist)
    pic[:,:,-1] = crop_masks>0
    skimage.io.imsave(savepath,np.uint8(pic*(2**8-1)))

    # save the boundary
    savepath = os.path.join(basedir,'crop_bd'+ext)
    dist = omnipose.utils.rescale(flows[k][idx][4][slc])
    cmap = mpl.cm.get_cmap('viridis')
    pic = cmap(dist)
    pic[:,:,-1] = crop_masks>0
    skimage.io.imsave(savepath,np.uint8(pic*(2**8-1)))

    #save a grayscale version for adobe illustator vectorization 
    ncl = ncolor.label(crop_masks)
    grey_n = np.stack([1-omnipose.utils.rescale(ncl)]*3,axis=-1)
    savepath = os.path.join(basedir,'masks_gray'+ext)
    io.imsave(savepath,np.uint8(grey_n*(2**8-1)))

In [None]:
# additional examples from cyto2
img_idx = [34,49,51,67,66,64]

In [None]:
# files = ['/home/kcutler/DataDrive/cyto2/test/005_img.tif']
basedir = '/home/kcutler/DataDrive/cyto2/test/'
mask_filter = '_masks'
imf = '_img'
img_names = io.get_image_files(basedir, mask_filter, imf)
mask_names,flow_names = io.get_label_files(img_names, mask_filter, imf)

In [None]:
im = skimage.io.imread('/home/kcutler/DataDrive/cyto2/test/005_img.tif')
mask = skimage.io.imread('/home/kcutler/DataDrive/cyto2/test/005_masks.tif')

In [None]:
%matplotlib inline
mpl.rcParams['figure.dpi'] = 300
fig = plt.figure(figsize=(20,)*2)
plt.imshow(plot.outline_view(omnipose.utils.normalize99(im), mask))

# plt.imshow(plot.image_to_rgb(im,omni=True))

In [None]:
from skimage import color
import ncolor
from omnipose.utils import sinebow, normalize99
c = sinebow(5)
colors = np.array(list(c.values()))[1:]
im = normalize99(transforms.move_min_dim(im))
if im.ndim>2:
    img0 = im.mean(axis=-1)
else:
    img0 = im
overlay = color.label2rgb(ncolor.label(mask,max_depth=20),img0,colors,bg_label=0,alpha=1/3)


In [None]:
plt.imshow(overlay)

In [None]:
im.shape,transforms.reshape(im,channels=[1,2]).shape


In [None]:
i = 1
im0 = imgs[i]
im = omnipose.utils.normalize_image(im0,1-masks_gt[i]>0,bg = .6)
plt.imshow(np.hstack((np.stack((im0[0],im0[1],np.zeros_like(im0[0])),axis=-1),np.stack((im[0],im[1],np.zeros_like(im[0])),axis=-1))))