# Figure 6
This noebook extracts a particular example from Luca's high-resolution *C. elegans* dataset. These worms are far larger in terms of pixel diameter than any objects in our training sets, such as 

In [None]:
%load_ext autoreload
%autoreload 2

# First, import dependencies.
import numpy as np
import time, os, sys
from cellpose import io, transforms, plot, models, core
import omnipose
import skimage.io


# for plotting 
import matplotlib.pyplot as plt
plt.style.use('dark_background')
import matplotlib as mpl
%matplotlib inline
mpl.rcParams['figure.dpi'] = 300

# This checks to see if you have set up your GPU properly.
# CPU performance is a lot slower, but not a problem if you are only processing a few images.
use_GPU = core.use_gpu()
print('>>> GPU activated? %d'%use_GPU)

# grab specific files
# mask = io.imread('/home/kcutler/DataDrive/luca/train/0_masks.png')
# img = io.imread('/home/kcutler/DataDrive/luca/train/0.png')

basedir = '/home/kcutler/DataDrive/luca/test_manual'
mask_filter = '_masks'
img_names = io.get_image_files(basedir,mask_filter)
mask_names,flow_names = io.get_label_files(img_names, mask_filter)
imgs = [skimage.io.imread(f) for f in img_names]
masks_gt = [skimage.io.imread(f) for f in mask_names]

In [None]:
# I trained this model with a diameter of 60. Unfortunately, diam_mean is not loaded with the model
modeldir = '/home/kcutler/DataDrive/luca/train_good_old/models/cellpose_residual_on_style_on_concatenation_off_omni_train_good_2022_04_21_01_30_25.911493_epoch_3999'
diam_mean = 60
omni = True
model = models.CellposeModel(gpu=use_GPU, pretrained_model=modeldir, net_avg=False, diam_mean=diam_mean, omni=omni)

In [None]:
diams = [omnipose.core.diameters(mask) for mask in masks_gt]
diams

In [None]:
chans = [0,0]
mask_threshold = -1
diam_threshold = 12
net_avg = 0
cluster = 0
verbose = 0
use_gpu = 1

rs = diam_mean/np.array(diams)

masks, flows, styles = model.eval(imgs,channels=chans,rescale=rs,mask_threshold=mask_threshold,net_avg=net_avg,
                                  transparency=True,flow_threshold=0,omni=omni,resample=True,verbose=verbose,
                                  diam_threshold=diam_threshold,cluster=cluster,tile=False)

In [None]:
from cellpose import plot
import omnipose
nimg = len(imgs)
for k in range(nimg):
    maski = masks[k]
    flowi = flows[k][0]

    fig = plt.figure(figsize=(12,5))
    plot.show_segmentation(fig, imgs[k], maski, flowi, channels=chans, omni=omni, bg_color=0)

    plt.tight_layout()
    plt.show()

In [None]:
from omnipose import utils
from cellpose import metrics
from skimage import measure
import fastremap

threshold=np.linspace(0.5,1,100)

In [None]:
nimg = len(masks_gt)
cell_areas = [[]] * nimg
masks_gt_clean = [None]*nimg
masks_pred_clean = [[None]*nimg]*len(masks)
# remapping = [[]] * nimg
for j in range(nimg):
    mgt = utils.format_labels(utils.clean_boundary(masks_gt[j]))
    masks_gt_clean[j] = mgt
    regions = measure.regionprops(mgt)
    areas = np.array([reg.area for reg in regions])
    cell_areas[j] =  areas

In [None]:
    
api,tpi,fpi,fni = metrics.average_precision(masks_gt_clean,masks,threshold=threshold)



In [None]:
plt.plot(threshold,np.mean(api,axis=0))

In [None]:
mpl.rcParams.update(mpl.rcParamsDefault)
axcol = 'k'

from omnipose.utils import sinebow
names = ['omni']
n = len(names)
z = 1
master_color_scheme = [[i,0,0] for i in np.linspace(1,.5,z)]+[[i,i,i] for i in np.linspace(.75,.25,n-z)]

golden = (1 + 5 ** 0.5) / 2
sz = 1.5
labelsize = 7

%matplotlib inline
darkmode = 0
if darkmode:
    plt.style.use('dark_background')
    axcol = 'w'
    colors = sinebow(n+1)
    colors = [colors[j+1] for j in range(n)]
    background_color = 'k'
    suffix = '_dark'
else:
    mpl.rcParams.update(mpl.rcParamsDefault)
    axcol = 'k'
    cmap = mpl.cm.get_cmap('viridis')
#     colors = cmap(np.linspace(0,.9,len(names)))
    colors = master_color_scheme
    background_color = np.array([1,1,1,1])
    suffix = ''
    
mpl.rcParams['figure.dpi'] = 300

