## Figure 6, *C. elegans* and cyto2 datasets 

Use the 'which' index to run the analysis. 

In [None]:
%load_ext autoreload
%autoreload 2

selection = 0
dataset = ['bact_phase','bact_fluor','cyto2','worm']
basedir = ['/home/kcutler/DataDrive/omnipose_train/worm_combined/test',
           '',
           '']

from pathlib import Path
from cellpose import plot, models, io, dynamics
from omnipose import utils
import skimage.io
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('dark_background')
import matplotlib as mpl
%matplotlib inline
mpl.rcParams['figure.dpi'] = 300


## Load ground truth

In [None]:
mask_filter = '_masks'
img_names = io.get_image_files(basedir[selection],mask_filter,look_one_level_down=True)
mask_names = io.get_label_files(img_names, mask_filter)
imgs = [skimage.io.imread(f) for f in img_names]
masks_gt = [skimage.io.imread(f) for f in mask_names]

In [None]:
from omnipose.utils import format_labels
# count masks
cnt = 0;
for mask in masks_gt:
    lbls = np.unique(format_labels(mask))
    cnt += len(lbls[lbls>0])
print(cnt)

### Initialize models
There are three models we are comparing:
1. Omnipose model on worm datast only
2. Cellpose model on worm dataset only
3. Omnipose model on combiend worm and bacteria dataset

In [None]:
from cellpose import core, models

use_GPU = core.use_gpu()
print('>>> GPU activated? %d'%use_GPU)

# pure worm omnipose and cellpose models as well as a combined bact+worm omnipose model 
model = [models.CellposeModel(gpu=use_GPU, model_type='worm_cp'),
         models.CellposeModel(gpu=use_GPU, model_type='worm_omni'),
         models.CellposeModel(gpu=use_GPU, model_type='worm_bact_omni')]

chans = [0,0] 

In [None]:
names = ['Cellpose','Mixed','Omnipose','Omnipose_worm+bact']

In [None]:
imglist = imgs[:]
n = len(names)
N = len(imgs)
masks,flows,styles,d = [[]]*n,[[]]*n,[[]]*n,[[]]*n

### Run models
There are three models, but four conitions described below:

In [None]:
clean = 0 # toggle on to re-run analysis 
if clean:
    # cellpose, no omni reconstruction 
    masks[0], flows[0], styles[0] = model[0].eval(imglist,channels=chans,rescale=None,dist_threshold=-1,flow_threshold=0,omni=False,resample=False,tile=False)
    # original + omni=True ('Mixed')
    masks[1], flows[1], styles[1] = model[0].eval(imglist,channels=chans,rescale=None,dist_threshold=-1,flow_threshold=0,omni=True,resample=False,tile=False,cluster=True)
    # omnipose
    masks[2], flows[2], styles[2] = model[1].eval(imglist,channels=chans,rescale=None,dist_threshold=-1,flow_threshold=0,omni=True,resample=True,tile=False,cluster=True)
    #worm+bact
    masks[3], flows[3], styles[3] = model[2].eval(imglist,channels=chans,rescale=None,dist_threshold=-1,flow_threshold=0,omni=True,resample=True,tile=False,cluster=True)

### Save results

In [None]:
save0 = '/home/kcutler/DataDrive/omnipose_paper/Comparison_Examples_C_elegans_combined/'

In [None]:
if clean:
    io.check_dir(save0)
    for j in range(n):
    # for j in [-1]:
        savedir = save0+names[j]
        io.check_dir(savedir)
        io.save_masks(imgs, masks[j], flows[j], img_names,
                      save_flows=True, 
                      save_outlines=True, 
                      savedir=savedir,
                      in_folders=True,
                      save_txt=False)

### Read in results

In [None]:
import os
for j in range(n):
    savedir = save0+names[j]
    mask_names = [os.path.join(savedir+'/masks',os.path.splitext(os.path.basename(name))[0]+'_cp_masks.png') for name in img_names]
    flow_names = [os.path.join(savedir+'/flows',os.path.splitext(os.path.basename(name))[0]+'_flows.tif') for name in img_names]
    masks[j] = [utils.format_labels(utils.clean_boundary(skimage.io.imread(f))) for f in mask_names]
    flows[j] = [skimage.io.imread(f) for f in flow_names]

### Analyze performance 
We use the built-in tools of Cellpose to do Jaccard index vs IoU. Note that Cellpose incorrectly calls this quantity the average precision, hence the label "ap" below. Before computing these performance metrics, we remove boundary labels (those cells within 3px of the boundary and below 30 square pixels).

In [None]:
from cellpose import metrics
from skimage import measure
import fastremap

threshold=np.linspace(0.5,1,100)

In [None]:
# This code ignores image boundary cells. 

ap = [[]] * len(masks) # average precision matrix per image
tp = [[]] * len(masks)
fp = [[]] * len(masks)
fn = [[]] * len(masks)
IoU = [[]] * len(masks) # Intersection over Union 
OvR = [[]] * len(masks)
pred_areas = [[]] * len(masks)

In [None]:

nimg = len(masks_gt)
cell_areas = [[]] * nimg
masks_gt_clean = [None]*nimg
masks_pred_clean = [[None]*nimg]*len(masks)
# remapping = [[]] * nimg
for j in range(nimg):
    mgt = utils.format_labels(utils.clean_boundary(masks_gt[j]))
    masks_gt_clean[j] = mgt
    regions = measure.regionprops(mgt)
    areas = np.array([reg.area for reg in regions])
    cell_areas[j] =  areas

In [None]:
# go over each model
for j,masks_pred in enumerate(masks):
#     masks_pred = map(list,zip(*[ utils.format_labels(utils.clean_boundary(msk)) for msk in masks_pred]))
# just apply cleanup to the masks when reading them in 
    api,tpi,fpi,fni = metrics.average_precision(masks_gt_clean,masks_pred,threshold=threshold)
    ap[j] = ap[j]+[api]
    tp[j] = tp[j]+[tpi]
    fp[j] = fp[j]+[fpi]
    fn[j] = fn[j]+[fni]
    
    masks_pred_clean[j] = masks_pred
    # go over every image
    
    for k in range(nimg):
        # get the IoU matrix; axis 0 corresponds to GT, axis 1 to pred 

        regions = measure.regionprops(masks_pred[k])
        areas = np.array([reg.area for reg in regions])
        pred_areas[j] = pred_areas[j] + [areas]
        iou = metrics._intersection_over_union(masks_gt_clean[k], masks_pred[k])
        ovp = metrics._label_overlap(masks_gt_clean[k], masks_pred[k])[1:,1:] #throw out columns corresponding to zero  
#         tp = metrics._true_positive(iou, th)
        OvR[j] = OvR[j]+[ovp / areas[np.newaxis,:]] # Overlap Ratio           
        IoU[j] = IoU[j]+[iou]
#         seg_error_percent[j] = 
 

In [None]:
savedir = save0

### Save results

In [None]:
np.savez(savedir+'OvR',OvR)
np.savez(savedir+'IoU',IoU)
np.savez(savedir+'cell_areas',cell_areas)
np.savez(savedir+'ap',ap)
# np.savez(savedir+'remapping',remapping)

### Load results

In [None]:
OvR = np.load(savedir+'OvR'+'.npz',allow_pickle=True)['arr_0']
IoU = np.load(savedir+'IoU'+'.npz',allow_pickle=True)['arr_0']
cell_areas = np.load(savedir+'cell_areas'+'.npz',allow_pickle=True)['arr_0']
# remapping = np.load(savedir+'remapping'+'.npz',allow_pickle=True)['arr_0']
ap = np.load(savedir+'ap'+'.npz',allow_pickle=True)['arr_0']

### Compute the number of cell errors for each model
We define 1 error for a cells without a label and 1 error for each additional predicted mask that is assigned to a cell. 

In [None]:
N = len(names)
nimg = len(cell_areas)
cell_errors = [[]]*N
total_errors = [0]*N
total_single_errors = [0]*N
total_cells = len(np.concatenate(cell_areas))

# area_thresh = np.linspace(1,np.max(np.concatenate(cell_areas)),100)
# area thresholds 
area_thresh = [np.percentile(np.concatenate(cell_areas),i) for i in range(100)]

M = len(area_thresh)
# ce_thresh = [[[]]*M]*N
# te_thresh = [[0]*M]*N
# tse_thresh = [[0]*M]*N
# tc_thresh = [0]*M

percent = [0]*N
# total = [0]*N

for j in range(N):
# for j in [-1,-2]:
    
    print(names[j])
    ce_thresh = [[]*2]*M
    te_thresh = np.zeros((2,M))
    tse_thresh = np.zeros((2,M))
    tc_thresh = np.zeros((2,M))
    
    for k in range(nimg):
        r = OvR[j][k].copy() #overlap ratio 
        r[r<=0.75] = 0 
        mx = np.max(r,axis=0) #find the maximum overlap for each true cell 
        mx[mx==0] = np.nan #exclude case where a spurious label has no overlaps
        hits = np.sum(r==mx,axis=1) # sum will be zero if a cell label has zero hits, not sure which ones these are
        cell_errors[j] = cell_errors[j] + [np.abs(hits-1)] #error if >1 or =0 
        total_errors[j] += np.sum(hits[hits>1]-1)+np.sum(hits==0) # -1 because a hit of 2 is 2 pred labels, 1 'extra' = 1 error
        total_single_errors[j] += np.sum(hits>1)+np.sum(hits==0)
        
        ca = cell_areas[k]
        for i,a in enumerate(area_thresh):
           
            for l in range(2):
                cell_filter = ca>=a if l==0 else ca<a
                cell_count = np.count_nonzero(cell_filter)
                hits_thresh = hits[cell_filter] # sum will be zero if a cell label has zero hits, not sure which ones these are
#                 ce_thresh[l,i] += [np.abs(hits_thresh-1)] #error if >1 or =0 
                te_thresh[l,i] += np.sum(hits_thresh[hits_thresh>1]-1)+np.sum(hits_thresh==0) # -1 because a hit of 2 is 2 pred labels, 1 'extra' or 1 error
                tse_thresh[l,i] += np.sum(hits_thresh>1)+np.sum(hits_thresh==0)
                tc_thresh[l,i] += cell_count

    
#     for i in range(M):
#         print('percentage of cells with at least one error above cutoff:', tse_thresh[i]/tc_thresh[i] *100)
    eps = 1e-8
    percent[j] = [[utils.safe_divide(tse_thresh[l,i],tc_thresh[l,i]+eps)*100 for i in range(M)] for l in range(2)]
        
  
    print('overall percentage of cells with at least one error:',total_single_errors[j]/total_cells *100)
#     print('percentage of cells with at least one error above cutoff:', tse_thresh[j]/tc_thresh *100)

### Percent of cells with at least one error above the 75th percentile in area:

In [None]:
k = 75
for j in range(N):
    print(names[j])
    print(percent[j][0][k])
# percent[0][0][k],percent[1][0][k]

In [None]:
mpl.rcParams.update(mpl.rcParamsDefault)
axcol = 'k'
def set_size(w,h, ax=None):
    """ w, h: width, height in inches """
    if not ax: ax=plt.gca()
    l = ax.figure.subplotpars.left
    r = ax.figure.subplotpars.right
    t = ax.figure.subplotpars.top
    b = ax.figure.subplotpars.bottom
    figw = float(w)/(r-l)
    figh = float(h)/(t-b)
    ax.figure.set_size_inches(figw, figh)

