# Cellpose vs Omnipose performance
This script runs Cellpose and Omnipose on various 2D datasets: bact_phase, bact_fluor, cyto2, and worm.
It is very long because I would like to reduce the repetition the loading and analysis code (factoring it into functions is overkill and kind of infeasible, since there are some special cases sprinkled throughout). I suggest using the ToC to navigate. 

Extended Data Figures 4,5 look at the failure modes of Cellpose. Figure 3 compares Omnipose to Cellpose on phase. Fig. 4 compares Omnipose to Cellpose on fluorescence. 

In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
from cellpose import plot, models, io, dynamics
from omnipose import utils
import skimage.io
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('dark_background')
import matplotlib as mpl
%matplotlib inline
mpl.rcParams['figure.dpi'] = 300
import time, os, sys

## Define datasets and models 
We will run the same code on each of the following datasets. We begin by defining names (used for saving output) as well as the paths to the test datasets. Later we will define the paths to the models trained ont he corresponding training datasets. 

In [None]:
datasets = ['bact_phase','bact_fluor','cyto2','worm']

basedirs = ['/home/kcutler/DataDrive/omnipose_all/phase/test_sorted',
            '/home/kcutler/DataDrive/omnipose_all/fluor/test_sorted',
            '/home/kcutler/DataDrive/cyto2/test',
            '/home/kcutler/DataDrive/omnipose_train/worm_combined/test',
           ]


## Load ground truth

In [None]:
which = 0 # select which of the above models to run
clean = 0 # control whether or not to re-run; still want to define the name of each experiment 
dataset = datasets[which]
basedir = basedirs[which]

mask_filter = '_masks'

if dataset=='cyto2':
    img_filter = '_img'
else:
    img_filter = ''
    
img_names = io.get_image_files(basedir,mask_filter,img_filter,look_one_level_down=True)
mask_names = io.get_label_files(img_names, mask_filter,img_filter)
# Obviously sorting the masks by corresponding base name is critical. The different suffixes mess this up, but sorted()
# can take a function like the one below as a key. Modifying the default suffix requires the lambda syntax (e.g. cellpose). 
def getname(path,suffix='_masks'):
    return os.path.splitext(Path(path).name)[0].replace(suffix,'')

mask_names = sorted(mask_names,key=getname)
img_names = sorted(img_names,key=getname)
imgs = [skimage.io.imread(f) for f in img_names]
masks_gt = [skimage.io.imread(f) for f in mask_names]

print(dataset)

In [None]:
# count masks
cnt = 0;
for mask in masks_gt:
    lbls = np.unique(utils.format_labels(mask))
    cnt += len(lbls[lbls>0])
print(cnt)

## Initialize models
The list of models I want to run depends on the dataset and the figure. The `bact_phase` dataset is the main dataset for the paper. The `bact_fluor` dataset is evaluated with Omnipose only because we are only interested in how Omnipose runs on it compared to the `bact_phase` dataset. 

In [None]:
from cellpose import core, models
net_avg=False

use_GPU = core.use_gpu()
print('>>> GPU activated? %d'%use_GPU)


if dataset == 'bact_phase':

    model = [models.CellposeModel(gpu=use_GPU, model_type='bact_phase_cp'),
             models.CellposeModel(gpu=use_GPU, model_type='bact_phase_omni')]
    chans = [0,0]
    
elif dataset == 'bact_fluor':

    model = [models.CellposeModel(gpu=use_GPU, model_type='bact_fluor_cp'),
             models.CellposeModel(gpu=use_GPU, model_type='bact_fluor_omni')]
    chans = [0,0]
    
elif dataset=='cyto2':

    model = [models.CellposeModel(gpu=use_GPU, model_type='cyto2'),
             models.CellposeModel(gpu=use_GPU, model_type='cyto2_omni')] # needs to be updated to use my bit depth version 
    chans = [[1,2],[2,1]] # this should be updated too with the new model and different training parameters 
    
elif dataset=='worm':   
    # pure worm cellpose models and omnipose as well as a combined bact+worm omnipose model
    # all trained for 3201 epochs 
    model = [models.CellposeModel(gpu=use_GPU, model_type='worm_cp'),
             models.CellposeModel(gpu=use_GPU, model_type='worm_omni'),
             models.CellposeModel(gpu=use_GPU, model_type='worm_bact_omni')]

    chans = [0,0]

## Run models

In [None]:
from cellpose.io import logger_setup
logger,log_file=logger_setup()
resample = True
transparency = True # save flows with transparency
verbose = False
cluster = True #use clustering algorithm for Omnipose 

imglist = imgs[:]
N = len(imgs)
names_phase = ['Cellpose','Mixed','Omnipose']
print('datasest is {} consisting of {} images and {} labels. Clean flag is {}'.format(dataset,N,cnt,clean))
if dataset == 'bact_phase':
    names = names_phase
    omni = [False,True,True]
    n = len(names)
    masks,flows,styles,d = [[]]*n,[[]]*n,[[]]*n,[[]]*n

    if clean:

        # Cellpose 'bact', no omni reconstruction 
        masks[0], flows[0], styles[0] = model[0].eval(imglist,channels=chans,rescale=None,mask_threshold=-1,flow_threshold=0,
                                                      omni=omni[0],resample=resample,tile=False, transparency=transparency, 
                                                      verbose=verbose)
        # Cellpose + omni ('Hybrid')
        masks[1], flows[1], styles[1] = model[0].eval(imglist,channels=chans,rescale=None,mask_threshold=-1,flow_threshold=0,
                                                      omni=omni[1],cluster=cluster,resample=resample,tile=False,transparency=transparency, 
                                                      verbose=verbose)
        #Omnipose
        masks[2], flows[2], styles[2] = model[1].eval(imglist,channels=chans,rescale=None,mask_threshold=-1,flow_threshold=0,
                                                      omni=omni[2],cluster=cluster,resample=resample,tile=False,transparency=transparency, 
                                                      verbose=verbose)

if dataset == 'bact_fluor':
    names = ['Cellpose','Omnipose']
    n = len(names)
    masks,flows,styles,d = [[]]*n,[[]]*n,[[]]*n,[[]]*n
    if clean: 
        cluster = False # both models perform better without clustering in this case
        
        for k in range(len(model)):
            masks[k], flows[k], _ = model[k].eval(imglist,channels=chans,rescale=None,mask_threshold=-1,flow_threshold=-1,
                                                  omni=k, cluster=cluster, resample=resample, tile=False, transparency=transparency, 
                                                  verbose=verbose)


if dataset=='cyto2':
    names = ['Cellpose','Mixed','Omnipose']
    omni = [False,True,True]
    n = len(names)
    masks,flows,styles,d = [[]]*n,[[]]*n,[[]]*n,[[]]*n
    if clean:
        resample = False # both models perform better without resampling in this case
        diam_threshold = 30 # I specified this at one point, clustering is on anyway
        
        diameters = [omnipose.core.diameters(mask) for mask in masks_gt]
        
        # cyto2
        masks[0], flows[0], styles[0] = model[0].eval(imglist, channels=chans[0], diameter=diameters, mask_threshold=0, flow_threshold=0,
                                                      omni=omni[0], resample=resample, tile=False, transparency=transparency, 
                                                      verbose=verbose)
        # cyto2 + omni=True ('Mixed')
        masks[1], flows[1], styles[1] = model[0].eval(imglist, channels=chans[0], diameter=diameters, mask_threshold=0, flow_threshold=0,
                                                      omni=omni[1], cluster=cluster, resample=resample, tile=False, transparency=transparency, 
                                                      verbose=verbose, diam_threshold=diam_threshold)
        # cyto2_omni
        masks[2], flows[2], styles[2] = model[1].eval(imglist, channels=chans[1], diameter=diameters, mask_threshold=-1, flow_threshold=0,
                                                      omni=omni[2], cluster=cluster, resample=resample, tile=False, transparency=transparency, 
                                                      verbose=verbose, diam_threshold=diam_threshold)


if dataset=='worm':
    names = ['Cellpose','Mixed','Omnipose','Omnipose_worm+bact']
    omni = [False,True,True,True]
    n = len(names)
    masks,flows,styles,d = [[]]*n,[[]]*n,[[]]*n,[[]]*n
    if clean:
        # worm_cp, no omni reconstruction 
        masks[0], flows[0], styles[0] = model[0].eval(imglist, channels=chans, rescale=None, mask_threshold=-1, flow_threshold=0,
                                                      omni=omni[0],cluster=cluster, resample=resample, tile=False, transparency=transparency, verbose=verbose)
        # worm_cp + omni=True ('Mixed')
        masks[1], flows[1], styles[1] = model[0].eval(imglist, channels=chans, rescale=None, mask_threshold=-1, flow_threshold=0,
                                                      omni=omni[1],cluster=cluster, resample=resample, tile=False, transparency=transparency, verbose=verbose)
        # worm_omni
        masks[2], flows[2], styles[2] = model[1].eval(imglist, channels=chans, rescale=None, mask_threshold=-1, flow_threshold=0,
                                                      omni=omni[2],cluster=cluster, resample=resample, tile=False, transparency=transparency, verbose=verbose)
        # worm_bact_omni
        masks[3], flows[3], styles[3] = model[2].eval(imglist, channels=chans, rescale=None, mask_threshold=-1, flow_threshold=0,
                                                      omni=omni[3], cluster=cluster, resample=resample, tile=False, transparency=transparency, verbose=verbose)

In [None]:
save0 = os.path.join('/home/kcutler/DataDrive/omnipose_all',dataset+'_comparison')
io.check_dir(save0)
print('Clean flag is {}'.format(clean))

In [None]:
if clean:
    for j in range(n):
        savedir = os.path.join(save0,names[j])
        print(savedir)
        io.check_dir(savedir)
        io.save_masks(imgs, masks[j], flows[j], 
                      img_names,
                      save_flows=True, 
                      save_outlines=0, 
                      savedir=savedir, 
                      in_folders=True, 
                      save_txt=False,
                      save_plot=False)

## Read in results
This step also will also delete small masks from image boundaries. 

In [None]:
import os
from omnipose import utils 