# could turn these results into a plot by repeating at different cutoffs 
fig = plt.figure(figsize=(sz, sz)) 
ax = plt.axes()
j=0
ax.plot(threshold,np.mean(api,axis=0),label=names[j],color=colors[j])
ax.legend(prop={'size': labelsize/1.5}, loc='best', frameon=False)

ax.tick_params(axis='both', which='major', labelsize=labelsize,length=3, direction="out",colors=axcol,bottom=True,left=True)
ax.tick_params(axis='both', which='minor', labelsize=labelsize,length=3, direction="out",colors=axcol,bottom=True,left=True)
ax.set_ylabel('Percent of cells \nwith >=1 error', fontsize = labelsize)
ax.set_ylabel('Jaccard Index', fontsize = labelsize)
ax.set_xlabel('IoU matching threshold', fontsize = labelsize)
# plt.set(xlabel='IoU threshold', ylabel='Average Precision',fontsize=labelsize)
# plt.tight_layout()
# plt.yscale('log')


ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.patch.set_alpha(0.0)
plt.ylim([0,1])
plt.xlim([.5,1])

ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.patch.set_alpha(0.0)

fig.patch.set_facecolor(background_color)
plt.show()

a = 35
tight_bbox_raw = ax.get_tightbbox(fig.canvas.get_renderer())
tight_bbox_raw._points+=[[-a,-a],[a,a]]
tight_bbox = mpl.transforms.TransformedBbox(tight_bbox_raw, mpl.transforms.Affine2D().scale(1./fig.dpi))


# name = '/home/kcutler/DataDrive/omnipose_paper/Figure 5/error_vs_area_percentiles'+suffix
# fig.savefig(name+'.eps',bbox_inches=tight_bbox)
# fig.savefig(name+'.pdf',bbox_inches=tight_bbox)
# fig.savefig(name+'.eps',pad_inches = 0.5)
# fig.savefig(name+'.png',pad_inches = 0.5)

In [None]:
import ncolor
save0 = os.path.join('/home/kcutler/DataDrive/omnipose_paper/Figure 6/luca_worm')
io.check_dir(save0)

# define outline color
cmap = mpl.cm.get_cmap('plasma')
outline_col = cmap(0.85)[:3]
ext = '.png'
mode = 'thick'

k = -1 # k denotes the model 
pad = int(diam/2)
n = [0]
mgt = mask
bin0 = mgt>0

inds = np.nonzero(bin0)
max_inds = np.array(bin0.shape)-1
slc = tuple([slice(max(0,min(inds[k])-pad),min(max_inds[k],max(inds[k])+pad)) for k in range(mgt.ndim)])

crop_img = img[(Ellipsis,)*(img.ndim-2)+slc]
crop_masks = masks[slc]
crop_flow = flows[0][slc]

crop_outli = plot.outline_view(crop_img,crop_masks,color=outline_col,mode=mode)

plt.imshow(crop_outli,interpolation='none')
plt.axis('off')
plt.show()

basedir = os.path.join(save0)
io.check_dir(basedir)
# save the cropped image, RGB uint8 is not interpolated in illustrator ;) 
img0 = np.stack([omnipose.utils.normalize99(crop_img),]*3,axis=-1)
savepath = os.path.join(basedir,'crop_img'+ext)
io.imsave(savepath,np.uint8(img0*(2**8-1)))

# save the outlines
savepath = os.path.join(basedir,'crop_outlines'+ext)
io.imsave(savepath,crop_outli)

# save the masks
savepath = os.path.join(basedir,'crop_masks'+ext)
io.imsave(savepath,np.uint8(crop_masks))

# save the flows
savepath = os.path.join(basedir,'crop_flows'+ext)
skimage.io.imsave(savepath,np.uint8(crop_flow))

# save the distance
savepath = os.path.join(basedir,'crop_dist'+ext)
dist = omnipose.utils.rescale(flows[2][slc])
cmap = mpl.cm.get_cmap('plasma')
pic = cmap(dist)
pic[:,:,-1] = crop_masks>0
skimage.io.imsave(savepath,np.uint8(pic*(2**8-1)))

# save the boundary
savepath = os.path.join(basedir,'crop_bd'+ext)
dist = omnipose.utils.rescale(flows[4][slc])
cmap = mpl.cm.get_cmap('viridis')
pic = cmap(dist)
pic[:,:,-1] = crop_masks>0
skimage.io.imsave(savepath,np.uint8(pic*(2**8-1)))

#save a grayscale version for adobe illustator vectorization 
ncl = ncolor.label(crop_masks)
grey_n = np.stack([1-omnipose.utils.rescale(ncl)]*3,axis=-1)
savepath = os.path.join(basedir,'masks_gray'+ext)
io.imsave(savepath,np.uint8(grey_n*(2**8-1)))

In [None]:
crop_img.shape

In [None]:
img0.max()