from omnipose.utils import sinebow
n = len(names)
z = 1
master_color_scheme = [[i,0,0] for i in np.linspace(1,.5,z)]+[[i,i,i] for i in np.linspace(.75,.25,n-z)]

golden = (1 + 5 ** 0.5) / 2
sz = 1.5
labelsize = 7

%matplotlib inline
darkmode = 0
if darkmode:
    plt.style.use('dark_background')
    axcol = 'w'
    colors = sinebow(n+1)
    colors = [colors[j+1] for j in range(n)]
    background_color = 'k'
    suffix = '_dark'
else:
    mpl.rcParams.update(mpl.rcParamsDefault)
    axcol = 'k'
    cmap = mpl.cm.get_cmap('viridis')
#     colors = cmap(np.linspace(0,.9,len(names)))
    colors = master_color_scheme
    background_color = np.array([1,1,1,1])
    suffix = ''
    
mpl.rcParams['figure.dpi'] = 300

# could turn these results into a plot by repeating at different cutoffs 
fig = plt.figure(figsize=(sz, sz)) 
ax = plt.axes()
for j in range(n):
# for j in [-1,-2]:
# for j in [0,2]:
#     ax.plot(area_thresh,np.divide(tse_thresh[j],tc_thresh),label=names[j],color=colors[j])
#     ax.plot(range(100),percent[j],label=names[j],color=colors[j])
    ax.plot(range(100),percent[j][0],label=names[j],color=colors[j])
#     ax.plot(range(100),percent[j][1],'--',label='above',color=colors[j])
#     print('mean percent error is',np.mean(percent[j]))
    # ax.set_facecolor('w')
# plt.xscale('log')
# ax.legend(prop={'size': labelsize}, loc='upper left', frameon=False,bbox_to_anchor=(1.05, 1))

ax.legend(prop={'size': labelsize/1.5}, loc='best', frameon=False)

ax.tick_params(axis='both', which='major', labelsize=labelsize,length=3, direction="out",colors=axcol,bottom=True,left=True)
ax.tick_params(axis='both', which='minor', labelsize=labelsize,length=3, direction="out",colors=axcol,bottom=True,left=True)
ax.set_ylabel('Percent of cells \nwith >=1 error', fontsize = labelsize)
ax.set_xlabel('Area percentile cutoff', fontsize = labelsize)
plt.ylim([0,100])
plt.xlim([0,100])

ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.patch.set_alpha(0.0)

fig.patch.set_facecolor(background_color)
plt.show()

a = 35
tight_bbox_raw = ax.get_tightbbox(fig.canvas.get_renderer())
tight_bbox_raw._points+=[[-a,-a],[a,a]]
tight_bbox = mpl.transforms.TransformedBbox(tight_bbox_raw, mpl.transforms.Affine2D().scale(1./fig.dpi))


# name = '/home/kcutler/DataDrive/omnipose_paper/Figure 5/error_vs_area_percentiles'+suffix
# fig.savefig(name+'.eps',bbox_inches=tight_bbox)
# fig.savefig(name+'.pdf',bbox_inches=tight_bbox)
# fig.savefig(name+'.eps',pad_inches = 0.5)
# fig.savefig(name+'.png',pad_inches = 0.5)

### Print some more statistics used in the paper

In [None]:
a = np.concatenate(cell_areas)
e = np.concatenate(cell_errors[0]) # specifically cellpose

In [None]:
area_cutoff = np.percentile(a,75)
cell_filter = a>=area_cutoff
print('75th area percentile:',area_cutoff)
print('Percentage of errors attributed to this group:',np.sum(e[cell_filter])/np.sum(e)*100)
print('Percent of all cells contributing to this group (confirm to be 25):',np.count_nonzero(cell_filter)/len(a)*100)
print('Number of cells in this group:',len(a))

In [None]:
%matplotlib inline
mpl.rcParams['figure.dpi'] = 300
x = threshold
sz = 1.5
fig = plt.figure(figsize=(sz, sz)) 
ax = plt.axes()
# plt.tight_layout()
# plt.minorticks_on()
plt.xticks(np.arange(min(x), max(x)+1, .25))
plt.xlim([0.5,1])
plt.ylim([0,1])
# plt.yticks(np.arange(0, 1.1, .25))

# colors = ['g','r','b','y','c','m']
cell_count = np.array([len(np.unique(ca)) for ca in cell_areas])
# for j in range(n):
# for j in [0,1,-1]:
for j in [0,2,3]:
    colors[j]
    ax.plot(x,np.mean(ap[j][0],axis=0).T,label=names[j],color=colors[j],linewidth=.5)
#     ax.plot(x,np.mean((fp[j][0]+fn[j][0]),axis=0),label=pretty_names[j],color=colors[j])


# ax.set_facecolor('w')
# ax.legend(prop={'size': labelsize}, loc='upper left', frameon=False,bbox_to_anchor=(1.05, 1))
# ax.legend(prop={'size': labelsize}, loc='best', frameon=False)
ax.tick_params(axis='both', which='major', labelsize=labelsize,length=3, direction="out",colors=axcol,bottom=True,left=True)
ax.tick_params(axis='both', which='minor', labelsize=labelsize,length=3, direction="out",colors=axcol,bottom=True,left=True)
# plt.setp(ax.xaxis.get_label(), visible=True, text='IoU')
# plt.setp(ax.get_xticklabels(), visible=True, ha='right')

# ax.grid(b=True, which='major', color='b', linestyle='-')
ax.set_ylabel('Jaccard Index', fontsize = labelsize)
ax.set_xlabel('IoU matching threshold', fontsize = labelsize)
# plt.set(xlabel='IoU threshold', ylabel='Average Precision',fontsize=labelsize)
# plt.tight_layout()
# plt.yscale('log')


ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.patch.set_alpha(0.0)

fig.patch.set_facecolor(background_color)
plt.show()

path = '/home/kcutler/DataDrive/omnipose_paper/Figure 5/c_elegans_AP_vs_IoU'
a = 35
tight_bbox_raw = ax.get_tightbbox(fig.canvas.get_renderer())
tight_bbox_raw._points+=[[-a,-a],[a,a]]
tight_bbox = mpl.transforms.TransformedBbox(tight_bbox_raw, mpl.transforms.Affine2D().scale(1./fig.dpi))

# fig.savefig(path+suffix+'.eps',bbox_inches=tight_bbox)
# fig.savefig(path+suffix+'.pdf',bbox_inches=tight_bbox)
# fig.savefig('/home/kcutler/DataDrive/omnipose_paper/AP_vs_IoU_all.eps',bbox_inches="tight",pad_inches = 0.05)
# fig.savefig('/home/kcutler/DataDrive/omnipose_paper/AP_vs_IoU_all.png',bbox_inches="tight",pad_inches = 0.05)

In [None]:
# # make outline views with yellow
# from cellpose import plot, io,omnipose
# cmap = mpl.cm.get_cmap('plasma')
# outline_col = cmap(0.85)[:3]
# def getname(path,suffix='_cp_masks'):
#     return os.path.splitext(Path(path).name)[0].replace(suffix,'')
# for j in range(n):
#     savedir = save0+names[j]
#     mask_names = [os.path.join(savedir+'/masks',os.path.splitext(os.path.basename(name))[0]+'_cp_masks.png') for name in img_names]
#     flow_names = [os.path.join(savedir+'/flows',os.path.splitext(os.path.basename(name))[0]+'_flows.tif') for name in img_names]
#     masks[j] = [omnipose.utils.format_labels(skimage.io.imread(f)) for f in mask_names]
#     flows[j] = [skimage.io.imread(f) for f in flow_names]
    
#     savedir = savedir+'_yellow_outlines_thin'
#     io.check_dir(savedir)
#     for k in range(len(mask_names)):
#         outli = plot.outline_view(omnipose.utils.normalize99(imgs[k]),masks[j][k],color=outline_col)
#         skimage.io.imsave(os.path.join(savedir,'outline_'+getname(mask_names[k])+'.png'),((outli)*255).astype(np.uint8))
        

In [None]:
cats = ['worms']
subsets= [['elegans']]
exclude = 'fggdffdgdf'#'caulo_14'
K = len(cats)
indices = [np.array([i for i,s in enumerate(mask_names) if any((name in s) and (exclude not in s) for name in subsets[k])]).astype(int) for k in range(K)]

counts = np.zeros((3,1))
for j,inds in enumerate(indices):
    # print(j)
    for i in inds:
        # print(len(np.unique(masks_gt_clean[i]))-1)
        counts[j] +=len(np.unique(masks_gt_clean[i]))-1
counts

In [None]:
from scipy.optimize import linear_sum_assignment

per_cell = 1
J = len(IoU)
K = len(cats)
y = [ [ [] for i in range(K) ] for i in range(J) ]
for j in range(J):
    for k in range(K):
        matched_iou = []
        mean_matched_iou = []
        for ind in indices[k]:
            iou = IoU[j][ind][1:,1:]
            th = 0
            n_min = min(iou.shape[0], iou.shape[1])
            costs = -(iou >= th).astype(float) - iou / (2*n_min)
            true_ind, pred_ind = linear_sum_assignment(costs)
            miou = iou[true_ind, pred_ind]
            matched_iou.append(miou)
            mean_matched_iou.append(np.sum(miou)/max(iou.shape[0], iou.shape[1]))
            
        y[j][k] = [m for sublist in matched_iou for m in sublist] if per_cell else mean_matched_iou
        

In [None]:
# one cat messes things up, must flatten


In [None]:
sz = 2
lw = 0.8
fig, axs = plt.subplots(figsize=(sz,sz))
plt.subplots_adjust(wspace=.4, hspace=.4)
# axs = axs.flatten()
np.random.seed(123)
w0 = 0.7
split=1
box = False
scatter = 0
violin = 1
# for j in [-1]:
ax = axs

# x = np.linspace(0,1,len(y))
w = w0 / 3
j = 0
x = np.linspace(j-w,j+w,len(y))
height = [np.mean(yi) for yi in y[j]]
percents = [np.sum(np.array(yi)>0.8)/counts[k]*100 for k,yi in enumerate(y[j])]
overall_percent = np.sum([np.sum(np.array(yi)>0.8) for yi in y[j]]) / np.sum(counts)
print(names[j], 'Percent above 0.8',percents,overall_percent)#,height,len(y[j]))

yp = [np.array(yi).flatten() for yi in y]

parts = ax.violinplot(yp,positions=x,showextrema=0,widths=w0/(len(y)+1)) #showmeans=1,quantiles = [[.25,.75]]*len(x),
boxes = ax.boxplot(yp,positions=x,showfliers=False,widths=w0/35,patch_artist=True,
                   boxprops=dict(facecolor='w', color='k',linewidth=lw),whiskerprops=dict(color='k',linewidth=lw),
                   medianprops=dict(color='k',linewidth=lw))

for i,pc in enumerate(parts['bodies']):
    pc.set_facecolor(colors[i])
    pc.set_alpha(1)
    pc.set_edgecolor(colors[i])
    pc.set_linewidth(lw)

ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.patch.set_alpha(0.0)