for j in range(n):
    savedir = os.path.join(save0,names[j])
    mask_names = [os.path.join(savedir,'masks',os.path.splitext(os.path.basename(name))[0]+'_cp_masks.png') for name in img_names]
    # flow_names = [os.path.join(savedir,'flows',os.path.splitext(os.path.basename(name))[0]+'_flows.tif') for name in img_names] #should fix png vs tif
    masks[j] = [utils.format_labels(utils.clean_boundary(skimage.io.imread(f))).astype(np.uint32) for f in mask_names]
    # flows[j] = [skimage.io.imread(f) for f in flow_names]

In [None]:
from cellpose import metrics
from skimage import measure
threshold=np.linspace(0.5,1,100)
import fastremap

In [None]:
#Note that 'ap' is called average precision in cellpose, but really it is the jaccard index
ap = [[]] * len(masks) 
tp = [[]] * len(masks)
fp = [[]] * len(masks)
fn = [[]] * len(masks)
IoU = [[]] * len(masks)
OvR = [[]] * len(masks)
pred_areas = [[]] * len(masks)

In [None]:
# Do the same boundary cleaning on ground truth masks 
import ncolor 

nimg = len(masks_gt)
cell_areas = [[]] * nimg
masks_gt_clean = [None]*nimg
masks_pred_clean = [[None]*nimg]*len(masks)
# remapping = [[]] * nimg
for j in range(nimg):
    mgt = ncolor.format_labels(utils.clean_boundary(masks_gt[j]))
    masks_gt_clean[j] = mgt
    regions = measure.regionprops(mgt)
    areas = np.array([reg.area for reg in regions])
    cell_areas[j] =  areas

In [None]:
np.min(np.concatenate(cell_areas))

## Compute performance metrics for each model

In [None]:
# go over each model
if clean:
    for j,masks_pred in enumerate(masks):
    #     masks_pred = map(list,zip(*[ utils.format_labels(utils.clean_boundary(msk)) for msk in masks_pred]))
    # just apply cleanup to the masks when reading them in 
        api,tpi,fpi,fni = metrics.average_precision(masks_gt_clean,masks_pred,threshold=threshold)
        ap[j] = ap[j]+[api]
        tp[j] = tp[j]+[tpi]
        fp[j] = fp[j]+[fpi]
        fn[j] = fn[j]+[fni]

        masks_pred_clean[j] = masks_pred
        # go over every image
        for k in range(nimg):
            # get the IoU matrix; axis 0 corresponds to GT, axis 1 to pred 
            regions = measure.regionprops(masks_pred[k])
            areas = np.array([reg.area for reg in regions])
            pred_areas[j] = pred_areas[j] + [areas]
            iou = metrics._intersection_over_union(masks_gt_clean[k], masks_pred[k])
            ovp = metrics._label_overlap(masks_gt_clean[k], masks_pred[k])[1:,1:] #throw out columns corresponding to zero  
            OvR[j] = OvR[j]+[ovp / areas[np.newaxis,:]] # Overlap Ratio           
            IoU[j] = IoU[j]+[iou]
 

In [None]:
save0 = os.path.join('/home/kcutler/DataDrive/omnipose_all',dataset+'_comparison')
io.check_dir(save0)

if clean:
    np.savez(os.path.join(save0,'OvR'),OvR)
    np.savez(os.path.join(save0,'IoU'),IoU)
    np.savez(os.path.join(save0,'cell_areas'),cell_areas)
    np.savez(os.path.join(save0,'ap'),ap)
# np.savez(savedir+'remapping',remapping)

In [None]:
save0

## Read back in performance metrics

In [None]:

OvR = np.load(os.path.join(save0,'OvR'+'.npz'),allow_pickle=True)['arr_0']
IoU = np.load(os.path.join(save0,'IoU'+'.npz'),allow_pickle=True)['arr_0']
cell_areas = np.load(os.path.join(save0,'cell_areas'+'.npz'),allow_pickle=True)['arr_0']
ap = np.load(os.path.join(save0,'ap'+'.npz'),allow_pickle=True)['arr_0']

if dataset=='bact_fluor':
    savedir_phase = os.path.join('/home/kcutler/DataDrive/omnipose_all',datasets[0]+'_comparison')
    ap_phase = np.load(os.path.join(savedir_phase,'ap'+'.npz'),allow_pickle=True)['arr_0']
    ap_plot = [a for a in ap_phase]+[a for a in ap]
    names_plot = ['phase_'+s for s in names_phase]+['fluor_'+s for s in names]

else:
    names_plot = names
    ap_plot = [a for a in ap]
    
cell_count = np.array([len(np.unique(ca)) for ca in cell_areas])

In [None]:
len(ap),j,n,names

## Plot JI averaged over entire dataset

In [None]:
%matplotlib inline
import matplotlib as mpl
from omnipose.utils import sinebow

# computing the JI for the whole dataset by averaging per image
# vs averagaing by cell area (I think the latter makes more sense, but it is not standard)
# (also, it doesn't change the plot that much 
weighted = 0

x = threshold
sz = 1
golden = (1 + 5 ** 0.5) / 2
labelsize = 7


# colors = ['g','r','b','y','c','m']
n = len(names_plot)
z = 1
master_color_scheme = [[i,0,0] for i in np.linspace(1,.5,z)]+[[i,i,i] for i in np.linspace(.75,.25,n-z)]

darkmode = 0
if darkmode:
    plt.style.use('dark_background')
    axcol = 'w'
    colors = sinebow(n+1)
    colors = [colors[j+1] for j in range(n)]
    background_color = 'k'
    suffix = '_dark'
else:
    mpl.rcParams.update(mpl.rcParamsDefault)
    axcol = 'k'
    colors = master_color_scheme
    # if weighted:
    #     colors = sinebow(n+1)
    #     colors = [colors[j+1] for j in range(n)]
    background_color = np.array([1,1,1,1])
    suffix = ''

mpl.rcParams['figure.dpi'] = 300
fig = plt.figure(figsize=(sz, sz)) 
ax = plt.axes()
# plt.tight_layout()
# plt.minorticks_on()
plt.xticks(np.arange(min(x), max(x)+1, .25))
plt.xlim([0.5,1])
plt.ylim([0,1])
# plt.yticks(np.arange(0, 1.1, .25))

selection = range(n)
print(names_plot)
if dataset=='bact_fluor':
    # selection = [-3,-2,-1]
    selection = range(len(ap_plot))
    # selection = [-1,-2]
for j in selection:
    
    if weighted:
        mJI = np.sum(np.array([ji*np.sum(a) for a,ji in zip(cell_areas,ap_plot[j][0])]),axis=0)/np.sum(np.concatenate(cell_areas))
        ax.plot(x,mJI.T,'--',label=names_plot[j],color=colors[j],linewidth=.5)

    ax.plot(x,np.mean(ap_plot[j][0],axis=0).T,label=names_plot[j],color=colors[j],linewidth=.5)
#     ax.plot(x,np.mean((fp[j][0]+fn[j][0]),axis=0),label=pretty_names[j],color=colors[j])


# ax.set_facecolor('w')
ax.legend(prop={'size': labelsize}, loc='upper left', frameon=False,bbox_to_anchor=(1.05, 1))
# ax.legend(prop={'size': labelsize}, loc='best', frameon=False)
ax.tick_params(axis='both', which='major', labelsize=labelsize,length=3, direction="out",colors=axcol,bottom=True,left=True)
ax.tick_params(axis='both', which='minor', labelsize=labelsize,length=3, direction="out",colors=axcol,bottom=True,left=True)
# plt.setp(ax.xaxis.get_label(), visible=True, text='IoU')
# plt.setp(ax.get_xticklabels(), visible=True, ha='right')

# ax.grid(b=True, which='major', color='b', linestyle='-')
ax.set_ylabel('Jaccard Index', fontsize = labelsize)
ax.set_xlabel('IoU matching threshold', fontsize = labelsize)
# plt.set(xlabel='IoU threshold', ylabel='Average Precision',fontsize=labelsize)
# plt.tight_layout()
# plt.yscale('log')


ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.patch.set_alpha(0.0)

fig.patch.set_facecolor(background_color)
plt.show()

figname = 'JI_vs_IoU'
# fig.savefig(os.path.join(save0,figname+suffix+'.eps'),bbox_inches=tight_bbox)
# fig.savefig(os.path.join(save0,figname+suffix+'.pdf'),bbox_inches=tight_bbox)
for ext in ['.pdf','.png','.eps']:
    fig.savefig(os.path.join(save0,figname+suffix+ext),bbox_inches="tight",pad_inches = 0.05)


## Calculate errors 

In [None]:
N = len(names)
nimg = len(cell_areas)
cell_errors = [[]]*N
total_errors = [0]*N
total_single_errors = [0]*N
total_cells = len(np.concatenate(cell_areas))

# area_thresh = np.linspace(1,np.max(np.concatenate(cell_areas)),100)
# area thresholds 
area_thresh = [np.percentile(np.concatenate(cell_areas),i) for i in range(100)]
M = len(area_thresh)
percent = [0]*N


for j in range(N):
    print(names[j])
    ce_thresh = [[]*2]*M
    te_thresh = np.zeros((2,M))
    tse_thresh = np.zeros((2,M))
    tc_thresh = np.zeros((2,M))
    
    for k in range(nimg):
        r = OvR[j][k].copy() #overlap ratio 
        r[r<=0.75] = 0 
        mx = np.max(r,axis=0) #find the maximum overlap for each true cell 
        mx[mx==0] = np.nan #exclude case where a spurious label has no overlaps
        hits = np.sum(r==mx,axis=1) # sum will be zero if a cell label has zero hits, not sure which ones these are
        cell_errors[j] = cell_errors[j] + [np.abs(hits-1)] #error if >1 or =0 
        total_errors[j] += np.sum(hits[hits>1]-1)+np.sum(hits==0) # -1 because a hit of 2 is 2 pred labels, 1 'extra' = 1 error
        total_single_errors[j] += np.sum(hits>1)+np.sum(hits==0)
        
        ca = cell_areas[k]
        for i,a in enumerate(area_thresh):
           
            for l in range(2): # look at both above AND below the area threshold
                cell_filter = ca>=a if l==0 else ca<a
                cell_count = np.count_nonzero(cell_filter)
                hits_thresh = hits[cell_filter] # sum will be zero if a cell label has zero hits, not sure which ones these are
