In [1]:
#imports
import numpy as np
from os import listdir
from os.path import isfile, join
from sklearn.model_selection import train_test_split
import tifffile
from u_net import UNet
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from statistics import mean

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#functions to split images into patches
def split_image(img):
    # Split the input array into smaller arrays of size 256x256
    sub_images = []
    for i in range(0, img.shape[0], 256):
        for j in range(0, img.shape[1], 256):
            sub_img = img[i:i+256, j:j+256]
            sub_images.append(sub_img)
    return sub_images

def combine_images(sub_images):
    # Create a NumPy array of size 1024x1024 to store the combined image
    img = np.zeros((1024, 1024))
    # Combine the smaller arrays into the larger image
    k = 0
    for i in range(0, img.shape[0], 256):
        for j in range(0, img.shape[1], 256):
            img[i:i+256, j:j+256] = sub_images[k]
            k += 1
    return img

In [7]:
#getting images and masks

def get_files(path,normalise=False,remove_txt=False):
    onlyfiles = [f for f in listdir(path) if isfile(join(path, f))]

    if remove_txt:
        onlyfiles = [val for val in onlyfiles if not val.endswith(".txt")]

    onlyfiles.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
    #if num_imgs > len(onlyfiles): num_imgs = len(onlyfiles)
    files = [np.squeeze(tifffile.imread(path +  onlyfiles[i])) for i in range(len(onlyfiles))]
    
    if normalise:
        files = [(image-np.min(image))/(np.max(image)-np.min(image)) for image in files]
    
    return files

def get_data(path, set='01',normalise_images=True):
    if len(set) == 2: #set 01 or set 02
        images_path = path + set + '/'
        images = get_files(images_path,normalise=normalise_images)
        masks_path = path + set + '_GT/TRA/'
        masks = get_files(masks_path,remove_txt=True)
    elif set == '0102': #both sets
        images_path = path + '01/'
        images_01 = get_files(images_path,normalise=normalise_images)
        images_path = path + '02/'
        images_02 = get_files(images_path,normalise=normalise_images)
        images = images_01 + images_02

        masks_path = path + '01_GT/TRA/'
        masks_01 = get_files(masks_path,remove_txt=True)
        masks_path = path + '02_GT/TRA/'
        masks_02 = get_files(masks_path,remove_txt=True)
        masks = masks_01 + masks_02
    else:
        images = []
        masks = []
    return images, masks

images, masks = get_data("c:\\Users\\rz200\\Documents\\development\\distillCellSegTrack\\" + 'datasets/Fluo-N2DH-GOWT1/', set = '0102',normalise_images=True)
X_train, X_test, y_train, y_test = train_test_split(images, masks, test_size=0.2, random_state=42)

X_train_padded = [np.pad(img,((0,1024-img.shape[0]),(0,1024-img.shape[1])),mode='constant',constant_values=0) for img in X_train]
X_test_padded = [np.pad(img,((0,1024-img.shape[0]),(0,1024-img.shape[1])),mode='constant',constant_values=0) for img in X_test]
y_train_padded = [np.pad(img,((0,1024-img.shape[0]),(0,1024-img.shape[1])),mode='constant',constant_values=0) for img in y_train]
y_test_padded = [np.pad(img,((0,1024-img.shape[0]),(0,1024-img.shape[1])),mode='constant',constant_values=0) for img in y_test]
#splitting images and masks into 256x256 patches
X_train_split = []
for image in X_train_padded:
    for i in range(0, 1024, 256):
        for j in range(0, 1024, 256):
            sub_img = image[i:i+256, j:j+256]
            X_train_split.append(sub_img)

X_test_split = []
for image in X_test_padded:
    for i in range(0, 1024, 256):
        for j in range(0, 1024, 256):
            sub_img = image[i:i+256, j:j+256]
            X_test_split.append(sub_img)

y_train_split = []
for mask in y_train_padded:
    for i in range(0, 1024, 256):
        for j in range(0, 1024, 256):
            sub_mask = mask[i:i+256, j:j+256]
            sub_mask = np.where(sub_mask>0,1,0)
            y_train_split.append(sub_mask)

y_test_split = []
for mask in y_test_padded:
    for i in range(0, 1024, 256):
        for j in range(0, 1024, 256):
            sub_mask = mask[i:i+256, j:j+256]
            sub_mask = np.where(sub_mask>0,1,0)
            y_test_split.append(sub_mask)