ax.set_ylim(0,1)      
ax.set_xlim((j-2*w0/4,j+2*w0/4))
ax.set_xticks([])
plt.show()   
tight_bbox_raw = ax.get_tightbbox(fig.canvas.get_renderer())
a = 50
tight_bbox_raw._points+=[[-a,-a],[a,a]]
tight_bbox = mpl.transforms.TransformedBbox(tight_bbox_raw, mpl.transforms.Affine2D().scale(1./fig.dpi))

fig.savefig('/home/kcutler/DataDrive/omnipose_paper/Figure 5/Worm_iou_'+names[j]+'.eps',bbox_inches=tight_bbox)

In [None]:
subsets = ['wormpose_39','bbc010_44']
indices = [np.array([i for i,s in enumerate(mask_names) if (name in s)]).astype(int) for name in subsets]
indices = [i[0] for i in indices]


In [None]:
from scipy.ndimage import binary_dilation
def get_bbxs(masks,seeds,pad=10):
    '''
    Get bounding boxes that cover the seed masks.
    Takes label matrix masks, seed list seeds. Seeds can be nested list or [None]. 
    '''
    K = len(masks)
    xys = [[]]*K
    for k in range(K): 
        mgt = masks[k]
        microcolonies = mgt>0        
        l = seeds[k]
        labels = skimage.measure.label(microcolonies)
        print(l)
        if l is not None:
            if isinstance(l, list):
                hits = np.any(np.stack(([mgt==li for li in l])),axis=0)
            else:
                hits = mgt==l
            binmask = hits.copy()
            for cell_ID in np.unique(labels[labels>0]):
                mask = labels==cell_ID 
                area = np.count_nonzero(mask)
                overlap = np.count_nonzero(np.logical_and(mask, binary_dilation(hits, iterations=1)))
                if overlap > 0: #only premove cells that are 50% or more edge px
                    binmask[mask] = 1

            y,x = np.nonzero(binmask)
            max_y,max_x = np.array(mgt.shape)-1

            y0 = max(0,min(y)-pad)
            y1 = min(max_y,max(y)+pad)
            x0 = max(0,min(x)-pad)
            x1 = min(max_x,max(x)+pad)
            xy = [y0,y1,x0,x1]
        else:
            ly,lx = mgt.shape
            xy = [0,ly,0,lx]
        xys[k] = xy
    return xys 

cell_seeds = [24,11]
masks_s = [masks[-1][i] for i in indices]
xys = get_bbxs(masks_s,cell_seeds)
xys

In [None]:
suffixes = ['CP','MX','OP','OP_Bact_worm']
# names

In [None]:
import edt
from matplotlib import rc
from scipy.ndimage.morphology import binary_erosion, binary_dilation
from cellpose import omnipose
from scipy.optimize import linear_sum_assignment

mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42
mpl.rcParams['font.family'] = 'sans-serif'
rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
import matplotlib_inline
# matplotlib_inline.backend_inline.set_matplotlib_formats('retina', 'png')
mpl.rcParams['figure.dpi'] = 72
%matplotlib inline
A = 1
px = A/plt.rcParams['figure.dpi']  # pixel in inches
cmap = mpl.cm.get_cmap('plasma')
outline_col = cmap(0.85)[:3]
ext = '.png'
bg = 0.5
pad = 10

def getname(path,suffix=''):
    return os.path.splitext(Path(path).name)[0].replace(suffix,'')
field_names = [getname(i) for i in img_names]

savedir = save0
io.check_dir(savedir)
J = len(suffixes)
K = len(masks_s)
for j in range(J): # select from models
# for j in [-1]:
    suffix = suffixes[j]
    for i,k in enumerate(indices):
        xy = xys[i]

        ly,lx = xy[1]-xy[0],xy[3]-xy[2]
        # outline thickness
        val = px*lx
        print(val,img_names[k])
        if val<3:
            mode = 'inner'
        else:
            mode = 'thick'
        mgt = masks_gt[k] # replace with GT when available
        p = omnipose.utils.rescale(imglist[k])
        # img0 = p**(np.log(bg)/np.log(np.mean(p[binary_erosion(mgt==0)])))
        # img0 = img0[xy[0]:xy[1],xy[2]:xy[3]]
        img0 = p[xy[0]:xy[1],xy[2]:xy[3]]
        mpred = masks[j][k].copy()[xy[0]:xy[1],xy[2]:xy[3]]
        # yellow/gold outlines
        for c,cn in zip([outline_col,[1,0,0]],['gold','red']):
            pic = plot.outline_view(img0,mpred,color=c,mode=mode)
            skimage.io.imsave(os.path.join(savedir,field_names[k]+'_outline_'+cn+suffix+ext),((pic)*255).astype(np.uint8))
            
        if 1:
            mgt, remap = fastremap.renumber(mgt[xy[0]:xy[1],xy[2]:xy[3]])
            mpred, remap = fastremap.renumber(mpred)
            iou = metrics._intersection_over_union(mgt, mpred)
            th = 0
            n_min = min(iou.shape[0], iou.shape[1])
            costs = -(iou >= th).astype(float) - iou / (2*n_min)
            true_ind, pred_ind = linear_sum_assignment(costs)
            miou = iou[true_ind, pred_ind]
            with open(os.path.join(savedir,'MMiou'+field_names[k]+suffix+'.txt'), "w") as text_file:
                print(f"Mean Matched IoU: {np.mean(miou)}", file=text_file)
        

In [None]:
%matplotlib widget
mgt, remap = fastremap.renumber(mgt[xy[0]:xy[1],xy[2]:xy[3]])
mpred, remap = fastremap.renumber(mpred)
plt.imshow(mpred)

In [None]:
iou = metrics._intersection_over_union(mgt, mpred)

In [None]:
cp = np.mean(ap[0][0],axis=0).T
om = np.mean(ap[1][0],axis=0).T

av_all = np.mean(np.divide((om-cp),cp,out=np.zeros_like(cp), where=np.logical_and(cp!=0,~np.isnan(cp))))
cutoff = 90
av_up = np.mean(np.divide((om[cutoff:]-cp[cutoff:]),cp[cutoff:],out=np.zeros_like(cp[cutoff:]), where=np.logical_and(cp[cutoff:]!=0,~np.isnan(cp[cutoff:]))))
av_dwn = np.mean(np.divide((om[:cutoff]-cp[:cutoff]),cp[:cutoff],out=np.zeros_like(cp[:cutoff]), where=np.logical_and(cp[:cutoff]!=0,~np.isnan(cp[:cutoff]))))

print(av_all*100,av_up*100,av_dwn*100,(np.mean(om)-np.mean(cp))/np.mean(cp)*100)

In [None]:
from cellpose.omnipose.utils import ncolorlabel
from cellpose import transforms, omnipose
import cellpose

In [None]:
#code developing the error ideas 
ind = range(4)

from matplotlib.colors import ListedColormap
plt.style.use('dark_background')
# now find and crop specific cells for comparison 
# oversegmneted cells will have low IoUs spread out over several hits; let's take a specific image from the rpevious list as an example. We can see that there are three true labels, but 
# several predicted labels. The IoU can be at most 1, but small precited labels will always be quite small, up to 1/(area of the true label). But we don't really want to care about the 
cm2 = ListedColormap([color for color in sinebow(5).values()])

k = ind[3] #specific cmall example as a test
# k = ind[1]
k = 3
j = 2 # look at original cellpose

# r = OvR[j][k]
# mx = np.max(r,axis=0)
# hits = np.sum(r==mx,axis=1)
# # true,pred = np.nonzero(r>mx/2)
# (unique,counts) = np.unique(true,return_counts=True)
# errors = counts-1

#defer 'belonging' to maximum; only issue is if there is a spurious label with no overlap at all (e.g. dead cell)
# opposite case (e.g. edge cell): maybe should filter by area 
# mx[mx==0] = np.nan
# hits_pred = r==mx,axis=1) # sum will be zero if a cell label has zero hits, not sure which ones these are
# err = np.abs(hits-1)
# cell_errors[j][k] = err  #error if >1 or =0 
# terr = np.sum(hits[hits>1]-1)+np.sum(hits==0)
# total_errors[j] += terr



m = 0
r_pred = OvR[j][k].copy()
mx = np.max(r_pred,axis=0) 
# std = np.std(r_pred,axis=0)
mx[mx==0] = np.nan
hits_pred = r_pred==mx
err = np.abs(np.sum(hits_pred,axis=1)-1)



# err = cell_errors[j][k]# assuming proper precomputing 
areas = cell_areas[k]
error_cutoff = 1
area_cutoff = 10
err_indexes = np.nonzero(np.logical_and(err>=error_cutoff,areas>area_cutoff))[0]

# print(true)
# print(pred)
# print(counts)
# print(hits,err,terr)

# fig, axes = plt.subplots(nrows=1, ncols=5)
# fig.set_size_inches(40,40)

# axes[0].imshow(imgs[k],cmap = 'gray')
# axes[0].axis('off')
# #     plt.show()
# axes[1].imshow(ncolorlabel(masks[j][k],n=5),cmap = cm2,interpolation='nearest')
# axes[1].axis('off')

# axes[2].imshow(ncolorlabel(masks_gt[k],n=5),cmap = cm2,interpolation='nearest')
# axes[2].axis('off')

# axes[3].imshow(r,cmap = 'viridis',interpolation='nearest')
# axes[3].axis('off')

# axes[4].imshow(r==mx,cmap = 'viridis',interpolation='nearest')
# axes[4].axis('off')
# plt.show()


# err_indexes = np.nonzero(err)[0]

pad = 5

for l in err_indexes[0:10]:
    print(err[l])
    mgt = masks_gt_clean[k]
    y,x = np.nonzero(mgt==l+1)
    max_y,max_x = np.array(mgt.shape)-1
    
    y0 = max(0,min(y)-pad)
    y1 = min(max_y,max(y)+pad)
    x0 = max(0,min(x)-pad)
    x1 = min(max_x,max(x)+pad)
    
    p = transforms.normalize99(imgs[k][y0:y1,x0:x1],omni=True)

    mask_gt = mgt[y0:y1,x0:x1]
    bini = mask_gt==l+1
    img0 = np.stack((p,p,p),axis=2)
    outli = transforms.normalize99(plot.outline_view(img0*255,bini),omni=True)
    bin0 = np.stack((bini,bini,bini),axis=2)
    gt_pic = np.hstack((img0,outli,bin0))
    mask_pred = masks[j][k][y0:y1,x0:x1]
    flow_pred = transforms.normalize99(flows[j][k][y0:y1,x0:x1],omni=True)
#                     print(bini.shape,flow_OG.shape)
#                     flow_OG = np.hstack([flow_OG[:,:,i]*bini for i in range(flow_OG.shape[-1])])
# #                     flow_OG = np.multiply(flow_OG,bini)
#                     print(flow_OG.shape)

    
    
    inds = np.where(hits_pred[l,:])
#                     print('OG inds',inds,np.unique(mask_OG))
    rmask = np.zeros_like(mask_gt)
    for i,label in enumerate(inds[0]):
        rmask[mask_pred==label+1] = i+1

    outl_pred = transforms.normalize99(plot.outline_view(img0,rmask),omni=True)
    res_pred = np.hstack((img0,bin0,cm2(ncolorlabel(mask_pred))[:,:,:3],outl_pred,flow_pred))
    plt.figure(figsize=[2,2])
    plt.imshow(res_pred)
    plt.axis('off')
    plt.show()