#                 ce_thresh[l,i] += [np.abs(hits_thresh-1)] #error if >1 or =0 
                te_thresh[l,i] += np.sum(hits_thresh[hits_thresh>1]-1)+np.sum(hits_thresh==0) # -1 because a hit of 2 is 2 pred labels, 1 'extra' or 1 error
                tse_thresh[l,i] += np.sum(hits_thresh>1)+np.sum(hits_thresh==0)
                tc_thresh[l,i] += cell_count

    
    eps = 1e-8
    percent[j] = [[utils.safe_divide(tse_thresh[l,i],tc_thresh[l,i]+eps)*100 for i in range(M)] for l in range(2)]        


## Print and save error metrics  

Calculate probability of errors above and below cutoff
This is used in the section "XXX". We save the output to a test file in the `save0` directory. 

In [None]:
x = np.concatenate(cell_areas)
m = np.percentile(x,75)
top_count = np.count_nonzero(x>=m)
bottom_count = np.count_nonzero(x<m)
cell_count = np.count_nonzero(x)
string = ['Top quartile of area is {} square pixels'.format(m),
         'Number of cells in top quartile is {}'.format(top_count),
         'Number of cells in bottom three quartiles is {}'.format(bottom_count),
          '']
# for j in [2,1,0]:
for j in range(N):
    y = np.concatenate(cell_errors[j])
    p = y>0
    q = np.logical_and(x>=m,p)
    s = np.logical_and(x<m,p)
    N_top = np.count_nonzero(q)
    N_bottom = np.count_nonzero(s)
    N = np.count_nonzero(y)

    substr = [names[j],
            'Number/fraction of cells in the top quartile with one or more errors: {} / {}'.format(N_top,N_top/top_count*100),
            'Number/fraction of cells in the bottom three quartiles with one or more errors: {} / {}'.format(N_bottom, N_bottom/bottom_count*100),
            'Number/fraction of all cells with one or more errors: {} / {} \n'.format(N, N/cell_count*100)]
    string += substr
    
# see the output in the notebook and save to txt file
with open(os.path.join(save0,'error_metrics.txt'), "w") as text_file:
    for s in string:
        [print(s,file=f) for f in [text_file,None]]

In [None]:
N

## Plot Figure 4A
Plot percent of cells with >=1 error against the area percentile cutoff. This tells us if there are any differences in how the algorithm treats cells by area. A flat line indicates that there is no difference in error count between large and small cells. 

In [None]:
mpl.rcParams.update(mpl.rcParamsDefault)
axcol = 'k'

from omnipose.utils import sinebow
n = len(names)
z = 1
master_color_scheme = [[i,0,0] for i in np.linspace(1,.5,z)]+[[i,i,i] for i in np.linspace(.75,0,n-z)]

golden = (1 + 5 ** 0.5) / 2
sz = 1.5
labelsize = 7

%matplotlib inline
darkmode = 0
if darkmode:
    plt.style.use('dark_background')
    axcol = 'w'
    colors = sinebow(n+1)
    colors = [colors[j+1] for j in range(n)]
    background_color = 'k'
    suffix = '_dark'
else:
    mpl.rcParams.update(mpl.rcParamsDefault)
    axcol = 'k'
    cmap = mpl.cm.get_cmap('viridis')
#     colors = cmap(np.linspace(0,.9,len(names)))
    colors = master_color_scheme
    background_color = np.array([1,1,1,1])
    suffix = ''
    
mpl.rcParams['figure.dpi'] = 300

# could turn these results into a plot by repeating at different cutoffs 
fig = plt.figure(figsize=(sz, sz)) 
ax = plt.axes()
for j in range(n):
# for j in [-1,-2]:
# for j in [0,1,-1]:
#     ax.plot(area_thresh,np.divide(tse_thresh[j],tc_thresh),label=names[j],color=colors[j])
#     ax.plot(range(100),percent[j],label=names[j],color=colors[j])
    ax.plot(range(100),percent[j][0],label=names[j],color=colors[j])
#     ax.plot(range(100),percent[j][1],'--',label='above',color=colors[j])
#     print('mean percent error is',np.mean(percent[j]))
    # ax.set_facecolor('w')
# plt.xscale('log')
# ax.legend(prop={'size': labelsize}, loc='upper left', frameon=False,bbox_to_anchor=(1.05, 1))

# ax.legend(prop={'size': labelsize}, loc='best', frameon=False)
ax.tick_params(axis='both', which='major', labelsize=labelsize,length=3, direction="out",colors=axcol,bottom=True,left=True)
ax.tick_params(axis='both', which='minor', labelsize=labelsize,length=3, direction="out",colors=axcol,bottom=True,left=True)
ax.set_ylabel('Percent of cells \nwith >=1 error', fontsize = labelsize)
ax.set_xlabel('Area percentile cutoff', fontsize = labelsize)
plt.ylim([0,100])
plt.xlim([0,100])

ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.patch.set_alpha(0.0)
ax.legend(prop={'size': labelsize}, loc='upper left', frameon=False,bbox_to_anchor=(1.05, 1))
fig.patch.set_facecolor(background_color)
plt.show()

a = 35
tight_bbox_raw = ax.get_tightbbox(fig.canvas.get_renderer())
tight_bbox_raw._points+=[[-a,-a],[a,a]]
tight_bbox = mpl.transforms.TransformedBbox(tight_bbox_raw, mpl.transforms.Affine2D().scale(1./fig.dpi))



figname = 'Error_vs_area'
# fig.savefig(os.path.join(save0,figname+suffix+'.eps'),bbox_inches=tight_bbox)
# fig.savefig(os.path.join(save0,figname+suffix+'.pdf'),bbox_inches=tight_bbox)
for ext in ['.pdf','.png','.eps']:
    fig.savefig(os.path.join(save0,figname+suffix+ext),bbox_inches="tight",pad_inches = 0.05)

In [None]:
# Consolidate error and cell area for plotting 
a = np.concatenate(cell_areas)

area_cutoff = np.percentile(a,75)
cell_filter = a>=area_cutoff
text = ['number of cells: {}'.format(len(a)),
        'fraction of areas: {}'.format(np.count_nonzero(cell_filter)/len(a)*100),
        'top area quartile: {}\n'.format(area_cutoff)]
# for j in [2,1,0]:

for j in range(n):
    e = np.concatenate(cell_errors[j]) 
    p = np.array(percent[j][0])
    text+=[names[j],'fraction of errors: {}'.format(np.sum(e[cell_filter])/np.sum(e)*100),
           'minimum error percentage: {}'.format(np.min(p[p>0])),
           'maximum error percentage: {}'.format(np.max(p[p>0])),
           # 'cross 5 percent: {}'.format(np.argwhere(p>5)),
           'mean error percentage: {}\n'.format(np.mean(p[p>0]))]

with open(os.path.join(save0,'error_fraction_by_area.txt'), "w") as text_file:
    [print(t,file=f) for t in text for f in [text_file,None]]

In [None]:
save0

## Separate into cell types for plotting statistics

This section of the notebook is only intended to be run for the bacterial datasets.

In [None]:
# plt.scatter(IoU[1])
exclude = 'xxxx' # could explude some problematic images for all 

if dataset=='bact_phase':
    cats = ['regular morphologies','mutants, antibiotics','elongated wildtype']
    subsets= [['5I_crop','PAO1_Staph','PSVB','Serratia_Ecoli','wiggins','vibrio','bthai'],
              ['Hpylori','dnaA','ftsN','murA','cex','a22','Az'],['caulo','streptomyces']]

elif dataset=='bact_fluor':
    cats = ['regular morphologies','mutants and antibiotics']
    subsets= [['wiggins','vibrio','bthai'],
              ['cex','a22']]
elif dataset=='worm':
    cats = ['worms']
    subsets= [['elegans']]
else:
    cats = ['']
    subsets = [['']]
K = len(cats)
indices = [np.array([i for i,s in enumerate(mask_names) if any((name in s) and (exclude not in s) for name in subsets[k])]).astype(int) for k in range(K)]

counts = np.zeros((K,1),int)
for j,inds in enumerate(indices):
    # print(j)
    for i in inds:
        # print(len(np.unique(masks_gt_clean[i]))-1)
        counts[j] +=len(np.unique(masks_gt_clean[i]))-1
counts


print('Are all images accounted for?',np.sum([len(indices[k]) for k in range(K)])==len(mask_names))
print('Image category counts:',[len(indices[k]) for k in range(K)])
print('Image category cell counts:',counts)

In [None]:
# len(mask_names),np.sum([len(indices[k]) for k in range(K)])
# indices
# mask_names[0]
np.sum([len(indices[k]) for k in range(K)]),len(mask_names)

In [None]:
counts

## Calculate Jaccard Index averages 

In [None]:
from scipy.optimize import linear_sum_assignment
per_cell = 1
J = len(names)
K = len(cats)
# y = [[]*K]*J #BADDDDDDDDDD! Need to check my other results to make sure it was done correctly, https://stackoverflow.com/questions/54673821/python-how-to-initialize-a-nested-list-with-empty-values-which-i-can-append-to 
y = [ [ [] for i in range(K) ] for i in range(J) ]
for j in range(J):
    for k in range(K):
        matched_iou = []
        mean_matched_iou = []
        for ind in indices[k]:
            iou = IoU[j][ind][1:,1:]
            th = 0
            n_min = min(iou.shape[0], iou.shape[1])
            costs = -(iou >= th).astype(float) - iou / (2*n_min)
            true_ind, pred_ind = linear_sum_assignment(costs)
            miou = iou[true_ind, pred_ind]
            matched_iou.append(miou)
            mean_matched_iou.append(np.sum(miou)/max(iou.shape[0], iou.shape[1]))
            
        y[j][k] = [m for sublist in matched_iou for m in sublist] if per_cell else mean_matched_iou

## Plot Jaccard Index per image
Color-coded by cell type

In [None]:


from omnipose.utils import sinebow
J = len(names)
linestyle='-'

%matplotlib inline
darkmode = 0
if darkmode:
    plt.style.use('dark_background')
    axcol = 'w'
    colors = sinebow(J+1)
    colors = [colors[j+1] for j in range(n)]
    background_color = 'k'
else:
    mpl.rcParams.update(mpl.rcParamsDefault)
    axcol = 'k'
    cmap = mpl.cm.get_cmap('viridis')
    colors = cmap(np.linspace(0,.9,len(names)))
