In [14]:
import numpy as np
from os import listdir
from os.path import isfile, join
import tifffile
import cellpose
from cellpose import models, io, core, dynamics
import time
from sklearn.model_selection import train_test_split
from statistics import mean
from u_net import UNet
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from ptflops import get_model_complexity_info

In [15]:
def get_files(path,normalise=False,remove_txt=False):
    onlyfiles = [f for f in listdir(path) if isfile(join(path, f))]

    if remove_txt:
        onlyfiles = [val for val in onlyfiles if not val.endswith(".txt")]

    onlyfiles.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
    #if num_imgs > len(onlyfiles): num_imgs = len(onlyfiles)
    files = [np.squeeze(tifffile.imread(path +  onlyfiles[i])) for i in range(len(onlyfiles))]
    
    if normalise:
        files = [(image-np.min(image))/(np.max(image)-np.min(image)) for image in files]
    
    return files   
   
def get_data(path, set='01',normalise_images=True):

    if len(set) == 2: #set 01 or set 02
        images_path = path + set + '/'
        images = get_files(images_path,normalise=normalise_images)
        masks_path = path + set + '_GT/TRA/'
        masks = get_files(masks_path,remove_txt=True)
    elif set == '0102': #both sets
        images_path = path + '01/'
        images_01 = get_files(images_path,normalise=normalise_images)
        images_path = path + '02/'
        images_02 = get_files(images_path,normalise=normalise_images)
        images = images_01 + images_02

        masks_path = path + '01_GT/TRA/'
        masks_01 = get_files(masks_path,remove_txt=True)
        masks_path = path + '02_GT/TRA/'
        masks_02 = get_files(masks_path,remove_txt=True)
        masks = masks_01 + masks_02
    else:
        images = []
        masks = []

    return images, masks

def get_IoU(predicted_masks,gt_masks, return_list=False):
    intersection_unions = []
    for i in range(len(predicted_masks)):
        intersection = np.logical_and(predicted_masks[i], gt_masks[i]).sum()
        union = np.logical_or(predicted_masks[i], gt_masks[i]).sum()
        intersection_unions.append(intersection/union)
    if return_list:
        return intersection_unions
    return mean(intersection_unions)

def get_dice(predicted_masks,gt_masks, return_list=False):
    dices = []
    for i in range(len(predicted_masks)):
        intersection = np.logical_and(predicted_masks[i], gt_masks[i]).sum()
        dice = (2*intersection)/(predicted_masks[i].sum() + gt_masks[i].sum())
        dices.append(dice)
    if return_list:
        return dices
    return mean(dices)

def get_accuracy(predicted_masks,gt_masks,return_list=False):
    accuracies = []
    for i in range(len(predicted_masks)):
        accuracies.append(np.mean(predicted_masks[i] == gt_masks[i]))
    if return_list:
        return accuracies
    return mean(accuracies)

In [16]:
images, masks = get_data("c:\\Users\\rz200\\Documents\\development\\distillCellSegTrack\\" + 'datasets/Fluo-N2DH-GOWT1/', set = '0102',normalise_images=True)
X_train, X_test, y_train, y_test = train_test_split(images, masks, test_size=0.2, random_state=42)

In [17]:
logger = io.logger_setup()
model = models.CellposeModel(gpu=core.use_gpu(), model_type='cyto', device=torch.device('cuda:0'))
new_model_path = model.train(X_train, y_train, 
                              test_data=X_test,
                              test_labels=y_test,
                              channels=[0,0], 
                              save_path='train_dir', 
                              n_epochs=200,
                              learning_rate=0.1,
                              weight_decay=0.0001,
                              model_name='cellpose_trained_model_SIM_5',
                              batch_size=16,
                              SGD=True)

creating new log file
2023-05-11 04:27:10,380 [INFO] WRITING LOG OUTPUT TO C:\Users\rz200\.cellpose\run.log
2023-05-11 04:27:10,381 [INFO] 
cellpose version: 	2.2.1 
platform:       	win32 
python version: 	3.8.16 
torch version:  	1.11.0+cu113
2023-05-11 04:27:10,389 [INFO] ** TORCH CUDA version installed and working. **
2023-05-11 04:27:10,391 [INFO] >> cyto << model set to be used
2023-05-11 04:27:10,971 [INFO] >>>> model diam_mean =  30.000 (ROIs rescaled to this size during training)
2023-05-11 04:27:17,492 [INFO] computing flows for labels