# plt.imshow(r,cmap='viridis')
# print(np.sum(r,axis=1))
# plt.hist(r)
# p = plt.plot(threshold,IoU_hits[0][1],'w',threshold,IoU_hits[1][1],'g')

For this plot, I should put both the training and test errors. The training dataset has a lot more long cells in it. In fact, I should do the analysis with the train, test, and combined to show correlation and lack of overfitting. 

In [None]:
## import matplotlib
labelsize = 7
mpl.rc('xtick', labelsize=labelsize) 
mpl.rc('ytick', labelsize=labelsize) 
# matplotlib.rcParams['text.usetex'] = True
mpl.rcParams['figure.dpi'] = 300 # resolution for pngs 
linewidth = 0.75
ymax = 8


fig = plt.figure(figsize=(sz, sz/golden)) 
ax = plt.axes()


# plt.minorticks_on()
# plt.xticks(np.arange(min(x), max(x)+1, .25))
# plt.xlim([0.4,1.1])
# plt.ylim([0,1])
# plt.yticks(np.arange(0, 1.1, .25))
# colors = ['r','g','b','k']
for j in [0]:
# for j in range(n):
# for j in [0,3,-1]:
    ax.scatter(np.concatenate(cell_areas),np.concatenate(cell_errors[j]),s=3,color=colors[j],label=names[j],alpha=1,edgecolors='none')
    ax.vlines(np.percentile(np.concatenate(cell_areas),75),0,ymax,colors='gray',linestyles='dashed',linewidth=.5)
# ax.scatter(np.concatenate(cell_areas),np.concatenate(cell_errors),s=1)
# plt.xlim([0.05,1e5])
plt.xscale('log')
plt.xticks()
# plt.yscale('log')

plt.ylim([0,ymax])
# labelsize = 11
# ax.set_facecolor('w')

# ax.legend(prop={'size': labelsize}, loc='upper left',frameon=False,markerscale=3)
ax.tick_params(axis='both', which='major', labelsize=labelsize,length=3, direction="out",colors=axcol,bottom=True,left=True)
# ax.tick_params(axis='both', which='minor', labelsize=labelsize,length=3, direction="in",colors=axcol,bottom=True,left=True)
# plt.setp(ax.xaxis.get_label(), visible=True, text='IoU')
# plt.setp(ax.get_xticklabels(), visible=True, ha='right')

# ax.grid(b=True, which='major', color='b', linestyle='-')
ax.set_ylabel('Segmentation errors', fontsize = labelsize)
ax.set_xlabel('Cell area (px$^2$)', fontsize = labelsize)
# plt.set(xlabel='IoU threshold', ylabel='Average Precision',fontsize=labelsize)
# plt.gcf().subplots_adjust(bottom=0.15)
# plt.tight_layout()

ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.patch.set_alpha(0.0)

fig.patch.set_facecolor(background_color)
plt.show()

tight_bbox_raw = ax.get_tightbbox(fig.canvas.get_renderer())
a = 50
tight_bbox_raw._points+=[[-a,-a],[a,a]]
tight_bbox = mpl.transforms.TransformedBbox(tight_bbox_raw, mpl.transforms.Affine2D().scale(1./fig.dpi))

# fig.savefig('/home/kcutler/DataDrive/omnipose_paper/errors_vs_area.eps',bbox_inches=tight_bbox)
# fig.savefig('/home/kcutler/DataDrive/omnipose_paper/errors_vs_area.png',bbox_inches=tight_bbox)

# fig.savefig('/home/kcutler/DataDrive/omnipose_paper/errors_vs_area_loglog.eps',bbox_inches="tight")
# fig.savefig('/home/kcutler/DataDrive/omnipose_paper/errors_vs_area_linear.eps',bbox_inches="tight")

# fig.savefig('/home/kcutler/DataDrive/omnipose_paper/errors_vs_area.svg')

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import (MultipleLocator, AutoMinorLocator)
from matplotlib.ticker import FormatStrFormatter
from scipy.stats import gaussian_kde

color = [1,0,0,.5]
axcol = 'k'

# some random data
# x = np.random.randn(1000)
# y = np.random.randn(1000)
x = np.concatenate(cell_areas)
y = np.concatenate(cell_errors[0])
xmin = 10
xmax = 10**4
cm = plt.cm.get_cmap('plasma')

def scatter_hist(x, y, ax, ax_histx, ax_histy):
    # no labels
    ax_histx.tick_params(axis="x", labelbottom=False)
    ax_histy.tick_params(axis="y", labelleft=False)

    # the scatter plot:
#     ax.scatter(x, y, s=3,edgecolors='none',c=color)
    
#         y = np.array(ap[j]).T.flatten()
#         X = np.repeat(x,len(ap[j][0]))
    X = x.copy()
    xy = np.vstack([X,y])
    z = gaussian_kde(xy)(xy)
    idx = z.argsort()
    X, y, z = X[idx], y[idx], z[idx]
    ax.scatter(X, y, c=np.log(z), s=1,cmap=cm)

    # now determine limits by hand:
    binwidth1 = 100
    binwidth2 = 1
    eps = 1e-20
    xymax = max(np.max(np.abs(x)), np.max(np.abs(y)))
    lim = (int(xymax/binwidth1) + 1) * binwidth1
    pp = 1
    bins1 = np.arange(xmin, lim + binwidth1, binwidth1)
    bins2 = np.arange(0, 55 + binwidth2, binwidth2)
    logbins1 = np.logspace(np.log10(bins1[0]),np.log10(bins1[-1]),len(bins1))
    
    # plot histogram
    n,bins,patches = ax_histx.hist(x, bins=logbins1,color=color)
    col = (n-n.min())/(n.max()-n.min())+eps
    for c, p in zip(col, patches):
        plt.setp(p, 'facecolor', cm(c**pp))
    
    n,bins,patches = ax_histy.hist(y, bins=bins2, orientation='horizontal',color=color,align='left')
    col = (n-n.min())/(n.max()-n.min())+eps
    for c, p in zip(col, patches):
        plt.setp(p, 'facecolor', cm(c**.25))
    
    
#     ax_histx.tick_params(axis='x',which='both',bottom=False,top=False)
#     ax_histy.tick_params(axis='y',which='both',left=False,right=False)
#     ax_histy.get_xaxis().set_ticks([])
#     ax_histy.get_yaxis().set_ticks([])
    

#     plt.xlim([0,10**3])
# start with a square Figure
# a = 1+golden
a = 5
b = 1
c = 2.5
labelsize = 9
fig = plt.figure(figsize=(c, c))


# Add a gridspec with two rows and two columns and a ratio of 2 to 7 between
# the size of the marginal axes and the main axes in both directions.
# Also adjust the subplot parameters for a square plot.

gs = fig.add_gridspec(2, 2,  width_ratios=(a, b), height_ratios=(b, a),
                      left=0.1, right=0.9, bottom=0.1, top=0.9,
                      wspace=0.075, hspace=0.075)

ax = fig.add_subplot(gs[1, 0])
ax_histx = fig.add_subplot(gs[0, 0], sharex=ax)
plt.xscale('log')
# plt.yscale('log')
ax_histy = fig.add_subplot(gs[1, 1], sharey=ax)

# plt.xscale('log')
# use the previously defined function
scatter_hist(x, y, ax, ax_histx, ax_histy)
ax.vlines(np.percentile(np.concatenate(cell_areas),75),0,np.max(y),colors=[.5]*4,linestyles='dashed',linewidth=1)
plt.xscale('log')
ax.set_ylabel('Segmentation errors', fontsize = labelsize)
ax.set_xlabel('Cell area (px$^2$)', fontsize = labelsize)

ax.patch.set_alpha(0.0)

fig.patch.set_facecolor(background_color)

x_major = mpl.ticker.LogLocator(base = 10.0, numticks = 10)
ax.xaxis.set_major_locator(x_major)
x_minor = mpl.ticker.LogLocator(base = 10.0, subs = np.arange(1.0, 10.0) * 0.1, numticks = 10)
ax.xaxis.set_minor_locator(x_minor)
ax.xaxis.set_minor_formatter(mpl.ticker.NullFormatter())
plt.yticks(range(0,55,10))
ax.tick_params(axis='both', which='major', labelsize=labelsize)

# x_major = matplotlib.ticker.LogLocator(base = 10.0, numticks = 10)
# ax_histy.yaxis.set_major_locator(x_major)
# x_minor = matplotlib.ticker.LogLocator(base = 10.0, subs = np.arange(1.0, 10.0) * 0.1, numticks = 10)
# ax_histy.yaxis.set_minor_locator(x_minor)
# ax_histy.yaxis.set_minor_formatter(matplotlib.ticker.NullFormatter())

ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax_histy.axis('off')
ax_histx.axis('off')

# ax_histy.spines['right'].set_visible(False)
# ax_histy.spines['top'].set_visible(False)
# ax_histy.spines['bottom'].set_visible(False)
# ax_histy.spines['left'].set_visible(False)
# ax_histy.get_xaxis().set_ticks([])
# ax_histy.get_yaxis().set_ticks([])
# ax_histy.get_yaxis().set_visible(False)
# plt.xlim([xmin,xmax])
plt.show()

# fig.savefig('/home/kcutler/DataDrive/omnipose_paper/errors_vs_area_hist'+suffix+'.eps',bbox_inches='tight',transparent=True,pad_inches=0)
# fig.savefig('/home/kcutler/DataDrive/omnipose_paper/errors_vs_area_hist'+suffix+'.pdf',bbox_inches='tight')
# fig.savefig('/home/kcutler/DataDrive/omnipose_paper/errors_vs_area_hist.png',bbox_inches='tight')

In [None]:
np.percentile(x,75)

In [None]:
r = x*y
p = y>0

# p = x>=3000
# print(np.mean(r[p])/np.mean(y[p]),np.mean(x[p])/np.median(x),np.mean(x[p]),np.mean(y[p]),np.count_nonzero(p)/len(p)*100,np.median(x),np.mean(x))
# len(p)
# a = 1000
# a,np.mean(y[x>=a]),np.mean(y[x<a])
m = 5*np.median(x)
# m = 1000
print(m)
q = np.logical_and(x>=m,p)
s = np.logical_and(x<m,p)
print('probability of having one error after cutoff',np.count_nonzero(q)/np.count_nonzero(x>=m)*100)
print('probability of having one error under cutoff',np.count_nonzero(s)/np.count_nonzero(x<m)*100)

In [None]:
import os
from scipy.ndimage.morphology import binary_erosion, binary_dilation
from cellpose import utils, transforms
from cellpose.omnipose.utils import ncolorlabel, sinebow
import cv2
from cellpose.io import imsave
from PIL import Image, ImageFont, ImageDraw, ImageOps
from matplotlib.colors import ListedColormap
import os, datetime, gc, warnings, glob


def rot_flow(f):
    r = f[:,:,0]
    g = f[:,:,1]
    b = f[:,:,2]
    a = np.sqrt(3)
#     f = np.stack(((g+b)/a,(b+r)/a,(g+r)/a),axis=-1)
    return 1-f