#     colors = master_color_scheme
    background_color = np.array([1,1,1,1])
    
mpl.rcParams['figure.dpi'] = 300

x = threshold
golden = (1 + 5 ** 0.5) / 2
sz = 6.5
labelsize = 7
# fig = plt.figure(figsize=(sz, sz/golden)) 
# fig = plt.figure(figsize=(sz, sz/2)) 
# ax = plt.axes()
fig, axs = plt.subplots(1,len(names),figsize=(sz,sz/J),sharex=True, sharey=True)

# colors = ['tab:olive','tab:gray','tab:red']
# plt.tight_layout()
# plt.minorticks_on()
plt.xticks(np.arange(min(x), max(x)+1, .25))
# plt.xlim([0.45,1.05])
# plt.ylim([0,1])
# plt.yticks(np.arange(0, 1.1, .25))

# colors = ['g','r','b','y','c','m']
# pretty_names = ['Original Cellpose','','','Mixed Method','','','','New Method']
cell_count = np.array([len(np.unique(ca)) for ca in cell_areas])
# for j in range(n):
# sort_colors = [colors[j] for j in args]
# colors = ['r',[.75,.75,.75],[0,0,0],[.5,.5,.5],[1,0,0]]

fill = 0 # makes mean+-"err" fill around the average JI
per_image = 1 # plots JI per image
density = 0 # plots JI per image as a dot for each IoU and colors by density 

l = np.array([len(indices[k]) for k in range(K)])
# alpha = 1.5/(1+l/np.min(l))
alpha = [1,1,1]

from scipy.stats import gaussian_kde

for j in range(J):
    
    mean = np.mean(ap[j][0],axis=0).T
    err = np.std(ap[j][0],axis=0).T
    if len(names)>1:
        ax = axs[j]
    else:
        ax = axs
    if fill:
        ax.plot(x,mean,label=names[j],color=colors[j],linestyle=linestyle)
        ax.fill_between(x,mean-err,mean+err,facecolor=colors[j],alpha=0.5)
    if per_image:
#         ax.plot(x,ap[j][0].T,label=names[j],color=colors[j],linestyle=linestyle,alpha=.05)
#         ax.plot(x,mean,label=names[j],color='k',linestyle=linestyle)
        for k in range(K):
#             if k<K-1:
#                 alpha = .25
#             else:
#                 alpha = .5
            arr = ap[j][0][indices[k]].T
            ax.plot(x,arr,label=names[j],color=colors[j],linestyle=linestyle,alpha=alpha[k],lw=.75)
            ax.plot(x,mean,label=names[j],color=axcol,linestyle=linestyle,lw=.75)
    if density:
        y = np.array(ap[j]).T.flatten()
        X = np.repeat(x,len(ap[j][0]))
        xy = np.vstack([X,y])
        z = gaussian_kde(xy)(xy)
        idx = z.argsort()
        X, y, z = X[idx], y[idx], z[idx]
        ax.scatter(X, y, c=np.log(z), s=1)
    # ax.set_facecolor('w')
    # ax.legend(prop={'size': labelsize}, loc='upper left', frameon=False,bbox_to_anchor=(1.05, 1))
#     ax.legend(prop={'size': 5}, loc='best', frameon=False)
    ax.set_title(names[j],fontsize=labelsize)
    ax.tick_params(axis='both', which='major', labelsize=labelsize,length=3, direction="out",colors=axcol,bottom=True,left=True)
    ax.tick_params(axis='both', which='minor', labelsize=labelsize,length=3, direction="out",colors=axcol,bottom=True,left=True)
#     ax.tick_params(axis='x', which='both', labelsize=labelsize,length=3, direction="out",colors=axcol,bottom=0,left=0,labelbottom=0)

#     if j==J-1:
#         ax.set_xlabel('IoU threshold', fontsize = labelsize)
#         ax.tick_params(axis='x', which='both', labelsize=labelsize,length=3, direction="out",colors=axcol,bottom=True,left=True,labelbottom=1)
#         ax.tick_params(axis='x', which='minor', labelsize=labelsize,length=3, direction="out",colors=axcol,bottom=True,left=True,labelbottom=1)

#     ax.set_ylim((0,1))
    ax.set_xlim((.5-.05,1))
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.patch.set_alpha(0.0)

plt.subplots_adjust(bottom=.3,left=.1)
fig.patch.set_facecolor(background_color)
fig.supylabel('Jaccard Index', fontsize = labelsize)
fig.supxlabel('IoU matching threshold', fontsize = labelsize)
# swap last two 
# pos = axs[-1].get_position()
# axs[-1].set_position(axs[-2].get_position())
# axs[-2].set_position(pos)
# 
# plt.set(xlabel='IoU threshold', ylabel='Average Precision',fontsize=labelsize)
# plt.tight_layout()
# plt.yscale('log')



plt.show()

tight_bbox_raw = ax.get_tightbbox(fig.canvas.get_renderer())
a = 50
tight_bbox_raw._points+=[[-a,-a],[a,a]]
tight_bbox = mpl.transforms.TransformedBbox(tight_bbox_raw, mpl.transforms.Affine2D().scale(1./fig.dpi))

## Compute fractional difference between Omnipose and Cellpose

In [None]:
cutoff = 80 # IoU cutoff of 0.8
if dataset=='bact_fluor':
    i1 = 0
    i2 = 1
else:
    i1 = 0
    i2 = 2
print('Comparing models {} and {}'.format(names[i1],names[i2]))
cp = np.mean(ap[i1][0],axis=0).T
om = np.mean(ap[i2][0],axis=0).T

av_all = np.mean(np.divide((om-cp),cp,out=np.zeros_like(cp), where=np.logical_and(cp!=0,~np.isnan(cp))))
av_up = np.mean(np.divide((om[cutoff:]-cp[cutoff:]),cp[cutoff:],out=np.zeros_like(cp[cutoff:]), where=np.logical_and(cp[cutoff:]!=0,~np.isnan(cp[cutoff:]))))
av_dwn = np.mean(np.divide((om[:cutoff]-cp[:cutoff]),cp[:cutoff],out=np.zeros_like(cp[:cutoff]), where=np.logical_and(cp[:cutoff]!=0,~np.isnan(cp[:cutoff]))))

text = ['average percent difference across all IoUs: {}'.format(av_all*100),
        'percent difference above and below IoU={}: {},{}'.format(cutoff/100,av_up*100,av_dwn*100),
        'percent difference of averages across all IoUs: {}'.format((np.mean(om)-np.mean(cp))/np.mean(cp)*100)] # turn this into saving a text file 

with open(os.path.join(save0,'relative_performance.txt'), "w") as text_file:
    [print(t,file=f) for t in text for f in [text_file,None]]

In [None]:
save0

# Figure 4f - worm dataset violin plot
To run this section, ensure the worm dataset has been loaded and all non-dataset-specific cells above have been run. 

In [None]:
sz = 2
lw = 0.8
fig, axs = plt.subplots(figsize=(sz,sz))
plt.subplots_adjust(wspace=.4, hspace=.4)
# axs = axs.flatten()
np.random.seed(123)
w0 = 0.7
split=1
box = False
scatter = 0
violin = 1


# for j in [-1]:
ax = axs

# x = np.linspace(0,1,len(y))
text = []
for j in range(J):
    percents = [np.sum(np.array(yi)>0.8)/counts[k]*100 for k,yi in enumerate(y[j])]
    overall_percent = np.sum([np.sum(np.array(yi)>0.8) for yi in y[j]]) / np.sum(counts)
    # text+=[names[j], 'Percent above 0.8',percents,overall_percent]#,height,len(y[j]))
    text += [names[j],'Percent above 0.8:\nBy category: {} \nOverall: {} \n'.format(percents,overall_percent)]#,height,len(y[j]))

w = w0 / 3
j = 0
x = np.linspace(j-w,j+w,len(y))
height = [np.mean(yi) for yi in y[j]]

yp = [np.array(yi).flatten() for yi in y]

parts = ax.violinplot(yp,positions=x,showextrema=0,widths=w0/(len(y)+1)) #showmeans=1,quantiles = [[.25,.75]]*len(x),
boxes = ax.boxplot(yp,positions=x,showfliers=False,widths=w0/35,patch_artist=True,
                   boxprops=dict(facecolor='w', color='k',linewidth=lw),whiskerprops=dict(color='k',linewidth=lw),
                   medianprops=dict(color='k',linewidth=lw))

for i,pc in enumerate(parts['bodies']):
    pc.set_facecolor(colors[i])
    pc.set_alpha(1)
    pc.set_edgecolor(colors[i])
    pc.set_linewidth(lw)

ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.patch.set_alpha(0.0)

ax.set_ylim(0,1)      
ax.set_xlim((j-2*w0/4,j+2*w0/4))
ax.set_xticks([])
plt.show()   
tight_bbox_raw = ax.get_tightbbox(fig.canvas.get_renderer())
a = 50
tight_bbox_raw._points+=[[-a,-a],[a,a]]
tight_bbox = mpl.transforms.TransformedBbox(tight_bbox_raw, mpl.transforms.Affine2D().scale(1./fig.dpi))

fig.savefig(os.path.join(save0,'Iou_violin.pdf'),bbox_inches=tight_bbox)

with open(os.path.join(save0,'IoU_stats.txt'), "w") as text_file:
    [print(t,file=f) for t in text for f in [None,text_file]]

## Plot Segmentation error examples
Most of what we observe in small cells are ambiguous division events - cells that may or may not be divided. 1 error is either a missing mask / 2 masks merged into one (not yet divided even though GT says it is) or split into 2 masks (has divided even though GT says it has not). 

In [None]:
import ncolor 
from cellpose import transforms
import cellpose, omnipose 

#code developing the error ideas 
# NOTE: broken for flows that are not rescaled, need to add zoom command
ind = range(4)

from matplotlib.colors import ListedColormap
plt.style.use('dark_background')
# now find and crop specific cells for comparison 
# oversegmneted cells will have low IoUs spread out over several hits; let's take a specific image from the rpevious list as an example. We can see that there are three true labels, but 
# several predicted labels. The IoU can be at most 1, but small precited labels will always be quite small, up to 1/(area of the true label). But we don't really want to care about the 
cm2 = ListedColormap([color for color in sinebow(5).values()])

