In [36]:
import numpy as np
from os import listdir
from os.path import isfile, join
import tifffile
import cellpose
from cellpose import models, io, core, dynamics
import time
from sklearn.model_selection import train_test_split
from statistics import mean
from u_net import UNet
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from ptflops import get_model_complexity_info

In [38]:
def get_files(path,normalise=False,remove_txt=False):
    onlyfiles = [f for f in listdir(path) if isfile(join(path, f))]

    if remove_txt:
        onlyfiles = [val for val in onlyfiles if not val.endswith(".txt")]

    onlyfiles.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
    #if num_imgs > len(onlyfiles): num_imgs = len(onlyfiles)
    files = [np.squeeze(tifffile.imread(path +  onlyfiles[i])) for i in range(len(onlyfiles))]
    
    if normalise:
        files = [(image-np.min(image))/(np.max(image)-np.min(image)) for image in files]
    
    return files
    
def get_data(path, set='01',normalise_images=True):

    if len(set) == 2: #set 01 or set 02
        images_path = path + set + '/'
        images = get_files(images_path,normalise=normalise_images)
        masks_path = path + set + '_GT/TRA/'
        masks = get_files(masks_path,remove_txt=True)
    elif set == '0102': #both sets
        images_path = path + '01/'
        images_01 = get_files(images_path,normalise=normalise_images)
        images_path = path + '02/'
        images_02 = get_files(images_path,normalise=normalise_images)
        images = images_01 + images_02

        masks_path = path + '01_GT/TRA/'
        masks_01 = get_files(masks_path,remove_txt=True)
        masks_path = path + '02_GT/TRA/'
        masks_02 = get_files(masks_path,remove_txt=True)
        masks = masks_01 + masks_02
    else:
        images = []
        masks = []

    return images, masks

def split_image(img):
    
    # Split the input array into smaller arrays of size 256x256
    sub_images = []
    for i in range(0, img.shape[0], 256):
        for j in range(0, img.shape[1], 256):
            sub_img = img[i:i+256, j:j+256]
            sub_images.append(sub_img)
            
    return sub_images

def combine_images(sub_images):
    
    # Create a NumPy array of size 1024x1024 to store the combined image
    img = np.zeros((1024, 1024))
    
    # Combine the smaller arrays into the larger image
    k = 0
    for i in range(0, img.shape[0], 256):
        for j in range(0, img.shape[1], 256):
            img[i:i+256, j:j+256] = sub_images[k]
            k += 1
            
    return img

def get_IoU(predicted_masks,gt_masks, return_list=False):
    intersection_unions = []
    for i in range(len(predicted_masks)):
        intersection = np.logical_and(predicted_masks[i], gt_masks[i]).sum()
        union = np.logical_or(predicted_masks[i], gt_masks[i]).sum()
        intersection_unions.append(intersection/union)
    if return_list:
        return intersection_unions
    return mean(intersection_unions)

def get_dice(predicted_masks,gt_masks, return_list=False):
    dices = []
    for i in range(len(predicted_masks)):
        intersection = np.logical_and(predicted_masks[i], gt_masks[i]).sum()
        dice = (2*intersection)/(predicted_masks[i].sum() + gt_masks[i].sum())
        dices.append(dice)
    if return_list:
        return dices
    return mean(dices)

def get_accuracy(predicted_masks,gt_masks,return_list=False):
    accuracies = []
    for i in range(len(predicted_masks)):
        accuracies.append(np.mean(predicted_masks[i] == gt_masks[i]))
    if return_list:
        return accuracies
    return mean(accuracies)

In [40]:
images, masks = get_data("c:\\Users\\rz200\\Documents\\development\\distillCellSegTrack\\" + 'datasets/Fluo-N2DH-GOWT1/', set = '0102',normalise_images=True)
X_train, X_test, y_train, y_test = train_test_split(images, masks, test_size=0.2, random_state=42)

In [41]:
X_train_padded = [np.pad(img,((0,1024-img.shape[0]),(0,1024-img.shape[1])),mode='constant',constant_values=0) for img in X_train]
X_test_padded = [np.pad(img,((0,1024-img.shape[0]),(0,1024-img.shape[1])),mode='constant',constant_values=0) for img in X_test]
y_train_padded = [np.pad(img,((0,1024-img.shape[0]),(0,1024-img.shape[1])),mode='constant',constant_values=0) for img in y_train]
y_test_padded = [np.pad(img,((0,1024-img.shape[0]),(0,1024-img.shape[1])),mode='constant',constant_values=0) for img in y_test]