plt.style.use('dark_background')
cmap = mpl.cm.get_cmap('viridis')
cm2 = ListedColormap([color for color in sinebow(5).values()])
cm3 = ListedColormap([color for color in sinebow(5,bg_color=[1,1,1,1]).values()])

basedir = '/home/kcutler/DataDrive/omnipose_paper/Comparison Examples/newcompare3/'
if not os.path.isdir(basedir):
    os.mkdir(basedir)

area_cutoff = 20
error_cutoff=2
pad = 10
N = len(masks)
nimg = len(masks_gt)
bkct = 500
bg = 0.5
ext = '.png'

cnt = 0
txtpad = 10
yoffset = [0,0]

img_list = ['bbc010','wormpose']
cell_list = [1,0]
namelist = [os.path.splitext(os.path.split(file)[-1])[0] for file in img_names]
img_index = [namelist.index(im) for im in img_list]
[print(l,k) for l,k in zip(cell_list,img_index)]

n_exmpl = len(img_list)
# phase_list = []*n_exmpl
label_list = []*n_exmpl
coords = []*n_exmpl

for j in [0]: #base selection on cellpose examples
    for l,k in zip(cell_list,img_index): # replace loops over k and l 
        
        file = img_names[k]
        basename = os.path.splitext(os.path.split(file)[-1])[0]

        mgt = masks_gt_clean[k]
        p = transforms.normalize99(imgs[k],omni=True)
        img0 = p**(np.log(bg)/np.log(np.mean(p[binary_erosion(mgt==0)])))

        if l is not None:
            if isinstance(l, list):
                hits = np.any(np.stack(([mgt==li+1 for li in l])),axis=0)
            else:
                hits = mgt==l+1
            
            microcolonies = mgt>0
            labels = skimage.measure.label(microcolonies)
            binmask = hits.copy()
            for cell_ID in np.unique(labels[labels>0]):
                mask = labels==cell_ID 
                area = np.count_nonzero(mask)
                overlap = np.count_nonzero(np.logical_and(mask, binary_dilation(hits, iterations=1)))
                if overlap > 0: #only premove cells that are 50% or more edge px
                    binmask[mask] = 1

            y,x = np.nonzero(binmask)
            max_y,max_x = np.array(mgt.shape)-1

            y0 = max(0,min(y)-pad)
            y1 = min(max_y,max(y)+pad)
            x0 = max(0,min(x)-pad)
            x1 = min(max_x,max(x)+pad)

            # p = img0[y0:y1,x0:x1]
            # mask_gt = mgt[y0:y1,x0:x1]
            xy = [y0,y1,x0,x1]
        else:
            ly,lx = mgt.shape
            xy = [0,ly,0,lx]
        
        coords.append(xy)
#         # flip fo that all examples are veritcal
#         if mask_gt.shape[0]<mask_gt.shape[1]:
#             mask_gt = np.transpose(mask_gt)
#             p =  np.transpose(p)

#         phase_list.append(p)
        # label_list.append(mask_gt)


# #         mask_gt = mgt[y0:y1,x0:x1]
# #         ncolor_gt = utils.ncolorlabel(mask_gt)
#         plt.imshow(p)
#         plt.show()
#         plt.axis('off')

In [None]:
img_names

In [None]:
m = masks_gt_clean[img_index[6]]
plt.imshow(m)
np.max(m)-2

In [None]:
# re-run both models on full images to reproduce the exact reults from mainn code
# (Crop later; cropped images can have slight significant differences in results for veery small crops). 
chans = [0,0] 
names = ['Cellpose','OmniSeg']
suffixes = ['_CP','_OM']
imglist = [imgs[k] for k in img_index]
J = len(names)
nimg = len(imgs)
masks,flows,styles,d = [[]]*J,[[]]*J,[[]]*J,[[]]*J
for j in range(J):
    masks[j], flows[j], styles[j], = model[j].eval(imglist,channels=[0,0],mask_threshold=-1,diameter=0,cluster=j,
                                                   flow_threshold=0,omni=j, calc_trace=True, min_size=0,tile=False,transparency=True)


In [None]:
%matplotlib inline
from cellpose import plot
j = 0
nimg = len(imgs)
for idx,im in enumerate(imglist):
# for idx,i in enumerate([0]):

    maski = masks[j][idx]
    flowi = flows[j][idx][0]

    fig = plt.figure(figsize=(12,5))
    plot.show_segmentation(fig, im, maski, flowi, channels=chans, omni=True, bg_color=0)
    plt.tight_layout()
    plt.show()

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(flows[0][4][0])
# plt.show()
# flows[0][-1][0].shape
i = 0
# masks[0][i].shape,imglist[i].shape,masks_gt[img_index[i]].shape
# dP_pred = flows[0][0][1]
# dP_pred.shape
# dP_pred.shape,np.transpose(dP_pred,(0,2,1)).shape

# tr = flows[0][0][-1][0]
# flows[j][k][-1][0].shape
# len(flows[0][1])

In [None]:
import edt
from matplotlib import rc
from cellpose.omnipose import utils
darkmode=1

mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42
mpl.rcParams['font.family'] = 'sans-serif'
rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
import matplotlib_inline
# matplotlib_inline.backend_inline.set_matplotlib_formats('retina', 'png')
mpl.rcParams['figure.dpi'] = 72
%matplotlib inline
A = 1
px = A/plt.rcParams['figure.dpi']  # pixel in inches
cmap = mpl.cm.get_cmap('plasma')

do_flows = 1
isolated = 0

basedir = '/home/kcutler/DataDrive/omnipose_paper/Figure 3/comparison4'
io.check_dir(basedir)
for j in range(J): # select from models
# for j in [-1]:
    suffix = suffixes[j]
    # for k in range(n_exmpl): # select from images 
    # for k in [4]:
    for k in [6]:
    
        xy = coords[k].copy() # for cropping
        ly,lx = xy[1]-xy[0],xy[3]-xy[2]
        
        # outline thickness
        if px*lx<2:
            mode = 'inner'
        else:
            mode = 'thick'
        
        mgt = masks_gt[img_index[k]].copy()
        p = transforms.normalize99(imglist[k],omni=True)
        img0 = p**(np.log(bg)/np.log(np.mean(p[binary_erosion(mgt==0)])))
        
        mpred = masks[j][k].copy()
        dP_pred = flows[j][k][1].copy()
        cellprob = flows[j][k][2].copy()
        tr = flows[j][k][-1][0].copy()

        # make veritcal
        if ly<lx:
            print('transposing')
            mgt = np.transpose(mgt)
            img0 =  np.transpose(img0)
            dP_pred = np.transpose(dP_pred,(0,2,1))
            #traspose
            v1 =  dP_pred[1].copy()
            v2 =  dP_pred[0].copy()
            theta = 90
            dP_pred[0] = (-v1 * np.sin(-theta) + v2*np.cos(-theta))
            dP_pred[1] = -(v1 * np.cos(-theta) + v2*np.sin(-theta))
            tr = np.stack((tr[1],tr[0]))
            cellprob = np.transpose(cellprob)
            mpred = np.transpose(mpred)
            xy = [xy[2],xy[3],xy[0],xy[1]]
        
        #crop
        mpred_full = mpred.copy()
        mpred = mpred[xy[0]:xy[1],xy[2]:xy[3]]
        mgt = mgt[xy[0]:xy[1],xy[2]:xy[3]]
        cellprob = cellprob[xy[0]:xy[1],xy[2]:xy[3]]
        dP_pred = dP_pred[:,xy[0]:xy[1],xy[2]:xy[3]]
        img0 = img0[xy[0]:xy[1],xy[2]:xy[3]]
        
        match = np.argwhere([np.logical_and.reduce((xy[3]>=tr[1,i,0],xy[2]<=tr[1,i,0],xy[0]<=tr[0,i,0],xy[1]>=tr[0,i,0])) for i in range(tr.shape[1])]).flatten()
        # tr = tr[:,match,:]
        tr = np.stack(([tr[0,i,:]-xy[0] for i in match],[tr[1,i,:]-xy[2] for i in match]))
        
        flow_pred = plot.dx_to_circ(dP_pred,transparency=True)
        l = cell_list[k] # cell label list
        
        bini = mgt>0
        bin0 = np.stack((bini,bini,bini),axis=2)
        
        savedir = os.path.join(basedir, img_list[k] + '_cell_number_'+ str(l))
        if not os.path.isdir(savedir):
            os.mkdir(savedir)
        path = os.path.join(savedir,'perim_flows' + suffix)
        print(path)
        
        skimage.io.imsave(os.path.join(savedir,'phase'+ext),np.uint8(img0*255))
        skimage.io.imsave(os.path.join(savedir,'flow'+suffix+ext),np.uint8(flow_pred*255))
        pic = cmap(.95*utils.rescale(cellprob))
        pic[:,:,-1] = utils.rescale(cellprob)
        skimage.io.imsave(os.path.join(savedir,'cellprob'+suffix+ext),((pic)*255).astype(np.uint8))
        
        if isolated:
            mag = transforms.normalize99(np.sqrt(np.sum(dP_pred**2,axis=0)),omni=True)
            f = flow_pred.copy()
            flow_gray = 0.2125*f[:,:,0] + 0.7154*f[:,:,1] + 0.0721*f[:,:,2]
            m = bini==0
            f[m] = np.stack([flow_gray,flow_gray,flow_gray,mag],axis=-1)[m]
            skimage.io.imsave(os.path.join(savedir,'flow_gray'+suffix+ext),np.uint8(f*255))
        