k = 0
j = -1 # look at Omnipose

m = 0
r_pred = OvR[j][k].copy()
mx = np.max(r_pred,axis=0) 
# std = np.std(r_pred,axis=0)
mx[mx==0] = np.nan
hits_pred = r_pred==mx
err = np.abs(np.sum(hits_pred,axis=1)-1)

# err = cell_errors[j][k]# assuming proper precomputing 
areas = cell_areas[k]
error_cutoff = 1
area_cutoff = 10
err_indexes = np.nonzero(np.logical_and(err>=error_cutoff,areas>area_cutoff))[0]

pad = 5

for l in err_indexes[0:10]:
    print(err[l])
    mgt = masks_gt_clean[k]
    y,x = np.nonzero(mgt==l+1)
    max_y,max_x = np.array(mgt.shape)-1
    
    y0 = max(0,min(y)-pad)
    y1 = min(max_y,max(y)+pad)
    x0 = max(0,min(x)-pad)
    x1 = min(max_x,max(x)+pad)
    slc = (Ellipsis,)+omnipose.utils.bbox_to_slice([y0,y1,x0,x1],mgt.shape)
    p = transforms.normalize99(imgs[k][slc],omni=True)
    p = transforms.move_min_dim(p)
    if p.ndim>2:
        p=p[...,0]
    p = p
    
    mask_gt = mgt[slc]
    bini = mask_gt==l+1
    img0 = np.stack([p]*3,axis=-1)
    # outli = transforms.normalize99(plot.outline_view(img0*255,bini),omni=True)
    bin0 = np.stack([bini]*3,axis=-1)
    # gt_pic = np.hstack((img0,outli,bin0))
    mask_pred = masks[j][k][slc]
    # flow_pred = transforms.normalize99(flows[j][k][0][y0:y1,x0:x1],omni=True)[:,:,:3] #transparency an issue here, at least with concatenating 
    mu = flows[j][k][1][slc]
    flow_pred = plot.dx_to_circ(mu)/255
    
    inds = np.where(hits_pred[l,:])
    rmask = np.zeros_like(mask_gt)
    for i,label in enumerate(inds[0]):
        rmask[mask_pred==label+1] = i+1

    
    outl_pred = plot.outline_view(img0,rmask)/255
    res_pred = np.hstack((img0,bin0,cm2(ncolor.label(mask_pred,max_depth=10))[:,:,:3],outl_pred,flow_pred))
    plt.figure(figsize=[2,2])
    plt.imshow(res_pred)
    plt.axis('off')
    plt.show()

# Extended Data Figure 4a
For this plot, I should put both the training and test errors. The training dataset has a lot more long cells in it. In fact, I should do the analysis with the train, test, and combined to show correlation and lack of overfitting. 

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import (MultipleLocator, AutoMinorLocator)
from matplotlib.ticker import FormatStrFormatter
import matplotlib.patches as patches
from scipy.stats import gaussian_kde
import cmasher as cmr
import ncolor


def custom_ceil(x,base=10):
    return int(np.ceil(x / float(base)) * base)

color = [1,0,0,.5]
axcol = 'k'

# some random data
# x = np.random.randn(1000)
# y = np.random.randn(1000)
j = 0 # which algorithm
x = np.concatenate(cell_areas)
y = np.concatenate(cell_errors[j])
xmin = 10
xmax = 10**5
# cm = plt.cm.get_cmap('plasma')
cm = cmr.get_sub_cmap('plasma', 0.2, 0.95)

mpl.rcParams['figure.dpi'] = 1000
rasterized = True

def scatter_hist(x, y, ax, ax_histx, ax_histy,zorder=1, hist = 1):
    # no labels
    ax_histx.tick_params(axis="x", labelbottom=False)
    ax_histy.tick_params(axis="y", labelleft=False)

    # the scatter plot:
#     ax.scatter(x, y, s=3,edgecolors='none',c=color)
    
#         y = np.array(ap[j]).T.flatten()
#         X = np.repeat(x,len(ap[j][0]))
    X = x.copy()
    xy = np.vstack([X,y])
    z = gaussian_kde(xy)(xy)
    idx = z.argsort()
    X, y, z = X[idx], y[idx], z[idx]
    ax.scatter(X, y, c=np.log(z), s=1,cmap=cm,zorder=zorder,rasterized=rasterized)
    
    
    if hist:
        # now determine limits by hand:
        binwidth1 = 100
        binwidth2 = 1
        eps = 1e-20
        xymax = max(np.max(np.abs(x)), np.max(np.abs(y)))
        lim = (int(xymax/binwidth1) + 1) * binwidth1
        pp = 1
        bins1 = np.arange(xmin, lim + binwidth1, binwidth1)
        bins2 = np.arange(0, 55 + binwidth2, binwidth2)
        logbins1 = np.logspace(np.log10(bins1[0]),np.log10(bins1[-1]),len(bins1))

        # plot histogram
        n,bins,patches = ax_histx.hist(x, bins=logbins1,color=color)
        col = (n-n.min())/(n.max()-n.min())+eps
        for c, p in zip(col, patches):
            plt.setp(p, 'facecolor', cm(c**pp))

        n,bins,patches = ax_histy.hist(y, bins=bins2, orientation='horizontal',color=color,align='left')
        col = (n-n.min())/(n.max()-n.min())+eps
        for c, p in zip(col, patches):
            plt.setp(p, 'facecolor', cm(c**.25))
    
    
#     ax_histx.tick_params(axis='x',which='both',bottom=False,top=False)
#     ax_histy.tick_params(axis='y',which='both',left=False,right=False)
#     ax_histy.get_xaxis().set_ticks([])
#     ax_histy.get_yaxis().set_ticks([])
    

#     plt.xlim([0,10**3])
# start with a square Figure
# a = 1+golden
a = 5
b = 1
c = 2.5
labelsize = 9
fig = plt.figure(figsize=(c, c))
xlog = 1

# Add a gridspec with two rows and two columns and a ratio of 2 to 7 between
# the size of the marginal axes and the main axes in both directions.
# Also adjust the subplot parameters for a square plot.

gs = fig.add_gridspec(2, 2,  width_ratios=(a, b), height_ratios=(b, a),
                      left=0.1, right=0.9, bottom=0.1, top=0.9,
                      wspace=0.075, hspace=0.075)

ax = fig.add_subplot(gs[1, 0])
ax_histx = fig.add_subplot(gs[0, 0], sharex=ax)
if xlog:
    plt.xscale('log')
# plt.yscale('log')
ax_histy = fig.add_subplot(gs[1, 1], sharey=ax)

# plt.xscale('log')
x0 = np.percentile(x,75)
x1 = np.max(x)
# x1 = 10**5
y1 = np.max(y)
y1 = custom_ceil(y1,50)
ax.add_patch(
     patches.Rectangle(
        (x0, 0),
        (x1-x0),
        y1,
        edgecolor = None,
        facecolor = [.2]*3 if darkmode else [.8]*3,
        fill=True,
        zorder=1
     ))
# ax.vlines(x0,0,y1,colors=[.5]*4,linestyles='dashed',linewidth=1)
# use the previously defined function to do scatter plot
sel = x>=5 # remove the 10ish 1-5px mistakes in GT
scatter_hist(x[sel], y[sel], ax, ax_histx, ax_histy, zorder=2, hist=0)
# plt.xscale('log')
# plt.yscale('log')
ax.set_ylabel(names[j]+' \n Segmentation errors', fontsize = labelsize)
ax.set_xlabel('Cell area (px$^2$)', fontsize = labelsize)
ax.set_ylim(0,y1)
ax.set_xlim(1e0,custom_ceil(x1,10**4.3))
ax.patch.set_alpha(0.0)

fig.patch.set_facecolor(background_color)

# set x ticks
if xlog:
    # plt.xscale('log')
    x_major = mpl.ticker.LogLocator(base = 10.0, numticks = 10)
    ax.xaxis.set_major_locator(x_major)
    x_minor = mpl.ticker.LogLocator(base = 10.0, subs = np.arange(1.0, 10.0) * 0.1, numticks = 10)
    ax.xaxis.set_minor_locator(x_minor)
    ax.xaxis.set_minor_formatter(mpl.ticker.NullFormatter())

# set y ticks
# y_major = mpl.ticker.LinearLocator(numticks = 1+5)
# ax.yaxis.set_major_locator(y_major)
# y_minor = mpl.ticker.LogLocator(base = 10.0, subs = np.arange(1.0, 10.0) * 0.1, numticks = 10)
# ax.yaxis.set_minor_locator(y_minor)
# ax.yaxis.set_minor_formatter(mpl.ticker.NullFormatter())

ax.tick_params(axis='both', which='major', labelsize=labelsize)

# x_major = matplotlib.ticker.LogLocator(base = 10.0, numticks = 10)
# ax_histy.yaxis.set_major_locator(x_major)
# x_minor = matplotlib.ticker.LogLocator(base = 10.0, subs = np.arange(1.0, 10.0) * 0.1, numticks = 10)
# ax_histy.yaxis.set_minor_locator(x_minor)
# ax_histy.yaxis.set_minor_formatter(matplotlib.ticker.NullFormatter())

# turn off hists
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax_histy.axis('off')
ax_histx.axis('off')

# ax_histy.spines['right'].set_visible(False)
# ax_histy.spines['top'].set_visible(False)
# ax_histy.spines['bottom'].set_visible(False)
# ax_histy.spines['left'].set_visible(False)
# ax_histy.get_xaxis().set_ticks([])
# ax_histy.get_yaxis().set_ticks([])
# ax_histy.get_yaxis().set_visible(False)
# plt.xlim([xmin,xmax])

plt.show()
for ext in ['.pdf','.eps']:
    fig.savefig(os.path.join(save0,'errors_vs_area_hist'+suffix+'_'+names[j]+ext),bbox_inches='tight') #<<<

# fig.savefig('/home/kcutler/DataDrive/omnipose_paper/errors_vs_area_hist'+suffix+'.eps',bbox_inches='tight',transparent=True,pad_inches=0)
# fig.savefig('/home/kcutler/DataDrive/omnipose_paper/errors_vs_area_hist'+suffix+'_'+names[j]+'.pdf',bbox_inches='tight') #<<<
# fig.savefig('/home/kcutler/DataDrive/omnipose_paper/errors_vs_area_hist.png',bbox_inches='tight')