In [42]:
#binarise
y_train_padded = [np.where(img>0,1,0) for img in y_train_padded]
y_test_padded = [np.where(img>0,1,0) for img in y_test_padded]

In [43]:
X_train_split = []
for image in X_train_padded:
    for i in range(0, 1024, 256):
        for j in range(0, 1024, 256):
            sub_img = image[i:i+256, j:j+256]
            X_train_split.append(sub_img)

X_test_split = []
for image in X_test_padded:
    for i in range(0, 1024, 256):
        for j in range(0, 1024, 256):
            sub_img = image[i:i+256, j:j+256]
            X_test_split.append(sub_img)

y_train_split = []
for mask in y_train_padded:
    for i in range(0, 1024, 256):
        for j in range(0, 1024, 256):
            sub_mask = mask[i:i+256, j:j+256]
            y_train_split.append(sub_mask)

y_test_split = []
for mask in y_test_padded:
    for i in range(0, 1024, 256):
        for j in range(0, 1024, 256):
            sub_mask = mask[i:i+256, j:j+256]
            y_test_split.append(sub_mask)

In [44]:
X_train_split_rot = []
y_train_split_rot = []
for i in range(len(X_train_split)):
    for j in range(4):
        X_train_split_rot.append(np.rot90(X_train_split[i],j))
        y_train_split_rot.append(np.rot90(y_train_split[i],j))

In [45]:
X_train_split_rot = np.array(X_train_split_rot)
y_train_cp_split_rot = np.array(y_train_split_rot)
X_train_split_rot_noz = X_train_split_rot[y_train_cp_split_rot.sum(axis=(1,2))!=0]
y_train_cp_split_rot_noz = y_train_cp_split_rot[y_train_cp_split_rot.sum(axis=(1,2))!=0]

In [46]:
class ImageDataset(Dataset):
    def __init__(self, image, mask):
        self.image = image
        self.mask = mask

    def __len__(self):
        return len(self.mask)

    def __getitem__(self, idx):
        img = self.image[idx]
        label = self.mask[idx]
        return img, label

X_train_torch = [torch.from_numpy(np.array(X_train_split_rot_noz[i])) for i in range(len(X_train_split_rot_noz))]
y_train_cp_torch = [torch.from_numpy(np.array(y_train_cp_split_rot_noz[i])).type(torch.float32) for i in range(len(y_train_cp_split_rot_noz))]
X_test_torch = [torch.from_numpy(np.array(X_test_split[i])) for i in range(len(X_test_split))]
y_test_cp_torch = [torch.from_numpy(np.array(y_test_split[i])).type(torch.float32) for i in range(len(y_test_split))]
train_dataset = ImageDataset(X_train_torch, y_train_cp_torch)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_dataset = ImageDataset(X_test_torch, y_test_cp_torch)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

In [47]:
model = UNet()
model = model.to('cuda:0')

In [48]:
loss_fn = nn.BCEWithLogitsLoss()
optimiser = torch.optim.Adam(model.parameters(), lr=0.0001)

In [51]:
for epoch in range(10):
    start_time = time.time()
    model.train()

    train_loss = 0
    for i, (x, y) in enumerate(train_loader):
        (x,y) = (x.to('cuda:0'), y.to('cuda:0')) # sending the data to the device (cpu or GPU)

        x = x.unsqueeze(1)
        pred = model(x)# make a prediction
        
        y = torch.unsqueeze(y,1)
        loss = loss_fn(pred, y) # calculate the loss of that prediction
        train_loss += loss
    
        optimiser.zero_grad() # zero out the accumulated gradients
        loss.backward() # backpropagate the loss
        optimiser.step() # update model parameters
    train_loss = train_loss.item()/len(train_loader)
    
    test_loss = 0
    with torch.no_grad():
        for images, cellprobs in test_loader:
            
            images = images.to('cuda:0')
            cellprobs = cellprobs.to('cuda:0')

            images = torch.unsqueeze(images,1)
            cellprobs = torch.unsqueeze(cellprobs,1)
            outputs = model(images)

            loss = loss_fn(outputs, cellprobs)
            test_loss += loss

        predictions = []
        for image in X_test_padded:
            tiles = split_image(image)
            tiles = np.array(tiles)
            tiles = tiles.astype(np.float32)
            tiles = torch.from_numpy(tiles)
            tiles = torch.unsqueeze(tiles,1)
            tiles = tiles.to('cuda:0')
            outputs = model(tiles)
            outputs = outputs.cpu().detach().numpy()
            outputs = np.squeeze(outputs)
            outputs = 1/(1+np.exp(-outputs))
            outputs = np.where(outputs>0.5,1,0)
            output = combine_images(outputs)
            predictions.append(output)
        IoU = get_IoU(predictions, y_test_padded)
        dice = get_dice(predictions, y_test_padded)

    test_loss = test_loss.item()/len(test_loader)

    print('epoch: ', epoch, 'train loss', train_loss, 'test loss', test_loss, 'IoU:', IoU, 'dice:', dice, 'time: ', time.time()-start_time)