#         ncolor_gt = utils.ncolorlabel(mgt)
        m = mgt.copy()
        u = np.unique(mgt)
        U = len(u)
        A = 100
        v = [0]+list(np.linspace(.25,.55,U-1)*A)
        res = dict(zip(u, v))
        m = fastremap.remap(m,res,preserve_missing_labels=False, in_place=True)
        ncolor_gt = utils.rescale(utils.ncolorlabel(m.copy()))
        ncolor_gray = np.stack([utils.rescale(ncolor_gt)]*3,axis=-1)
        skimage.io.imsave(os.path.join(savedir,'ncolor_gray_masks'+ext),np.uint8(ncolor_gray*255))

        ncolor_gray = np.stack([1-ncolor_gt]*3,axis=-1)
        skimage.io.imsave(os.path.join(savedir,'ncolor_gray_masks_inv'+ext),np.uint8(ncolor_gray*255))

        outli = plot.outline_view(img0,mpred,color=cmap(.85)[:3],mode=mode)
        skimage.io.imsave(os.path.join(savedir,'outline_view_gold'+suffix+ext),np.uint8(outli*255))
        outli = plot.outline_view(img0,mpred,mode=mode)
        skimage.io.imsave(os.path.join(savedir,'outline_view_red'+suffix+ext),np.uint8(outli*255))
       
        mgt, remap = fastremap.renumber(mgt)
        iou = metrics._intersection_over_union(mgt, mpred)
        th = 0
        n_min = min(iou.shape[0], iou.shape[1])
        costs = -(iou >= th).astype(float) - iou / (2*n_min)
        true_ind, pred_ind = linear_sum_assignment(costs)
        miou = iou[true_ind, pred_ind]

        # print('match',miou,pred_ind,mask_match,pred_inds,true_ind)
        with open(os.path.join(savedir,'MMiou'+suffix+'.txt'), "w") as text_file:
            print(f"Mean Matched IoU: {np.mean(miou)}", file=text_file)
                               
        if do_flows:
            dists = edt.edt(mpred)
            bd = dists==1
            
            Y,X = np.nonzero(bd)
            # Y = Y[np.logical_and(Y>=xy[0],Y<=xy[1])]
            # X = X[np.logical_and(X>=xy[2],X<=xy[3])]
            a = .5
            match0 = [np.any(np.logical_and((X-tr[1,i,0])**2<a,(Y-tr[0,i,0])**2<a)) for i in range(tr.shape[1])]
            select_inds = np.argwhere(match0).flatten()
            
            Y,X = np.nonzero(mpred)
            # Y = Y[np.logical_or(Y>=xy[0],Y<=xy[1])]
            # X = X[np.logical_or(X>=xy[2],X<=xy[3])]
            a = .5
            match2 = [np.any(np.logical_and((X-tr[1,i,0])**2<a,(Y-tr[0,i,0])**2<a)) for i in range(tr.shape[1])]
            select_inds2 = np.argwhere(match2).flatten()

            lx = mpred.shape[1]
            ly = mpred.shape[0]

            fig,ax = plt.subplots(figsize=(ly*px,lx*px))
            # ax.imshow(dists)
            ax.set_aspect('equal')
            ax.axis('off')
            ax.set_position([0, 0, 1, 1])
            
            for i in select_inds:
                xs = tr[1,i,:]
                ys = tr[0,i,:]
                if darkmode:
                    c = [1,1,1,.5]
                    s2 = '_darkmode'
                else:
                    c = [0,0,0,.25]
                    s2 = '_lightmode'
                ax.plot(xs,ys,c=c,solid_capstyle='round',linewidth=1/2,zorder=1)
                
            for i in select_inds2:
                ax.scatter(tr[1,i,-1],tr[0,i,-1],marker='.',s=px*lx,edgecolor=None,facecolor='r',zorder=2)
            
            ax.set_xlim([0,lx])
            ax.set_ylim([ly,0])
            ax.patch.set_alpha(0.)
            fig.patch.set_facecolor(None)

            plt.show()
            
            fig.savefig(path+suffix+s2+'.pdf',bbox_inches='tight',transparent=True,pad_inches=0)
            fig.savefig(path+suffix+s2+'.png',bbox_inches='tight',transparent=True,pad_inches=0)
        

In [None]:
# # p.shape,mgt.shape,dP_pred.shape
# # ax.scatter(tr[1,i,-1],tr[0,i,-1],marker='.',s=px*lx,edgecolor=None,facecolor='r',rasterized=True)
# # tr = flows[j][k][-1][0].copy()
# # # plt.plot(xs,ys,solid_capstyle='round',linewidth=1/2,rasterized=True,color='r')
# # # plt.show()
# # tr.shape
# # np.sum(match)
# # len(flows[j][k][-1])
# # np.logical_and(np.array(1),np.array(1),np.array(0))
# # [np.any(np.logical_and((X-tr[1,i,0])**2<a,(Y-tr[0,i,0])**2<a)) for i in range(tr.shape[1])]
# tr = flows[j][k][-1][0].copy()
# # # tr = np.stack((tr[0],tr[1]))
# # tr[0,1,0],Y.shape
# # # np.logical_and((X-tr[1,i,0])**2<a,(Y-tr[0,i,0])**2<a)
# # # np.any(np.logical_and((X-tr[1,i,0])**2<a,(Y-tr[0,i,0])**2<a))
# # # tr.shape[1]
# # # [np.any(np.logical_and((X-tr[1,i,0])**2<a,(Y-tr[0,i,0])**2<a)) for i in range(tr.shape[1])]
# match = np.argwhere([np.logical_and.reduce((xy[3]>=tr[1,i,0],xy[2]<=tr[1,i,0],xy[0]<=tr[0,i,0],xy[1]>=tr[0,i,0])) for i in range(tr.shape[1])]).flatten()
# tr = np.stack(([tr[0,i,:]-xy[0] for i in match],[tr[1,i,:]-xy[2] for i in match]))

# # # tr.shape,len(match)
# # len(match),tr.shape,tr[:,match].shape
# msk = np.zeros_like(mgt)
# msk[tr[0,:,0].astype(int),tr[1,:,0].astype(int)] == 1
# # np.argwhere(match).shape
# mgt.shape,ly,lx,mgt.shape[0]<mgt.shape[1]
px*lx<2

In [None]:
# # regions = measure.regionprops(masks_OG)
# # areas = np.array([reg.area for reg in regions])
# # pred_areas[j] = pred_areas[j] + [areas]
# ovp = metrics._label_overlap(mgt, masks_OG)[1:,1:]
# # ovr = np.array(ovp / areas[np.newaxis,:]) # Overlap Ratio       
# #         pred_inds = np.argwhere(ovr[cell_ind-1]>.9)
# pred_inds = np.argwhere(ovp[cell_ind-1]>np.percentile(ovp,50))
# pred_inds.flatten()
# pred_inds+=1
# # predmask = np.any(np.stack([masks_OG == x for x in pred_inds]),axis=0)
# # pred_inds,pred_inds0,np.unique(masks_OG),ovr,iou
# predmask = np.any(np.stack([masks_OG == x for x in pred_inds]),axis=0)
# plt.imshow(predmask)
# # len(pred_inds),len(np.unique(masks_OG*predmask))
# pred_inds

# np.unique(masks_OG[Y,X])
msk = np.zeros_like(mgt)
for i in range(tr.shape[1]):
    msk[int(tr[0,i,-1]),int(tr[1,i,-1])]=1 
plt.figure(figsize=(10,10))
plt.imshow(msk)

In [None]:
fig = plt.figure(figsize=(20,20))
# plt.imshow(ncolorlabel(masks_OG))
plt.imshow(masks_OG==6)
np.unique(masks_OG)

In [None]:
ovp

In [None]:
# plt.imshow(flows_OG)
# fig = plt.figure(figsize=(10,10))
# plt.imshow(ncolorlabel(mgt))
m = mgt.copy()
k = np.unique(mgt)
K = len(k)
A = 100
v = [0]+list(np.linspace(.25,.75,K-1)*A)
res = dict(zip(k, v))
print(res,k)

m = fastremap.remap(m,res,preserve_missing_labels=True, in_place=True)
print(np.max(m/A))
plt.imshow(m/A,cmap='gray')

In [None]:
# np.random.seed(0)
# x = np.linspace(0,10,10)
# np.random.shuffle(x)
# x
54/72

In [None]:
            mask_new =  masks[-1][k][y0:y1,x0:x1]
            flow_new = transforms.normalize99(flows[-1][k][y0:y1,x0:x1],omni=True)
            inds = np.where(hits_new[l,:])
#                     print('new inds',inds,np.unique(mask_new))
            rmask = np.zeros_like(mask_gt)
            for i,label in enumerate(inds[0]):
                rmask[mask_new==label+1] = i+1

            bin_new = np.stack((rmask>0,rmask>0,rmask>0),axis=2)
            outl_new = plot.outline_view(img0,rmask)
#                     new_pic = np.hstack((cm2(ncolorlabel(mask_new))[:,:,:3],outl_new,flow_new))


#                     img = np.vstack((gt_pic,OG_pic,new_pic))
#                     plt.imshow(img)
#                     plt.axis('off')
#                     plt.show()

#                     x = range(len(r_OG[l,:]))
#                     x2 = range(len(r_new[l,:]))
#                     plt.plot(x,r_OG[l,:],'r',x2,r_new[l,:],'w')
#                     plt.show()


#                     io.imsave(savepath,np.uint8(img*255))
            savedir = os.path.join(basedir, name)
            if not os.path.isdir(savedir):
                os.mkdir(savedir)
            io.imsave(os.path.join(savedir,'phase'+ext),np.uint8(img0*255))
            io.imsave(os.path.join(savedir,'outlines_GT'+ext),np.uint8(outli*255))
            io.imsave(os.path.join(savedir,'ncolor_GT'+ext),np.uint8(cm3(ncolor_gt)[:,:,:3]*255))
            io.imsave(os.path.join(savedir,'outlines_OG'+ext),np.uint8(outl_OG*255))
            io.imsave(os.path.join(savedir,'flow_OG'+ext),np.uint8(flow_OG*255))
            io.imsave(os.path.join(savedir,'flow_OG_isolated'+ext),np.uint8(bin0*flow_OG*255))
            io.imsave(os.path.join(savedir,'flow_OG_isolated_inv'+ext),np.uint8(rot_flow(bin0*flow_OG)*255))
            io.imsave(os.path.join(savedir,'binary_mask'+ext),np.uint8(bin0*255))
            io.imsave(os.path.join(savedir,'binary_mask_inv'+ext),np.uint8((1-bin0)*255))
            io.imsave(os.path.join(savedir,'outlines_new'+ext),np.uint8(outl_new*255))
            io.imsave(os.path.join(savedir,'flow_new'+ext),np.uint8(flow_new*255))
            io.imsave(os.path.join(savedir,'flow_new_isolated'+ext),np.uint8(bin_new*flow_new*255))
            io.imsave(os.path.join(savedir,'flow_new_inverted'+ext),np.uint8(rot_flow(flow_new)*255))
            io.imsave(os.path.join(savedir,'flow_new_isolated_inv'+ext),np.uint8(rot_flow(bin_new*flow_new)*255))


            ncolor_gray = np.stack([utils.rescale(ncolor_gt)]*3,axis=-1)
            ncolor_gray[mask_gt==l+1] = [1,0,0]
            io.imsave(os.path.join(savedir,'ncolor_gray_masks'+ext),np.uint8(ncolor_gray*255))

            ncolor_gray = np.stack([1-utils.rescale(ncolor_gt)]*3,axis=-1)
            ncolor_gray[mask_gt==l+1] = [1,0,0]
            io.imsave(os.path.join(savedir,'ncolor_gray_masks_inv'+ext),np.uint8(ncolor_gray*255))

            f = flow_OG.copy()
            flow_gray = 0.2125*f[:,:,0] + 0.7154*f[:,:,1] + 0.0721*f[:,:,2]
            m = bini==0
            f[m] = np.stack([flow_gray]*3,axis=-1)[m]
            io.imsave(os.path.join(savedir,'flow_OG_gray'+ext),np.uint8(f*255))
            io.imsave(os.path.join(savedir,'flow_OG_gray_inv'+ext),np.uint8(rot_flow(f)*255))


            if p.shape[0]<p.shape[1]:
                    stack = np.vstack
            else:
                stack  = np.hstack

            stack_OG = stack((img0,outl_OG,flow_OG))
            stack_new = stack((img0,outl_new,flow_new))
            stack_OG_inv = stack((img0,outl_OG,flow_OG))
            stack_new_inv = stack((img0,outl_new,flow_new))
            stack_OG_isolated = stack((img0,outl_OG,bin0*flow_OG))
            stack_new_isolated = stack((img0,outl_new,bin_new*flow_new))
#                     plt.imshow(new_stack)
#                     plt.show()
            io.imsave(os.path.join(savedir,'stack_OG'+ext),np.uint8(stack_OG*255))
            io.imsave(os.path.join(savedir,'stack_new'+ext),np.uint8(stack_new*255))

            io.imsave(os.path.join(savedir,'stack_OG_isolated'+ext),np.uint8(stack_OG_isolated*255))
            io.imsave(os.path.join(savedir,'stack_new_isolated'+ext),np.uint8(stack_new_isolated*255))