#  Extended Data Figure 4b

This portion of the script depends on the bact_phase dataset being loaded. It could be adapted fairly easily for the other datasets with alternative choices for image names and cell indexes.

## Select examples of errored cells
From inspection of the test image set, I defined the `img_list` with candidates for making a good figure. This is entirely qualitative, but I wanted to make sure that I included both branched and unbranched morphologies as well as over-segmentation by discrete clustering and over-segmentation in extended clustering. 

In [None]:
import os
from scipy.ndimage import binary_erosion, binary_dilation
from cellpose import utils, transforms
from omnipose.utils import sinebow
import ncolor
import cv2
from cellpose.io import imsave
from PIL import Image, ImageFont, ImageDraw, ImageOps
from matplotlib.colors import ListedColormap
import os, datetime, gc, warnings, glob



plt.style.use('dark_background')
cmap = mpl.cm.get_cmap('viridis')
cm2 = ListedColormap([color for color in sinebow(5).values()])
cm3 = ListedColormap([color for color in sinebow(5,bg_color=[1,1,1,1]).values()])

# basedir = '/home/kcutler/DataDrive/omnipose_paper/Comparison Examples/newcompare4/'
# io.check_dir(basedir)

area_cutoff = 20
error_cutoff=2
pad = 10
N = len(masks)
nimg = len(masks_gt)
bkct = 500
bg = 0.5
ext = '.png'

cnt = 0
txtpad = 10
yoffset = [0,0]

img_list = ['ftsN_ensemble_0','caulo_15','streptomyces_XY15_1','Az_branch_ec_0','Hpylori2_2',
            'PSVB_ensemble_c_8','PSVB_ensemble_c_5','PSVB_ensemble_c_2','PSVB_ensemble_c_0','PSVB_ensemble_c_12','PSVB_ensemble_c_11','PSVB_ensemble_c_10',
            'vibrio_ensemble_2_19','wiggins_ensemble_1','wiggins_ensemble_11']

# cells were chosen by using the Matlotlib Widget and hovering over cells 
# that I wanted to extract to get the ground truth index 
# e.g. plt.imshow(masks_gt_clean[img_index[6]])
cell_list = [10,8,None,33,44,
             160,193,31,58,12,49,10,#[180,194,201]
             None,399,119]# [370,399]

namelist = [os.path.splitext(os.path.split(file)[-1])[0] for file in img_names]
img_index = [namelist.index(im) for im in img_list]
[print(l,k) for l,k in zip(cell_list,img_index)]

n_exmpl = len(img_list)
label_list = []*n_exmpl
coords = []*n_exmpl

for j in [0]: #base selection on cellpose examples
    for l,k in zip(cell_list,img_index): # replace loops over k and l 
        
        file = img_names[k]
        basename = os.path.splitext(os.path.split(file)[-1])[0]

        mgt = masks_gt_clean[k]
        p = transforms.normalize99(imgs[k],omni=True)
        img0 = p**(np.log(bg)/np.log(np.mean(p[binary_erosion(mgt==0)])))

        if l is not None:
            if isinstance(l, list):
                hits = np.any(np.stack(([mgt==li+1 for li in l])),axis=0)
            else:
                hits = mgt==l+1
            
            microcolonies = mgt>0
            labels = skimage.measure.label(microcolonies)
            binmask = hits.copy()
            for cell_ID in np.unique(labels[labels>0]):
                mask = labels==cell_ID 
                area = np.count_nonzero(mask)
                overlap = np.count_nonzero(np.logical_and(mask, binary_dilation(hits, iterations=1)))
                if overlap > 0: #only premove cells that are 50% or more edge px
                    binmask[mask] = 1

            y,x = np.nonzero(binmask)
            max_y,max_x = np.array(mgt.shape)-1

            y0 = max(0,min(y)-pad)
            y1 = min(max_y,max(y)+pad)
            x0 = max(0,min(x)-pad)
            x1 = min(max_x,max(x)+pad)

            # p = img0[y0:y1,x0:x1]
            # mask_gt = mgt[y0:y1,x0:x1]
            xy = [y0,y1,x0,x1]
        else:
            ly,lx = mgt.shape
            xy = [0,ly,0,lx]
        
        coords.append(xy)


## Plot the pixel trajectories 

In [None]:
# re-run both models on full images to reproduce the exact reults from main code
# (Crop later; cropped images can have slight significant differences in results for very small crops). 
chans = [0,0] 
names = ['Cellpose','Omnipose']
suffixes = ['_CP','_OM']
# imglist = [imgs[k] for k in img_index]
imglist = []

# as it turns out, the sample I ultimately selected needs to be gamma-normalized
# to reproduce the exact flow field in the paper. Omnipose is not sensitive to this since I
# have a gamma augmentation, but Cellpose is, and the flow changes a lot based on that 
bg = 0.5
for k in img_index:
    mgt = masks_gt[k].copy()
    p = transforms.normalize99(imgs[k],omni=True)
    img0 = p**(np.log(bg)/np.log(np.mean(p[binary_erosion(mgt==0)])))
    imglist.append(img0)
# imglist = [imgs[k] for k in [img_index[4]]]

J = len(names)
nimg = len(imgs)
masks,flows,styles,d = [[]]*J,[[]]*J,[[]]*J,[[]]*J
for j in range(J):
# for j in [1]:
    masks[j], flows[j], styles[j], = model[j].eval(imglist,channels=[0,0], mask_threshold=-1, diameter=0,cluster=j,
                                                   flow_threshold=0, omni=j, calc_trace=True, min_size=0, tile=False,
                                                   transparency=transparency, verbose=verbose)

In [None]:
names

In [None]:
# show the results
%matplotlib inline
from cellpose import plot
j = 0
nimg = len(imgs)
for idx,im in enumerate(imglist):
    maski = masks[j][idx]
    flowi = flows[j][idx][0]
    fig = plt.figure(figsize=(12,5))
    plot.show_segmentation(fig, im, maski, flowi, channels=chans, omni=True, bg_color=0)
    plt.tight_layout()
    plt.show()

## Figure 2b

This generates our figure demonstrating how Cellpose over-segments cells. `do_flows` toggles whether or not to do the time-consuming flow line vector graphic showing how boundary pixels coalesce. `isolated` toggles whether or not the panels generate that information only for selected cells or for all cells in the FoV. 

In [None]:
import edt
from matplotlib import rc
from omnipose import utils
import ncolor
import cmasher as cmr
from skimage import filters

subdir = 'error_examples'
# io.check_dir(subdir)

darkmode=0
do_flows = 1
isolated = 0
recompute = 0

mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42
mpl.rcParams['font.family'] = 'sans-serif'
rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
import matplotlib_inline
# matplotlib_inline.backend_inline.set_matplotlib_formats('retina', 'png')
mpl.rcParams['figure.dpi'] = 72
%matplotlib inline
A = 1
px = A/plt.rcParams['figure.dpi']  # pixel in inches
# cmap = mpl.cm.get_cmap('plasma')
cmap = cmr.get_sub_cmap('plasma', 0, 0.95)
cmap2 = cmr.get_sub_cmap('gray', 0, 1)



for j in range(J): # select from models
# for j in [-1]:
    suffix = suffixes[j]
    omni = 'OM' in suffix
    print(suffix,omni)
    # for k in range(n_exmpl): # select from images 
    # for k in [4]:
    for k in [3,4]:
    
    # for k in [6]: # limit to this particular cell 
    
        xy = coords[k].copy() # for cropping
        ly,lx = xy[1]-xy[0],xy[3]-xy[2]
        
        # outline thickness
        if px*lx<2:
            mode = 'inner'
        else:
            mode = 'thick'
        
        mgt = masks_gt[img_index[k]].copy()
        # p = transforms.normalize99(imglist[k],omni=True)
        # img0 = p**(np.log(bg)/np.log(np.mean(p[binary_erosion(mgt==0)])))
        img0 = imglist[k]
        
        mpred = masks[j][k].copy()
        dP_pred = flows[j][k][1].copy()
        cellprob = flows[j][k][2].copy()
        bd = flows[j][k][-2].copy() if omni else np.zeros_like(cellprob)
        bd = 1/(1+np.exp(-bd))
        plt.imshow(utils.rescale(bd))
        plt.show()
        tr = flows[j][k][-1][0].copy()

        # make veritcal
        if ly<lx:
            print('transposing')
            mgt = np.transpose(mgt)
            img0 =  np.transpose(img0)
            dP_pred = np.transpose(dP_pred,(0,2,1))
            #traspose
            v1 =  dP_pred[1].copy()
            v2 =  dP_pred[0].copy()
            theta = 90
            dP_pred[0] = (-v1 * np.sin(-theta) + v2*np.cos(-theta))
            dP_pred[1] = -(v1 * np.cos(-theta) + v2*np.sin(-theta))
            tr = np.stack((tr[1],tr[0]))
            cellprob = np.transpose(cellprob)
            bd = np.transpose(bd)
            mpred = np.transpose(mpred)
            xy = [xy[2],xy[3],xy[0],xy[1]]
        
        #crop
        mpred_full = mpred.copy()
        mpred = mpred[xy[0]:xy[1],xy[2]:xy[3]]
        mgt = mgt[xy[0]:xy[1],xy[2]:xy[3]]
        cellprob = cellprob[xy[0]:xy[1],xy[2]:xy[3]]
        bd = bd[xy[0]:xy[1],xy[2]:xy[3]]
        dP_pred = dP_pred[:,xy[0]:xy[1],xy[2]:xy[3]]
        img0 = img0[xy[0]:xy[1],xy[2]:xy[3]]
        match = np.argwhere([np.logical_and.reduce((xy[3]>=tr[1,i,0],xy[2]<=tr[1,i,0],xy[0]<=tr[0,i,0],xy[1]>=tr[0,i,0])) for i in range(tr.shape[1])]).flatten()
        tr = np.stack(([tr[0,i,:]-xy[0] for i in match],[tr[1,i,:]-xy[2] for i in match]))
        # for some reason, the ones I selected came from running the algorithm on
        # the cropped field of view. This actually changes the output. 
        
        if recompute:
            mpred,f,s = model[j].eval(img0,channels=[0,0], mask_threshold=-1, diameter=0,cluster=j,
                                                   flow_threshold=0, omni=j, calc_trace=True, min_size=0, tile=False,
                                                   transparency=transparency, verbose=verbose)
            

            dP_pred = f[1].copy()
            cellprob = f[2].copy()
            bd = f[-2].copy() if omni else np.zeros_like(cellprob)
            bd = 1/(1+np.exp(-bd))
            tr = f[-1][0].copy()

        
        
        flow_pred = plot.dx_to_circ(dP_pred,transparency=transparency)
        l = cell_list[k] # cell label list
        
        bini = mgt>0
        bin0 = np.stack((bini,bini,bini),axis=2)
        
        savedir = os.path.join(save0, subdir, img_list[k] + '_cell_number_'+ str(l))
        io.check_dir(savedir)
        path = os.path.join(savedir,'perim_flows' + suffix)
        print(path)
        
        skimage.io.imsave(os.path.join(savedir,'phase'+ext),np.uint8(img0*255))
        skimage.io.imsave(os.path.join(savedir,'flow'+suffix+ext),np.uint8(flow_pred))
        pic = cmap(utils.rescale(cellprob))
        pic[:,:,-1] = utils.rescale(cellprob)
        skimage.io.imsave(os.path.join(savedir,'cellprob'+suffix+ext),((pic)*255).astype(np.uint8))
        
        pic = cmap2(utils.normalize99(bd))
        pic[:,:,-1] = utils.rescale(bd)
        skimage.io.imsave(os.path.join(savedir,'boundary'+suffix+ext),((pic)*255).astype(np.uint8))
        
        mask_threshold = -1
        if omni:
            mask0 = filters.apply_hysteresis_threshold(cellprob, mask_threshold-1, mask_threshold)
        else:
            mask0 = cellprob>mask_threshold
        b = np.zeros_like(mask0)
        pic = np.stack((mask0,b,b),axis=-1)
        skimage.io.imsave(os.path.join(savedir,'cellprob_thresh'+suffix+ext),((pic)*255).astype(np.uint8))

        if isolated:
            mag = transforms.normalize99(np.sqrt(np.sum(dP_pred**2,axis=0)),omni=True)
            f = flow_pred.copy()
            flow_gray = 0.2125*f[:,:,0] + 0.7154*f[:,:,1] + 0.0721*f[:,:,2]
            m = bini==0
            f[m] = np.stack([flow_gray,flow_gray,flow_gray,mag],axis=-1)[m]
            skimage.io.imsave(os.path.join(savedir,'flow_gray'+suffix+ext),np.uint8(f*255))
        