epoch:  0 train loss 0.03913314916940132 test loss 0.03053106166221 IoU: 0.7409294895985893 dice: 0.8510761179937784 time:  30.822853565216064
epoch:  1 train loss 0.03863542471359025 test loss 0.029542887533033215 IoU: 0.7413503223514488 dice: 0.8513089260051724 time:  30.853352069854736
epoch:  2 train loss 0.03824769611805995 test loss 0.029153617652686866 IoU: 0.7391213176252958 dice: 0.8498500986899811 time:  42.391671895980835
epoch:  3 train loss 0.037569269696786714 test loss 0.02837652129095954 IoU: 0.7466913102187824 dice: 0.8548789532741151 time:  61.607659339904785
epoch:  4 train loss 0.03727028364820013 test loss 0.02898452088639543 IoU: 0.7448622572575819 dice: 0.8536491037295325 time:  61.55084681510925
epoch:  5 train loss 0.03703627759205507 test loss 0.028667156760757033 IoU: 0.7403941642384259 dice: 0.8506864447626805 time:  61.31753492355347
epoch:  6 train loss 0.03651439609812267 test loss 0.027743371757301124 IoU: 0.7470379913794415 dice: 0.8550758332594555 time

In [53]:
predictions = []
for image in X_test_padded:
    tiles = split_image(image)
    tiles = np.array(tiles)
    tiles = tiles.astype(np.float32)
    tiles = torch.from_numpy(tiles)
    tiles = torch.unsqueeze(tiles,1)
    tiles = tiles.to('cuda:0')
    outputs = model(tiles)
    outputs = outputs.cpu().detach().numpy()
    outputs = np.squeeze(outputs)
    outputs = 1/(1+np.exp(-outputs))
    outputs = np.where(outputs>0.5,1,0)
    output = combine_images(outputs)
    predictions.append(output)
IoU = get_IoU(predictions, y_test_padded, return_list=True)
accuracy = get_accuracy(predictions, y_test_padded, return_list=True)


print('Mean IoU: ', mean(IoU))
print('Max IoU: ', max(IoU))
print('Min IoU:', min(IoU))
print('Mean Pixel-wise: ', mean(accuracy))
print('Max Pixel-wise: ', max(accuracy))
print('Min Pixel-wise: ', min(accuracy))

Mean IoU:  0.7457820136138066
Max IoU:  0.7783926218708828
Min IoU: 0.7052269808897953
Mean Pixel-wise:  0.988216632121318
Max Pixel-wise:  0.9906291961669922
Min Pixel-wise:  0.9846153259277344


In [None]:
#need to save model

In [33]:
with torch.cuda.device(0):
  net = model
  macs, params = get_model_complexity_info(net, (1, 256, 256), as_strings=True,
                                           print_per_layer_stat=True, verbose=True)
  print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
  print('{:<30}  {:<8}'.format('Number of parameters: ', params))

UNet(
  116.75 k, 100.000% Params, 1.8 GMac, 100.000% MACs, 
  (encoder): Encoder(
    71.79 k, 61.490% Params, 620.76 MMac, 34.437% MACs, 
    (encBlocks): ModuleList(
      71.79 k, 61.490% Params, 618.92 MMac, 34.336% MACs, 
      (0): Block(
        2.48 k, 2.124% Params, 163.58 MMac, 9.075% MACs, 
        (conv1): Conv2d(160, 0.137% Params, 10.49 MMac, 0.582% MACs, 1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (relu): ReLU(0, 0.000% Params, 1.05 MMac, 0.058% MACs, )
        (conv2): Conv2d(2.32 k, 1.987% Params, 152.04 MMac, 8.435% MACs, 16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (1): Block(
        13.89 k, 11.895% Params, 228.07 MMac, 12.652% MACs, 
        (conv1): Conv2d(4.64 k, 3.974% Params, 76.02 MMac, 4.217% MACs, 16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (relu): ReLU(0, 0.000% Params, 524.29 KMac, 0.029% MACs, )
        (conv2): Conv2d(9.25 k, 7.921% Params, 151.52 MMac, 8.406% MACs, 32, 32, kernel