100%|██████████| 147/147 [00:12<00:00, 11.84it/s]


2023-05-11 04:27:32,699 [INFO] computing flows for labels


100%|██████████| 37/37 [00:03<00:00, 11.89it/s]


2023-05-11 04:27:37,960 [INFO] >>>> median diameter set to = 30
2023-05-11 04:27:37,961 [INFO] >>>> mean of training label mask diameters (saved to model) 45.370
2023-05-11 04:27:37,965 [INFO] >>>> training network with 2 channel input <<<<
2023-05-11 04:27:37,965 [INFO] >>>> LR: 0.10000, batch_size: 16, weight_decay: 0.00010
2023-05-11 04:27:37,966 [INFO] >>>> ntrain = 147, ntest = 37
2023-05-11 04:27:37,966 [INFO] >>>> nimg_per_epoch = 147
2023-05-11 04:27:44,478 [INFO] Epoch 0, Time  6.5s, Loss 0.5519, Loss Test 0.5210, LR 0.0000
2023-05-11 04:27:50,153 [INFO] saving network parameters to train_dir\models/cellpose_trained_model_SIM_5
2023-05-11 04:28:14,500 [INFO] Epoch 5, Time 36.5s, Loss 0.2180, Loss Test 0.1326, LR 0.0556
2023-05-11 04:28:44,288 [INFO] Epoch 10, Time 66.3s, Loss 0.1217, Loss Test 0.1151, LR 0.1000
2023-05-11 04:29:43,886 [INFO] Epoch 20, Time 125.9s, Loss 0.1107, Loss Test 0.1087, LR 0.1000
2023-05-11 04:30:43,040 [INFO] Epoch 30, Time 185.1s, Loss 0.1056, Loss T

In [18]:
predicted_masks = model.eval(X_test, batch_size=1, channels=[0,0], diameter=model.diam_labels)[0]

2023-05-11 04:44:28,591 [INFO] 0%|          | 0/37 [00:00<?, ?it/s]
2023-05-11 04:44:29,101 [INFO] 3%|2         | 1/37 [00:00<00:18,  1.96it/s]
2023-05-11 04:44:29,618 [INFO] 5%|5         | 2/37 [00:01<00:17,  1.95it/s]
2023-05-11 04:44:30,151 [INFO] 8%|8         | 3/37 [00:01<00:17,  1.91it/s]
2023-05-11 04:44:30,677 [INFO] 11%|#         | 4/37 [00:02<00:17,  1.91it/s]
2023-05-11 04:44:31,208 [INFO] 14%|#3        | 5/37 [00:02<00:16,  1.90it/s]
2023-05-11 04:44:31,716 [INFO] 16%|#6        | 6/37 [00:03<00:16,  1.92it/s]
2023-05-11 04:44:32,242 [INFO] 19%|#8        | 7/37 [00:03<00:15,  1.91it/s]
2023-05-11 04:44:32,738 [INFO] 22%|##1       | 8/37 [00:04<00:14,  1.95it/s]
2023-05-11 04:44:33,262 [INFO] 24%|##4       | 9/37 [00:04<00:14,  1.93it/s]
2023-05-11 04:44:33,777 [INFO] 27%|##7       | 10/37 [00:05<00:13,  1.94it/s]
2023-05-11 04:44:34,300 [INFO] 30%|##9       | 11/37 [00:05<00:13,  1.93it/s]
2023-05-11 04:44:34,817 [INFO] 32%|###2      | 12/37 [00:06<00:12,  1.93it/s]
2023-05-

In [19]:
predicted_masks = [np.where(mask>0,1,0) for mask in predicted_masks]
y_test_binary = [np.where(mask>0,1,0) for mask in y_test]

In [20]:
IoU = get_IoU(predicted_masks,y_test_binary,return_list=True)
accuracy = get_accuracy(predicted_masks,y_test_binary,return_list=True)
print('Mean IoU: ', mean(IoU))
print('Max IoU: ', max(IoU))
print('Min IoU:', min(IoU))
print('Mean Pixel-wise: ', mean(accuracy))
print('Max Pixel-wise: ', max(accuracy))
print('Min Pixel-wise: ', min(accuracy))

Mean IoU:  0.8096088846661483
Max IoU:  0.8418616722176565
Min IoU: 0.7799276202190316
Mean Pixel-wise:  0.9913675978377059
Max Pixel-wise:  0.9932928085327148
Min Pixel-wise:  0.9885635375976562