#         ncolor_gt = utils.ncolor.label(mgt)
        m = mgt.copy()
        u = np.unique(mgt)
        U = len(u)
        A = 100
        v = [0]+list(np.linspace(.25,.55,U-1)*A)
        res = dict(zip(u, v))
        m = fastremap.remap(m,res,preserve_missing_labels=False, in_place=True)
        ncolor_gt = utils.rescale(ncolor.label(m.copy()))
        ncolor_gray = np.stack([utils.rescale(ncolor_gt)]*3,axis=-1)
        skimage.io.imsave(os.path.join(savedir,'ncolor_gray_masks'+ext),np.uint8(ncolor_gray*255))

        ncolor_gray = np.stack([1-ncolor_gt]*3,axis=-1)
        skimage.io.imsave(os.path.join(savedir,'ncolor_gray_masks_inv'+ext),np.uint8(ncolor_gray*255))
        
        ncolor_pred = utils.rescale(ncolor.label(mpred.copy()))
        skimage.io.imsave(os.path.join(savedir,'ncolor_pred'+ext),np.uint8(ncolor_pred*255))
        

        outli = plot.outline_view(img0,mpred,color=cmap(.85)[:3],mode=mode)
        skimage.io.imsave(os.path.join(savedir,'outline_view_gold'+suffix+ext),np.uint8(outli))
        outli = plot.outline_view(img0,mpred,mode=mode)
        skimage.io.imsave(os.path.join(savedir,'outline_view_red'+suffix+ext),np.uint8(outli))
       
        mgt, remap = fastremap.renumber(mgt)
        iou = metrics._intersection_over_union(mgt, mpred)
        th = 0
        n_min = min(iou.shape[0], iou.shape[1])
        costs = -(iou >= th).astype(float) - iou / (2*n_min)
        true_ind, pred_ind = linear_sum_assignment(costs)
        miou = iou[true_ind, pred_ind]

        # print('match',miou,pred_ind,mask_match,pred_inds,true_ind)
        with open(os.path.join(savedir,'MMiou_and_area'+suffix+'.txt'), "w") as text_file:
            text = f"Mean Matched IoU: {np.mean(miou)}\nAreas: {[np.sum(mgt==l) for l in np.unique(mgt)]}\nThis area: {np.sum(bin0)}"
            print(text, file=text_file)
                               
        if do_flows:
            dists = edt.edt(mpred)
            bd = dists==1
            
            Y,X = np.nonzero(bd)
            # Y = Y[np.logical_and(Y>=xy[0],Y<=xy[1])]
            # X = X[np.logical_and(X>=xy[2],X<=xy[3])]
            a = .5
            match0 = [np.any(np.logical_and((X-tr[1,i,0])**2<a,(Y-tr[0,i,0])**2<a)) for i in range(tr.shape[1])]
            select_inds = np.argwhere(match0).flatten()
            
            Y,X = np.nonzero(mpred)
            # Y = Y[np.logical_or(Y>=xy[0],Y<=xy[1])]
            # X = X[np.logical_or(X>=xy[2],X<=xy[3])]
            a = .5
            match2 = [np.any(np.logical_and((X-tr[1,i,0])**2<a,(Y-tr[0,i,0])**2<a)) for i in range(tr.shape[1])]
            select_inds2 = np.argwhere(match2).flatten()

            lx = mpred.shape[1]
            ly = mpred.shape[0]

            fig,ax = plt.subplots(figsize=(ly*px,lx*px))
            # ax.imshow(dists)
            ax.set_aspect('equal')
            ax.axis('off')
            ax.set_position([0, 0, 1, 1])
            
            for i in select_inds:
                xs = tr[1,i,:]
                ys = tr[0,i,:]
                if darkmode:
                    c = [1,1,1,.5]
                    s2 = '_darkmode'
                else:
                    c = [0,0,0,.25]
                    s2 = '_lightmode'
                ax.plot(xs,ys,c=c,solid_capstyle='round',linewidth=1/2,zorder=1)
                
            for i in select_inds2:
                ax.scatter(tr[1,i,-1],tr[0,i,-1],marker='.',s=px*lx,edgecolor=None,facecolor='r',zorder=2)
            
            ax.set_xlim([0,lx])
            ax.set_ylim([ly,0])
            ax.patch.set_alpha(0.)
            fig.patch.set_facecolor(None)

            plt.show()
            
            fig.savefig(path+s2+'.pdf',bbox_inches='tight',transparent=True,pad_inches=0)
            fig.savefig(path+s2+'.png',bbox_inches='tight',transparent=True,pad_inches=0,dpi=1000)
        

In [None]:
u

In [None]:
%matplotlib inline
k=3
j = 0
# xy = [xy[2],xy[3],xy[0],xy[1]]
bg = 0.5
mgt = masks_gt[img_index[k]].copy()
p = utils.rescale(imglist[k])
# p = imglist[k]
img0 = p**(np.log(bg)/np.log(np.mean(p[binary_erosion(mgt==0)])))
im = (img0)
# im = p
m,f,s = model[j].eval(im,channels=[0,0], mask_threshold=-1, diameter=0,cluster=j,
                                                   flow_threshold=0, omni=j, calc_trace=True, min_size=0, tile=False,
                                                   transparency=transparency, verbose=verbose)

maski = m
flowi = f[0]
fig = plt.figure(figsize=(12,12))
plot.show_segmentation(fig, im, maski, flowi, channels=chans, omni=True, bg_color=0)
plt.tight_layout()
plt.show()

## Export Error examples
From cell errors we can look at tons of examples. My current code just gives every single instance of an error. We see from the above histogram that my Omnipose barely gives any errors above 1 (these are issues with cell division, I think). 

- Enforce a minimum crop size 
- Sort by number of errors 
- Dump images into a PDF with info 

In [None]:
import os
from scipy.ndimage.morphology import binary_erosion 
from cellpose import utils, transforms, plot
import cv2
from cellpose.io import imsave
from PIL import Image, ImageFont, ImageDraw, ImageOps
from matplotlib.colors import ListedColormap
import os, datetime, gc, warnings, glob
from fpdf import FPDF

plt.style.use('dark_background')
cmap = mpl.cm.get_cmap('viridis')
cm2 = ListedColormap([color for color in sinebow(5).values()])

basedir = '/home/kcutler/DataDrive/omnipose_paper/Comparison Examples/Per_Cell_GT_Comparison/'
io.check_dir(basedir)
    
cutoff=1
pad = 10
N = len(masks)
nimg = len(masks_gt)
bkct = 500
bg = 0.5


txtpad = 10
yoffset = [0,0]
buffer = 30
do_pdf = 0
clean = 1
# for k in range(nimg):
# for k in [0]:
for j in [0]:
    if do_pdf:
        pdf = FPDF(unit='pt')
        pdf.add_page()
        pdf.set_font('Arial', 'B', 8)
        pdf.set_text_color(r=255, g=255, b=255)
        pdf.set_fill_color(0,0,0)
        ph = pdf.h
        pw = pdf.w
        pdf.page_break_trigger = ph
        pdf.rect(0,0,pw,ph,'F')
        pdf.set_xy(0,0)
        ind = 0
        cnt = 0
    
    
    for k in range(nimg):