#                 im = Image.open(savepath)

In [None]:
# something pretty surprising: there aren't all that many cells that get oversegmented, and omnipose and distpose are super close - edit: no longer!
# this metric doesn't tell the whole story, but the differnet thrsholds do show there is more *bad* oversegmentation in the original cellpose
# as for the rest, it might be the case that a lot of these are edge cells, especially the longer ones

darkmode = False
if darkmode:
    plt.style.use('dark_background')
    axcol = 'w'
    colors = sinebow(n+1)
    colors = [colors[j+1] for j in range(len(names))]
else:
    mpl.rcParams.update(mpl.rcParamsDefault)
    axcol = 'k'
#     cmap = mpl.cm.get_cmap('viridis')
#     colors = cmap(np.linspace(0,.9,len(names)))
    
fig = plt.figure(figsize=(sz, sz/golden)) 
ax = plt.axes()
for j in range(n):
    x = np.concatenate(cell_errors[j],axis=0)
    ax.hist(x,bins=np.arange(0,80, 1),color = colors[j],log=True,range=(0,80),label=names[j],align='left')
    print(names[j] +' error count:',np.count_nonzero(x))
    
print('Number of cells:',x.shape[-0])
# plt.xscale('log')
ax.legend(prop={'size': labelsize},  loc='upper left', frameon=False,bbox_to_anchor=(1.05, 1),markerscale=3)

ax.tick_params(axis='both', which='major', labelsize=labelsize,length=3, direction="out",colors=axcol,bottom=True,left=True)
plt.xticks(np.arange(0,80, 5))
ax.set_ylabel('Number of cells', fontsize = labelsize)
ax.set_xlabel('Segmentation errors', fontsize = labelsize)
plt.show()

From cell errors we can look at tons of examples. My current code just gives every single instance of an error. We see from the above histogram that my version barely gives any errors above 1 (these are issues with cell division, I think). 

- Enforce a minimum crop size 
- Sort by number of errors 
- Dump images into a PDF with info 

In [None]:
import os
from scipy.ndimage.morphology import binary_erosion 
from cellpose import utils, transforms, plot
import cv2
from cellpose.io import imsave
from PIL import Image, ImageFont, ImageDraw, ImageOps
from matplotlib.colors import ListedColormap
import os, datetime, gc, warnings, glob
from fpdf import FPDF

plt.style.use('dark_background')
cmap = mpl.cm.get_cmap('viridis')
cm2 = ListedColormap([color for color in sinebow(5).values()])

basedir = '/home/kcutler/DataDrive/omnipose_paper/Comparison Examples/Per_Cell_GT_Comparison/'
io.check_dir(basedir)
    
cutoff=1
pad = 10
N = len(masks)
nimg = len(masks_gt)
bkct = 500
bg = 0.5


txtpad = 10
yoffset = [0,0]
buffer = 30
do_pdf = 0
clean = 1
# for k in range(nimg):
# for k in [0]:
for j in [0]:
    if do_pdf:
        pdf = FPDF(unit='pt')
        pdf.add_page()
        pdf.set_font('Arial', 'B', 8)
        pdf.set_text_color(r=255, g=255, b=255)
        pdf.set_fill_color(0,0,0)
        ph = pdf.h
        pw = pdf.w
        pdf.page_break_trigger = ph
        pdf.rect(0,0,pw,ph,'F')
        pdf.set_xy(0,0)
        ind = 0
        cnt = 0
    
    
    for k in range(nimg):
#     for k in range(50):
        file = img_names[k]
        basename = os.path.splitext(os.path.split(file)[-1])[0]

        label_list = []
        err = cell_errors[j][k]      
        err_indexes = np.nonzero(err>=cutoff)[0]
        num_errors = err_indexes.size
        if num_errors > 0:
            
            for l in err_indexes:
                name = basename + '_' + names[j] + '_cell_number_'+ str(l) + '.png'
                savepath = os.path.join(basedir, name)
                print(j,k,l,name)
                
                if not os.path.isfile(savepath) or clean: 
                    mgt = masks_gt_clean[k]
                    p = transforms.normalize99(imgs[k],omni=True)
                    img0 = p**(np.log(bg)/np.log(np.mean(p[binary_erosion(mgt==0)])))

                    y,x = np.nonzero(mgt==l+1)
                    max_y,max_x = np.array(mgt.shape)-1

                    y0 = max(0,min(y)-pad)
                    y1 = min(max_y,max(y)+pad)
                    x0 = max(0,min(x)-pad)
                    x1 = min(max_x,max(x)+pad)

                    p = img0[y0:y1,x0:x1]

                    pic = [[]]*2
                    img0 = np.stack((p,p,p),axis=2)
                    maski = masks[j][k][y0:y1,x0:x1]
                    flowi = transforms.normalize99(flows[j][k][y0:y1,x0:x1],omni=True)
                    outli = transforms.normalize99(plot.outline_view(img0,maski),omni=True)
                    pic[0] = np.hstack((img0,outli,flowi))

                    maski = mgt[y0:y1,x0:x1]
    #                 print(np.unique(maski),l)
                    bini = maski==l+1
                    bini = np.stack((bini,bini,bini),axis=2)
                    outli = transforms.normalize99(plot.outline_view(img0,maski),omni=True)
                    pic[1] = np.hstack((cm2(utils.ncolorlabel(maski))[:,:,:3],outli,bini))

                    img = np.vstack((pic[0],pic[-1]))
                
                    io.imsave(savepath,np.uint8(img*255))
                
                if do_pdf:
                    im = Image.open(savepath)
                    y = pdf.get_y()

                    h = ((pw/2)/im.size[0])*im.size[1]
                    if yoffset[ind]+h+txtpad+buffer>=ph:
                        pdf.add_page()
                        pdf.rect(0,0,pw,ph,'F')
                        pdf.set_xy(0,0)
                        yoffset = [0,0]
                        ind = 0
                        pdf.set_xy(0,0)
                        cnt = 0
                    elif cnt==0 or cnt==1:
                        pdf.set_xy(ind*pw/2,yoffset[ind])
                    else:
                        pdf.set_xy(ind*pw/2,yoffset[ind]+txtpad)

                    if h<ph:
                        pdf.image(savepath,w=pw/2)
                    else:
                        pdf.image(savepath,h=ph-2*txtpad)
                    pdf.cell(30, 10, txt=name)
    #                 pdf.cell(30, 10, txt=str(j)+','+str(k)+','+str(l))
    #                 pdf.cell(30, txtpad, str(ind))
    #                 print(ind,yoffset)
                    yoffset[ind] = pdf.get_y()
                    ind = (ind+1)%2
                    cnt += 1
    if do_pdf:         
        pdf.output('/home/kcutler/DataDrive/'+names[j]+'_errors.pdf','F')
                


In [None]:
#older version relying just on saved flows 

import os
from scipy.ndimage.morphology import binary_erosion 
from cellpose import utils, transforms
from cellpose.utils import ncolorlabel
import cv2
from cellpose.io import imsave
from PIL import Image, ImageFont, ImageDraw, ImageOps
from matplotlib.colors import ListedColormap
import os, datetime, gc, warnings, glob


def rot_flow(f):
    r = f[:,:,0]
    g = f[:,:,1]
    b = f[:,:,2]
    a = np.sqrt(3)
#     f = np.stack(((g+b)/a,(b+r)/a,(g+r)/a),axis=-1)
    return 1-f

plt.style.use('dark_background')
cmap = mpl.cm.get_cmap('viridis')
cm2 = ListedColormap([color for color in sinebow(5).values()])
cm3 = ListedColormap([color for color in sinebow(5,bg_color=[1,1,1,1]).values()])

basedir = '/home/kcutler/DataDrive/omnipose_paper/Comparison Examples/newcompare3/'
if not os.path.isdir(basedir):
    os.mkdir(basedir)

area_cutoff = 20
error_cutoff=2
pad = 10
N = len(masks)
nimg = len(masks_gt)
bkct = 500
bg = 0.5
ext = '.png'

cnt = 0

txtpad = 10
yoffset = [0,0]
buffer = 15

img_list = ['ftsN_ensemble_0','caulo_15','streptomyces_XY15_1','Az_branch_ec_0','cex_xy1c1','Hpylori2_2','PSVB_ensemble_c_8']
cell_list = [10,8,1,33,219,44,160]
namelist = [os.path.splitext(os.path.split(file)[-1])[0] for file in img_names]
img_index = [namelist.index(im) for im in img_list]
[print(j,k) for j,k in zip(cell_list,img_index)]

# # for k in range(nimg):
# # for k in [0]:
for j in [0]:
#     ind = 0
#     for k in range(nimg):
# #     for k in range(26):

    # for l,k in zip(cell_list,img_index): # replace loops over k and l 
    for l,k in zip(cell_list,range(nimg)): # replace loops over k and l 
    

        file = img_names[k]
        basename = os.path.splitext(os.path.split(file)[-1])[0]

        label_list = []
        err = cell_errors[j][k]
        areas = cell_areas[k]
        err_indexes = np.nonzero(np.logical_and(err>=error_cutoff,areas>area_cutoff))[0]
        num_errors = err_indexes.size

        m = 2
        r_OG = OvR[j][k] 
        mx = np.max(r_OG,axis=0) 
        std = np.std(r_OG,axis=0)
        mx[mx==0] = np.nan
        hits_OG = r_OG>=(mx-m*std)
    #         hits_OG = r_OG>=0.2

        r_new = OvR[-1][k] 
        mx = np.max(r_new,axis=0) 
        std = np.std(r_new,axis=0)
        mx[mx==0] = np.nan
        hits_new = r_new>=(mx-m*std)
    #         hits_new = r_new>= 0.2

        if num_errors > 0:

#             for l in err_indexes:
            name = basename + '_' + names[j] + '_cell_number_'+ str(l)
            savepath = os.path.join(basedir, name + ext)

            overwrite = True
            if not os.path.isfile(savepath) or overwrite:
                mgt = masks_gt_clean[k]
                p = transforms.normalize99(imgs[k],omni=True)
                img0 = p**(np.log(bg)/np.log(np.mean(p[binary_erosion(mgt==0)])))

                y,x = np.nonzero(mgt==l+1)
                max_y,max_x = np.array(mgt.shape)-1

                y0 = max(0,min(y)-pad)
                y1 = min(max_y,max(y)+pad)
                x0 = max(0,min(x)-pad)
                x1 = min(max_x,max(x)+pad)

                p = img0[y0:y1,x0:x1]

                mask_gt = mgt[y0:y1,x0:x1]
                ncolor_gt = utils.ncolorlabel(mask_gt)
#                 print(np.unique(maski),l)
                bini = mask_gt==l+1
                img0 = np.stack((p,p,p),axis=2)
                outli = plot.outline_view(img0,bini)
                bin0 = np.stack((bini,bini,bini),axis=2)
#                     bin0 = np.zeros_like(img0)
                gt_pic = np.hstack((img0,outli,bin0))
                mask_OG = masks[j][k][y0:y1,x0:x1]
                flow_OG = transforms.normalize99(flows[j][k][y0:y1,x0:x1],omni=True)
