In [1]:
import numpy as np
from os import listdir
from os.path import isfile, join
import tifffile
from cellpose import models, io, core
import time
from sklearn.model_selection import train_test_split
from statistics import mean
from u_net import UNet
import torch
from skimage import measure

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
def get_data(path, num_imgs=4, set='01'):

    images_path = path + set + '/'
    onlyfiles = [f for f in listdir(images_path) if isfile(join(images_path, f))]
    onlyfiles.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
    if num_imgs > len(onlyfiles): num_imgs = len(onlyfiles)
    images = [np.squeeze(tifffile.imread(images_path +  onlyfiles[i])) for i in range(num_imgs)]
    images = [(image-np.min(image))/(np.max(image)-np.min(image)) for image in images]
    
    masks_path = path + set + '_GT/TRA/'
    onlyfiles = [f for f in listdir(masks_path) if isfile(join(masks_path, f))]
    onlyfiles = [val for val in onlyfiles if not val.endswith(".txt")]
    onlyfiles.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
    if num_imgs > len(onlyfiles): num_imgs = len(onlyfiles)
    masks = [np.squeeze(tifffile.imread(masks_path +  onlyfiles[i])) for i in range(num_imgs)]

    return images, masks

In [3]:
def get_IoU(predicted_masks,gt_masks):
    intersection_unions = []
    for i in range(len(predicted_masks)):
        intersection = np.logical_and(predicted_masks[i], gt_masks[i]).sum()
        union = np.logical_or(predicted_masks[i], gt_masks[i]).sum()
        intersection_unions.append(intersection/union)
    return mean(intersection_unions)

def get_dice(predicted_masks,gt_masks):
    dices = []
    for i in range(len(predicted_masks)):
        intersection = np.logical_and(predicted_masks[i], gt_masks[i]).sum()
        dice = (2*intersection)/(predicted_masks[i].sum() + gt_masks[i].sum())
        dices.append(dice)
    return mean(dices)

In [7]:
images_02, masks_02 = get_data("/Users/rehanzuberi/Downloads/distillCellSegTrack/" + 'datasets/Fluo-N2DL-HeLa/', num_imgs=92, set='02')
images_02_train, images_02_test, masks_02_train, masks_02_test = train_test_split(images_02, masks_02, test_size=0.2, random_state=42)

In [11]:
#Get cellpose predictions and the distilled U-Net predictions
cellpose_model = models.CellposeModel(gpu=core.use_gpu(), pretrained_model='/Users/rehanzuberi/Downloads/distillCellSegTrack/segmentation/train_dir/models/cellpose_trained_model')
cellpose_predicted_masks = cellpose_model.eval(images_02_test, batch_size=1, channels=[0,0], diameter=cellpose_model.diam_labels)[0]

In [13]:
#Get base cellpose 'cyto' predictions
cellpose_cyto_model = models.CellposeModel(gpu=core.use_gpu(),model_type='cyto')
cellpose_cyto_predicted_masks = cellpose_cyto_model.eval(images_02_test, batch_size=1, channels=[0,0])[0]

In [31]:
#Get distilled U-Net predictions
model = UNet()
#model = model.to('cuda:0')
model.load_state_dict(torch.load('/Users/rehanzuberi/Downloads/distillCellSegTrack/segmentation/train_dir/models/unet_trained_model',map_location=torch.device('cpu')))
distilled_predicted_masks = []
for test_image in images_02_test:
    #test_image = torch.from_numpy(test_image).float().unsqueeze(0).unsqueeze(0).to('cuda:0')
    test_image = torch.from_numpy(test_image).float().unsqueeze(0).unsqueeze(0)
    outputs = model(test_image).squeeze(0).squeeze(0).cpu().detach().numpy()
    outputs = 1/(1+np.exp(-outputs)) #sigmoid the outputs
    binary_outputs = np.where(outputs > 0.5, 1, 0)
    #predicted_instance_mask = measure.label(predicted_binary_mask, connectivity=2)
    distilled_predicted_masks.append(binary_outputs)

In [23]:
#Binarise the ground truth masks
masks_02_test_binary = [np.where(mask>0,1,0) for mask in masks_02_test]

#Binarise the cellpose predictions
cellpose_predicted_masks_binary = [np.where(mask>0,1,0) for mask in cellpose_predicted_masks]
cellpose_cyto_predicted_masks_binary = [np.where(mask>0,1,0) for mask in cellpose_cyto_predicted_masks]

In [33]:
#Get IoU and dice coeff between cellpose trained and groundtruth
print("IoU between cellpose trained and groundtruth: ", get_IoU(cellpose_predicted_masks_binary, masks_02_test_binary))
print("Dice coeff between cellpose trained and groundtruth: ", get_dice(cellpose_predicted_masks_binary, masks_02_test_binary))

#Get IoU and dice coeff between cellpose 'cyto' and groundtruth
print("IoU between cellpose 'cyto' and groundtruth: ", get_IoU(cellpose_cyto_predicted_masks_binary, masks_02_test_binary))
print("Dice coeff between cellpose 'cyto' and groundtruth: ", get_dice(cellpose_cyto_predicted_masks_binary, masks_02_test_binary))

#Get IoU and dice coeff between distilled U-Net and groundtruth
print("IoU between distilled U-Net and groundtruth: ", get_IoU(distilled_predicted_masks, masks_02_test_binary))
print("Dice coeff between distilled U-Net and groundtruth: ", get_dice(distilled_predicted_masks, masks_02_test_binary))

IoU between cellpose trained and groundtruth:  0.21182149734420092
Dice coeff between cellpose trained and groundtruth:  0.3493621693404443
IoU between cellpose 'cyto' and groundtruth:  0.07610879702163514
Dice coeff between cellpose 'cyto' and groundtruth:  0.1414164945667137
IoU between distilled U-Net and groundtruth:  0.07672062126296861
Dice coeff between distilled U-Net and groundtruth:  0.14247836277167722


In [21]:
#Get IoU and dice coeff between cellpose trained and cellpose 'cyto' predictions
print("IoU between cellpose trained and cellpose 'cyto' predictions: ", get_IoU(cellpose_predicted_masks_binary, cellpose_cyto_predicted_masks_binary))
print("Dice coeff between cellpose trained and cellpose 'cyto' predictions: ", get_dice(cellpose_predicted_masks_binary, cellpose_cyto_predicted_masks_binary))

IoU between cellpose trained and cellpose 'cyto' predictions:  0.26420578382284404
Dice coeff between cellpose trained and cellpose 'cyto' predictions:  0.4179438099609645


In [34]:
#These are the distillation results
#Get IoU and dice coeff between distilled U-Net and groundtruth
print("IoU between distilled U-Net and cellpose trained: ", get_IoU(distilled_predicted_masks, cellpose_predicted_masks_binary))
print("Dice coeff between distilled U-Net and cellpose trained: ", get_dice(distilled_predicted_masks, cellpose_predicted_masks_binary))

print("IoU between distilled U-Net and cellpose 'cyto': ", get_IoU(distilled_predicted_masks, cellpose_cyto_predicted_masks_binary))
print("Dice coeff between distilled U-Net and cellpose 'cyto': ", get_dice(distilled_predicted_masks, cellpose_cyto_predicted_masks_binary))

IoU between distilled U-Net and cellpose trained:  0.26588922437791973
Dice coeff between distilled U-Net and cellpose trained:  0.42005711997575107
IoU between distilled U-Net and cellpose 'cyto':  0.9350706966050253
Dice coeff between distilled U-Net and cellpose 'cyto':  0.966438367836746