#     for k in range(50):
        file = img_names[k]
        basename = os.path.splitext(os.path.split(file)[-1])[0]

        label_list = []
        err = cell_errors[j][k]      
        err_indexes = np.nonzero(err>=cutoff)[0]
        num_errors = err_indexes.size
        if num_errors > 0:
            
            for l in err_indexes:
                name = basename + '_' + names[j] + '_cell_number_'+ str(l) + '.png'
                savepath = os.path.join(basedir, name)
                print(j,k,l,name)
                
                if not os.path.isfile(savepath) or clean: 
                    mgt = masks_gt_clean[k]
                    p = transforms.normalize99(imgs[k],omni=True)
                    img0 = p**(np.log(bg)/np.log(np.mean(p[binary_erosion(mgt==0)])))

                    y,x = np.nonzero(mgt==l+1)
                    max_y,max_x = np.array(mgt.shape)-1

                    y0 = max(0,min(y)-pad)
                    y1 = min(max_y,max(y)+pad)
                    x0 = max(0,min(x)-pad)
                    x1 = min(max_x,max(x)+pad)

                    p = img0[y0:y1,x0:x1]

                    pic = [[]]*2
                    img0 = np.stack((p,p,p),axis=2)
                    maski = masks[j][k][y0:y1,x0:x1]
                    flowi = transforms.normalize99(flows[j][k][y0:y1,x0:x1],omni=True)
                    outli = transforms.normalize99(plot.outline_view(img0,maski),omni=True)
                    pic[0] = np.hstack((img0,outli,flowi))

                    maski = mgt[y0:y1,x0:x1]
    #                 print(np.unique(maski),l)
                    bini = maski==l+1
                    bini = np.stack((bini,bini,bini),axis=2)
                    outli = transforms.normalize99(plot.outline_view(img0,maski),omni=True)
                    pic[1] = np.hstack((cm2(utils.ncolor.label(maski))[:,:,:3],outli,bini))

                    img = np.vstack((pic[0],pic[-1]))
                
                    io.imsave(savepath,np.uint8(img*255))
                
                if do_pdf:
                    im = Image.open(savepath)
                    y = pdf.get_y()

                    h = ((pw/2)/im.size[0])*im.size[1]
                    if yoffset[ind]+h+txtpad+buffer>=ph:
                        pdf.add_page()
                        pdf.rect(0,0,pw,ph,'F')
                        pdf.set_xy(0,0)
                        yoffset = [0,0]
                        ind = 0
                        pdf.set_xy(0,0)
                        cnt = 0
                    elif cnt==0 or cnt==1:
                        pdf.set_xy(ind*pw/2,yoffset[ind])
                    else:
                        pdf.set_xy(ind*pw/2,yoffset[ind]+txtpad)

                    if h<ph:
                        pdf.image(savepath,w=pw/2)
                    else:
                        pdf.image(savepath,h=ph-2*txtpad)
                    pdf.cell(30, 10, txt=name)
    #                 pdf.cell(30, 10, txt=str(j)+','+str(k)+','+str(l))
    #                 pdf.cell(30, txtpad, str(ind))
    #                 print(ind,yoffset)
                    yoffset[ind] = pdf.get_y()
                    ind = (ind+1)%2
                    cnt += 1
    if do_pdf:         
        pdf.output('/home/kcutler/DataDrive/'+names[j]+'_errors.pdf','F')
                

# Figure 5f
This section is specifically for extracting out example images. 

In [None]:
from scipy.ndimage import binary_dilation
def get_bbxs(masks,seeds,pad=10):
    '''
    Get bounding boxes that cover the seed masks.
    Takes label matrix masks, seed list seeds. Seeds can be nested list or [None]. 
    '''
    K = len(masks)
    bboxes = [[]]*K
    for k in range(K): 
        mgt = masks[k]
        microcolonies = mgt>0        
        l = seeds[k]
        labels = skimage.measure.label(microcolonies)
        print(l)
        if l is not None:
            if isinstance(l, list):
                hits = np.any(np.stack(([mgt==li for li in l])),axis=0)
            else:
                hits = mgt==l
            binmask = hits.copy()
            for cell_ID in np.unique(labels[labels>0]):
                mask = labels==cell_ID 
                area = np.count_nonzero(mask)
                overlap = np.count_nonzero(np.logical_and(mask, binary_dilation(hits, iterations=1)))
                if overlap > 0: #only premove cells that are 50% or more edge px
                    binmask[mask] = 1

            y,x = np.nonzero(binmask)
            max_y,max_x = np.array(mgt.shape)-1

            y0 = max(0,min(y)-pad)
            y1 = min(max_y,max(y)+pad)
            x0 = max(0,min(x)-pad)
            x1 = min(max_x,max(x)+pad)
            # bbox = [y0,y1,x0,x1]
            bbox = [y0,x0,y1,x1]
        else:
            ly,lx = mgt.shape
            bbox = [0,0,ly,lx]
        bboxes[k] = bbox
    return bboxes 

if dataset=='worm':
    suffixes = ['_CP','_MX','_OP','_OP_bact_worm']
    subsets = ['wormpose_39','bbc010_44']
    cell_seeds = [24,11]

elif dataset=='cyto2':
    suffixes = ['_CP','_MX','_OP']
    subsets = ['001_img']
    cell_seeds = [[34,36]]

indices = [np.array([i for i,s in enumerate(mask_names) if (name in s)]).astype(int) for name in subsets]
indices = [i[0] for i in indices]

# I picked the worm labels off of omnipose masks,
# cyto2 off of the ground truth
if dataset=='worm':
    masks_s = [masks[-1][i] for i in indices]
else:
    masks_s = [masks_gt[i] for i in indices]

bboxes = get_bbxs(masks_s,cell_seeds)
bboxes

In [None]:
import edt
from matplotlib import rc
from scipy.ndimage import binary_erosion, binary_dilation, zoom
import omnipose
from scipy.optimize import linear_sum_assignment

mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42
mpl.rcParams['font.family'] = 'sans-serif'
rc('font',**{'family':'sans-serif','sans-serif':['Arial']})
import matplotlib_inline
# matplotlib_inline.backend_inline.set_matplotlib_formats('retina', 'png')
mpl.rcParams['figure.dpi'] = 72
%matplotlib inline
A = 1
px = A/plt.rcParams['figure.dpi']  # pixel in inches
cmap = mpl.cm.get_cmap('plasma')
outline_col = cmap(0.85)[:3]
ext = '.png'
bg = 0.5
pad = 10

def getname(path,suffix=''):
    return os.path.splitext(Path(path).name)[0].replace(suffix,'')
field_names = [getname(i) for i in img_names]

savedir = os.path.join(save0,'examples')
io.check_dir(savedir)
J = len(suffixes)
K = len(masks_s)
for j in range(J): # select from models
# for j in [-1]:
    suffix = suffixes[j]
    for i,k in enumerate(indices):
        xy = bboxes[i]

        ly,lx = xy[1]-xy[0],xy[3]-xy[2]
        # outline thickness
        val = px*lx
        print(val,img_names[k])
        if val<3:
            mode = 'inner'
        else:
            mode = 'thick'
            
            
        mgt = masks_gt[k]
        p = omnipose.utils.rescale(imglist[k])
        slc0 = omnipose.utils.bbox_to_slice(xy,mgt.shape)
        slc = (Ellipsis,)+slc0
        
        img0 = p[slc]
        mpred = masks[j][k].copy()[slc]
       
    
        # save yellow/gold outline plots
        for c,cn in zip([outline_col,[1,0,0]],['gold','red']):
            pic = plot.outline_view(img0,mpred,color=c,mode=mode)
            skimage.io.imsave(os.path.join(savedir,field_names[k]+'_outline_'+cn+suffix+ext),pic)
        
        crop_img = img0
        crop_masks = mpred
        flow_rgb = flows[j][k][0]
        scale = np.array(mgt.shape[0:2])/np.array(flow_rgb.shape[0:2])
        crop_flow = zoom(flow_rgb, tuple(scale)+(1,), order=1)[slc0]
        
        # save RGB flows
        savepath = os.path.join(savedir,field_names[k]+'_crop_flows'+suffix+ext)
        skimage.io.imsave(savepath,np.uint8(crop_flow))
        
        # save distance field
        savepath = os.path.join(savedir,field_names[k]+'_crop_dist'+suffix+ext)
        dist = zoom(flows[j][k][2], tuple(scale), order=1)[slc]
        dist = omnipose.utils.rescale(dist)
        cmap = mpl.cm.get_cmap('plasma')
        pic = cmap(dist)
        pic[:,:,-1] = crop_masks>0
        skimage.io.imsave(savepath,np.uint8(pic*(2**8-1)))

        
        
        # save the boundary, only for omnipose
        bd = flows[j][k][4]
        if bd is not None:
            savepath = os.path.join(savedir,field_names[k]+'_crop_bd'+suffix+ext)
            bd = zoom(bd, tuple(scale), order=1)[slc]
            bd = omnipose.utils.rescale(bd)
            cmap = mpl.cm.get_cmap('viridis')
            pic = cmap(bd)
            pic[:,:,-1] = crop_masks>0
            skimage.io.imsave(savepath,np.uint8(pic*(2**8-1)))
        
        
        # save a grayscale mask version for adobe illustator vectorization 
        ncl = ncolor.label(crop_masks)
        grey_n = np.stack([1-omnipose.utils.rescale(ncl)]*3,axis=-1)
        savepath = os.path.join(savedir,field_names[k]+'_masks_gray'+suffix+ext)
        io.imsave(savepath,np.uint8(grey_n*(2**8-1)))
        
        # save the cropped image, RGB uint8 is not interpolated in illustrator ;) 
        img0 = transforms.move_min_dim(img0)
        if img0.shape[-1] < 3 or img0.ndim < 3:
            img0 = plot.image_to_rgb(img0, channels=chans[0] if len(chans)>1 else chans, omni=omni)
        savepath = os.path.join(savedir,field_names[k]+'_crop_img'+ext)
        io.imsave(savepath,np.uint8(img0))
        
        if 1:
            mgt, remap = fastremap.renumber(mgt[slc])
            mpred, remap = fastremap.renumber(mpred)
            iou = metrics._intersection_over_union(mgt, mpred)
            th = 0
            n_min = min(iou.shape[0], iou.shape[1])
            costs = -(iou >= th).astype(float) - iou / (2*n_min)
            true_ind, pred_ind = linear_sum_assignment(costs)
            miou = iou[true_ind, pred_ind]
            with open(os.path.join(savedir,'MMiou'+field_names[k]+suffix+'.txt'), "w") as text_file:
                print(f"Mean Matched IoU: {np.mean(miou)}", file=text_file)