class ImageDataset(Dataset):
    def __init__(self, image, mask):
        self.image = image
        self.mask = mask

    def __len__(self):
        return len(self.mask)

    def __getitem__(self, idx):
        img = self.image[idx]
        label = self.mask[idx]
        return img, label
    
X_train_split_rot = []
y_train_cp_split_rot = []
for i in range(len(X_train_split)):
    for j in range(4):
        X_train_split_rot.append(np.rot90(X_train_split[i],j))
        y_train_cp_split_rot.append(np.rot90(y_train_split[i],j))
X_train_split_rot = np.array(X_train_split_rot)
y_train_cp_split_rot = np.array(y_train_cp_split_rot)

X_train_split_rot_noz = X_train_split_rot[y_train_cp_split_rot.sum(axis=(1,2))!=0]
y_train_cp_split_rot_noz = y_train_cp_split_rot[y_train_cp_split_rot.sum(axis=(1,2))!=0]

X_train_torch = [torch.from_numpy(np.array(X_train_split_rot_noz[i])) for i in range(len(X_train_split_rot_noz))]
y_train_cp_torch = [torch.from_numpy(np.array(y_train_cp_split_rot_noz[i])).type(torch.float32) for i in range(len(y_train_cp_split_rot_noz))]
X_test_torch = [torch.from_numpy(np.array(X_test_split[i])) for i in range(len(X_test_split))]
y_test_cp_torch = [torch.from_numpy(np.array(y_test_split[i])).type(torch.float32) for i in range(len(y_test_split))]
train_dataset = ImageDataset(X_train_torch, y_train_cp_torch)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_dataset = ImageDataset(X_test_torch, y_test_cp_torch)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

In [8]:
#training the model on the masks and images

def train_epoch(model, train_loader, test_loader, loss_fn, activation_fn, optimiser):
    model.train()

    #get train loss
    total_train_loss_per_epoch = 0
    for i, (x, y) in enumerate(train_loader):
        #x = x.copy()
        #y = y.copy()
        #print(i)
        #x = x.type(torch.float32)
        #y = y.type(torch.float32)
        (x,y) = (x.to('cuda:0'), y.to('cuda:0')) # sending the data to the device (cpu or GPU)

        x = x.unsqueeze(1)
        pred = model(x)# make a prediction
        #sigmoid the outputs
        if activation_fn == nn.Sigmoid:
            pred = activation_fn(pred)
        
        y = torch.unsqueeze(y,1)
        loss = loss_fn(pred, y) # calculate the loss of that prediction
        total_train_loss_per_epoch += loss
    
        optimiser.zero_grad() # zero out the accumulated gradients
        loss.backward() # backpropagate the loss
        optimiser.step() # update model parameters

        
        #total_train_loss_per_epoch += loss.detach().item()

    total_train_loss_per_epoch /= len(train_loader)
   
    #get test loss
    #total_test_loss_per_epoch = 0
    total_loss = 0
    with torch.no_grad():
        for images, cellprobs in test_loader:
            #images = images.copy()
            #cellprobs = cellprobs.copy()
            
            images = images.to('cuda:0')
            cellprobs = cellprobs.to('cuda:0')

            images = torch.unsqueeze(images,1)
            cellprobs = torch.unsqueeze(cellprobs,1)
            #cellprobs = cellprobs.to(torch.float32)
            outputs = model(images)
            #sigmoid the outputs
            if activation_fn == nn.Sigmoid:
                outputs = activation_fn(outputs)

            #outputs = activation_fn(outputs)
            loss = loss_fn(outputs, cellprobs)
            total_loss += loss
            #total_test_loss_per_epoch += loss.item()

            #calculate dice score
            #outputs = activation_fn(outputs)
            #outputs = torch.where(outputs>0.5,1.0,0.0)
            #outputs = outputs.view(-1)
            #cellprobs = cellprobs.view(-1)
            #intersection = (outputs * cellprobs).sum()  
            #dice = (2.*intersection+1)/(outputs.sum() + cellprobs.sum()+1)  
            #total_dice += dice
            
    #total_test_loss_per_epoch /= len(test_loader)
    #total_dice /= len(test_loader)
    #total_dice = total_dice.item()
    return total_train_loss_per_epoch.item()/len(train_loader), total_loss.item()/len(test_loader)