#                     print(bini.shape,flow_OG.shape)
#                     flow_OG = np.hstack([flow_OG[:,:,i]*bini for i in range(flow_OG.shape[-1])])
# #                     flow_OG = np.multiply(flow_OG,bini)
#                     print(flow_OG.shape)


                inds = np.where(hits_OG[l,:])
#                     print('OG inds',inds,np.unique(mask_OG))
                rmask = np.zeros_like(mask_gt)
                for i,label in enumerate(inds[0]):
                    rmask[mask_OG==label+1] = i+1

                outl_OG = plot.outline_view(img0,rmask)
#                     OG_pic = np.hstack((cm2(ncolorlabel(mask_OG))[:,:,:3],outl_OG,flow_OG))

                mask_new =  masks[-1][k][y0:y1,x0:x1]
                flow_new = transforms.normalize99(flows[-1][k][y0:y1,x0:x1],omni=True)
                inds = np.where(hits_new[l,:])
#                     print('new inds',inds,np.unique(mask_new))
                rmask = np.zeros_like(mask_gt)
                for i,label in enumerate(inds[0]):
                    rmask[mask_new==label+1] = i+1

                bin_new = np.stack((rmask>0,rmask>0,rmask>0),axis=2)
                outl_new = plot.outline_view(img0,rmask)
#                     new_pic = np.hstack((cm2(ncolorlabel(mask_new))[:,:,:3],outl_new,flow_new))


#                     img = np.vstack((gt_pic,OG_pic,new_pic))
#                     plt.imshow(img)
#                     plt.axis('off')
#                     plt.show()

#                     x = range(len(r_OG[l,:]))
#                     x2 = range(len(r_new[l,:]))
#                     plt.plot(x,r_OG[l,:],'r',x2,r_new[l,:],'w')
#                     plt.show()


#                     io.imsave(savepath,np.uint8(img*255))
                savedir = os.path.join(basedir, name)
                io.check_dir(savedir)
                io.imsave(os.path.join(savedir,'phase'+ext),np.uint8(img0*255))
                io.imsave(os.path.join(savedir,'outlines_GT'+ext),np.uint8(outli*255))
                io.imsave(os.path.join(savedir,'ncolor_GT'+ext),np.uint8(cm3(ncolor_gt)[:,:,:3]*255))
                io.imsave(os.path.join(savedir,'outlines_OG'+ext),np.uint8(outl_OG*255))
                io.imsave(os.path.join(savedir,'flow_OG'+ext),np.uint8(flow_OG*255))
                io.imsave(os.path.join(savedir,'flow_OG_isolated'+ext),np.uint8(bin0*flow_OG*255))
                io.imsave(os.path.join(savedir,'flow_OG_isolated_inv'+ext),np.uint8(rot_flow(bin0*flow_OG)*255))
                io.imsave(os.path.join(savedir,'binary_mask'+ext),np.uint8(bin0*255))
                io.imsave(os.path.join(savedir,'binary_mask_inv'+ext),np.uint8((1-bin0)*255))
                io.imsave(os.path.join(savedir,'outlines_new'+ext),np.uint8(outl_new*255))
                io.imsave(os.path.join(savedir,'flow_new'+ext),np.uint8(flow_new*255))
                io.imsave(os.path.join(savedir,'flow_new_isolated'+ext),np.uint8(bin_new*flow_new*255))
                io.imsave(os.path.join(savedir,'flow_new_inverted'+ext),np.uint8(rot_flow(flow_new)*255))
                io.imsave(os.path.join(savedir,'flow_new_isolated_inv'+ext),np.uint8(rot_flow(bin_new*flow_new)*255))
                
                
                ncolor_gray = np.stack([utils.rescale(ncolor_gt)]*3,axis=-1)
                ncolor_gray[mask_gt==l+1] = [1,0,0]
                io.imsave(os.path.join(savedir,'ncolor_gray_masks'+ext),np.uint8(ncolor_gray*255))
                
                ncolor_gray = np.stack([1-utils.rescale(ncolor_gt)]*3,axis=-1)
                ncolor_gray[mask_gt==l+1] = [1,0,0]
                io.imsave(os.path.join(savedir,'ncolor_gray_masks_inv'+ext),np.uint8(ncolor_gray*255))
                
                f = flow_OG.copy()
                flow_gray = 0.2125*f[:,:,0] + 0.7154*f[:,:,1] + 0.0721*f[:,:,2]
                m = bini==0
                f[m] = np.stack([flow_gray]*3,axis=-1)[m]
                io.imsave(os.path.join(savedir,'flow_OG_gray'+ext),np.uint8(f*255))
                io.imsave(os.path.join(savedir,'flow_OG_gray_inv'+ext),np.uint8(rot_flow(f)*255))
                

                if p.shape[0]<p.shape[1]:
                        stack = np.vstack
                else:
                    stack  = np.hstack

                stack_OG = stack((img0,outl_OG,flow_OG))
                stack_new = stack((img0,outl_new,flow_new))
                stack_OG_inv = stack((img0,outl_OG,flow_OG))
                stack_new_inv = stack((img0,outl_new,flow_new))
                stack_OG_isolated = stack((img0,outl_OG,bin0*flow_OG))
                stack_new_isolated = stack((img0,outl_new,bin_new*flow_new))
#                     plt.imshow(new_stack)
#                     plt.show()
                io.imsave(os.path.join(savedir,'stack_OG'+ext),np.uint8(stack_OG*255))
                io.imsave(os.path.join(savedir,'stack_new'+ext),np.uint8(stack_new*255))

                io.imsave(os.path.join(savedir,'stack_OG_isolated'+ext),np.uint8(stack_OG_isolated*255))
                io.imsave(os.path.join(savedir,'stack_new_isolated'+ext),np.uint8(stack_new_isolated*255))


#                 im = Image.open(savepath)

In [None]:
# [print(j,k) for j,k in zip([1,2,3],[4,5,6])]
# zip([1,2,3],[4,5,6])
f = flow_OG.copy()
# flow_gray = 0.2125*f[:,:,0] + 0.7154*f[:,:,1] + 0.0721*f[:,:,2]
flow_gray = utils.rescale(mask_gt)
m = bini==0
f[m] = np.stack([flow_gray]*3,axis=-1)[m]
plt.imshow(f)
plt.axis('off')
plt.show()
bini

In [None]:
# plt.imshow(plot.outline_view(img0,bini))
# plt.show()
# p = transforms.normalize99(imgs[k],omni=True)
# np.max(p)
from skimage.color import rgb2hsv, hsv2rgb
f = utils.rescale(flow_OG.copy())
# a = 2
# # a2 = np.sqrt(3)/2
# a2 = 1
# a = np.sqrt(3)
# for j in range(f.shape[-1]):
#     c = f[:,:,j]
#     c[mask_gt==0] = 1
#     f[:,:,j] = c
r = f[:,:,0]
angle = np.arccos(r-1)+np.pi
a = 1
b = 2
r = ((np.cos(angle)+a)/b)
g = ((np.cos(angle+2*np.pi/3)+a)/b)
b = ((np.cos(angle+4*np.pi/3)+a)/b)
# # rotated basis vectors compard to new ones; 120 degree angles between everything
# f = 1-np.stack(((f[:,:,2]+f[:,:,1])/a,(f[:,:,2]+f[:,:,0])/a,(f[:,:,1]+f[:,:,0])/a),axis=-1)
# f = np.stack(((g+b*a2)/a,(b+r*a2)/a,(g+r)/a),axis=-1)
# f = np.stack(((np.cos(np.pi)*r+1)/2,(np.cos(np.pi+2*np.pi/3)*g+1)/2,(np.cos(np.pi+4*np.pi/3)*b+1)/2),axis=-1)
f = np.stack((r,g,b),axis=-1)
# f = np.stack((r,g,b),axis=-1)
# hsv_img = rgb2hsv(1-f)
# hsv_img[:,:,0]+=.3
# f[f==[0,0,0]] = 1
# plt.imshow(hsv2rgb(hsv_img))
plt.imshow(f)
plt.show()
# flows[0][0].shape

In [None]:
# hsv_img[:,:,0].max()
# np.cos(60h =)
f = utils.rescale(flow_OG.copy())
h = rgb2hsv(f)
angle = h[:,:,0]+np.pi
a = 1
b = 2
r = ((np.cos(angle)+a)/b)
g = ((np.cos(angle+2*np.pi/3)+a)/b)
b = ((np.cos(angle+4*np.pi/3)+a)/b)
f = np.stack((r,g,b),axis=-1)
plt.imshow(g)
plt.show()

In [None]:
masks_cmp = io.imread('/home/kcutler/DataDrive/merge_ftsN/xy1/edited_labels.tif')

In [None]:
from cellpose import utils
n = len(masks_cmp)
print(n)
diam_old = []
diam_new = []

x = range(n)
for k in x:#random.sample(range(0, n), 2):

    m = masks_cmp[k]
    fastremap.renumber(m,in_place=True)
    diam_old.append(utils.diameters(m,omni=False)[0])
    diam_new.append(utils.diameters(m,omni=True)[0])

In [None]:
def set_size(w,h, ax=None):
    """ w, h: width, height in inches """
    if not ax: ax=plt.gca()
    l = ax.figure.subplotpars.left
    r = ax.figure.subplotpars.right
    t = ax.figure.subplotpars.top
    b = ax.figure.subplotpars.bottom
    figw = float(w)/(r-l)
    figh = float(h)/(t-b)
    ax.figure.set_size_inches(figw, figh)

from cellpose.utils import sinebow
master_color_scheme = [[i,0,0] for i in np.linspace(1,.5,3)]+[[i,i,i] for i in np.linspace(.75,0,n-3)]
golden = (1 + 5 ** 0.5) / 2
sz = 2.5
labelsize = 7

%matplotlib inline
darkmode = False
if darkmode:
    plt.style.use('dark_background')
    axcol = 'w'
    colors = sinebow(n+1)
    colors = [colors[j+1] for j in range(n)]
    background_color = 'k'
else:
    mpl.rcParams.update(mpl.rcParamsDefault)
    axcol = 'k'
    cmap = mpl.cm.get_cmap('viridis')
#     colors = cmap(np.linspace(0,.9,len(names)))
    colors = master_color_scheme
    background_color = np.array([242,242,242])/255
    
mpl.rcParams['figure.dpi'] = 300

fig = plt.figure(figsize=(sz, sz/golden)) 
ax = plt.axes()

plt.plot(range(n),diam_old,'r',range(n),diam_new,'k')

ax.legend(prop={'size': labelsize}, loc='best', frameon=False)
ax.tick_params(axis='both', which='major', labelsize=labelsize,length=3, direction="out",colors=axcol,bottom=True,left=True)
ax.tick_params(axis='both', which='minor', labelsize=labelsize,length=3, direction="out",colors=axcol,bottom=True,left=True)
ax.set_ylabel('Diameter Metric', fontsize = labelsize)
ax.set_xlabel('Frame number', fontsize = labelsize)
# plt.ylim([0,100])
fig.patch.set_facecolor(background_color)
plt.show()

name = '/home/kcutler/DataDrive/omnipose_paper/diameter metric_ftsN_comparison'
fig.savefig(name+'.eps',bbox_inches='tight')
fig.savefig(name+'.png',bbox_inches='tight')

In [None]:
cmap(.85)


In [None]:
217+157