model = UNet()
model.to('cuda:0')
loss_fn = nn.BCEWithLogitsLoss()
optimiser = optim.Adam(model.parameters(), lr=0.001)
epochs = 10
losses = []
for epoch in range(epochs):
    loss = train_epoch(model, train_loader, test_loader, loss_fn, None, optimiser)
    losses.append(loss)
    print('Epoch {}: train loss {}, test loss {}'.format(epoch, loss, loss))

Epoch 0: train loss (3.0106934371279248e-05, 0.034585105406271445), test loss (3.0106934371279248e-05, 0.034585105406271445)
Epoch 1: train loss (2.2340581805975453e-05, 0.03445101106489027), test loss (2.2340581805975453e-05, 0.03445101106489027)
Epoch 2: train loss (2.0571406493817312e-05, 0.02960210877495843), test loss (2.0571406493817312e-05, 0.02960210877495843)
Epoch 3: train loss (1.981395727662898e-05, 0.02800200436566327), test loss (1.981395727662898e-05, 0.02800200436566327)
Epoch 4: train loss (1.9195960211092985e-05, 0.029361434884973475), test loss (1.9195960211092985e-05, 0.029361434884973475)
Epoch 5: train loss (1.866326554179954e-05, 0.027335627658947093), test loss (1.866326554179954e-05, 0.027335627658947093)
Epoch 6: train loss (1.822272812061981e-05, 0.026363454960487986), test loss (1.822272812061981e-05, 0.026363454960487986)
Epoch 7: train loss (1.8035840076297076e-05, 0.02666054545222102), test loss (1.8035840076297076e-05, 0.02666054545222102)
Epoch 8: train

In [9]:
#testing the model's accuracy

def get_prediction(image, model):
    splitted = split_image(image)
    predictions = []
    for split in splitted:
        split = torch.from_numpy(split)
        split = split.unsqueeze(0)
        split = split.unsqueeze(0)
        split = split.to('cuda:0')
        prediction = model(split)
        prediction = prediction.squeeze(0)
        prediction = prediction.squeeze(0)
        predictions.append(prediction.cpu().detach().numpy())
    combined_prediction = combine_images(predictions)
    return combined_prediction

def get_IoU(predicted_masks,gt_masks,return_list=False):
    intersection_unions = []
    for i in range(len(predicted_masks)):
        intersection = np.logical_and(predicted_masks[i], gt_masks[i]).sum()
        union = np.logical_or(predicted_masks[i], gt_masks[i]).sum()
        intersection_unions.append(intersection/union)
    if return_list:
        return intersection_unions
    return mean(intersection_unions)

def get_dice(predicted_masks,gt_masks, return_list=False):
    dices = []
    for i in range(len(predicted_masks)):
        intersection = np.logical_and(predicted_masks[i], gt_masks[i]).sum()
        dice = (2*intersection)/(predicted_masks[i].sum() + gt_masks[i].sum())
        dices.append(dice)
    if return_list:
        return dices
    return mean(dices)

combined_images = []
for i in range(0, len(X_test_torch), 16):
    combined_image = combine_images(X_test_torch[i:i+16])
    combined_images.append(combined_image)

combined_masks = []
for i in range(0, len(y_test_cp_torch), 16):
    combined_mask = combine_images(y_test_cp_torch[i:i+16])
    combined_mask = np.where(combined_mask>0.4,1,0)
    combined_masks.append(combined_mask)

predictions = []
for image in combined_images:
    prediction = get_prediction(image, model)
    prediction = torch.sigmoid(torch.from_numpy(prediction))
    prediction = np.where(prediction>0.4,1,0)
    predictions.append(prediction)

predictions_cropped = []
for i in range(len(predictions)):
    prediction = predictions[i]
    mask = y_test[i]
    prediction_cropped = prediction[0:mask.shape[0],0:mask.shape[1]]
    predictions_cropped.append(prediction_cropped)

dices = get_dice(predictions_cropped, y_test,return_list=True)
print('mean dice score: ', np.mean(dices))
iou = get_IoU(predictions_cropped, y_test, return_list=True)
print('mean iou score: ', np.mean(iou))

mean dice score:  0.08099949622469936
mean iou score:  0.770016708993167


In [10]:
#save the model
torch.save(model.state_dict(), 'train_dir/models/unet_no_distillation_GOWT1')