In [1]:
# %cd drive/My\ Drive/Colab\ Notebooks

In [2]:
# train_networks: Training CNNs to be used by the main program

import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import cv2
import os
import imutils
from processing import *

import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import random_split, Dataset, DataLoader

from skimage.segmentation import watershed
from skimage.morphology import disk
from skimage.feature import peak_local_max
from skimage.filters import meijering

from scipy import ndimage as ndi

BATCH_SIZE = 7

def plot_two_images(imgL, imgR, titleL, titleR):
    f = plt.figure()
    f.add_subplot(1,2, 1)
    plt.imshow(imgL, cmap='gray')
    plt.title(titleL)
    f.add_subplot(1,2, 2)
    plt.imshow(imgR, cmap='gray')
    plt.title(titleR)
    plt.show(block=True)
    

# def load_data(dataset):
#     data = []
#     paths = [os.path.join(dataset, '01'), os.path.join(dataset, '02')]
#     for path in paths:
#         mask_path = path + '_ST'
#         mask_path = os.path.join(mask_path, 'SEG')
#         for f in os.listdir(mask_path):
#             if not f.endswith(".tif"):
#                 continue
#             image = cv2.imread(os.path.join(path, f.replace('man_seg', 't')), cv2.IMREAD_GRAYSCALE)
#             image = equalize_clahe(image).astype(np.float32)
#             mask = cv2.imread(os.path.join(mask_path, f), cv2.IMREAD_UNCHANGED)
#             print("   Loaded " + os.path.join(mask_path, f) + ", " + os.path.join(path, f.replace('man_seg', 't')))
            
#             # Generate the Cell Mask and Markers from the Mask
#             cell_mask = (mask > 0).astype(np.uint8)
#             markers = (get_markers(mask) > 0).astype(np.uint8)
#             weight_map = get_weight_map(markers)
            
#             # Pack the data for the DataLoader
#             target = (cell_mask, markers, weight_map)
#             data.append((np.array([image]), target))

#     train_size = int(0.8 * len(data))
#     test_size = len(data) - train_size
#     train_data, test_data = random_split(data, [train_size, test_size])
#     trainLoader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
#     testLoader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True)
#     return trainLoader, testLoader

def load_data(dataset):
    data = []
    path = os.path.join(dataset, "originals")
    clahe_path = path.replace("originals", "clahes")
    mask_path = path.replace("originals", "masks")
    markers_path = path.replace("originals", "markers")
    wm_path = path.replace("originals", "weight_maps")
    
    for f in os.listdir(path):
        if not f.endswith(".npy"):
            continue
#         image = cv2.imread(os.path.join(path, f), cv2.IMREAD_GRAYSCALE)
#         clahe = cv2.imread(os.path.join(clahe_path, f), cv2.IMREAD_GRAYSCALE).astype(np.float32)
#         cell_mask = cv2.imread(os.path.join(mask_path, f), cv2.IMREAD_UNCHANGED)
#         markers = cv2.imread(os.path.join(markers_path, f), cv2.IMREAD_UNCHANGED)
#         weight_map = cv2.imread(os.path.join(wm_path, f), cv2.IMREAD_UNCHANGED)
        
        image = np.load(os.path.join(path, f))
        clahe = np.load(os.path.join(clahe_path, f))
        cell_mask = np.load(os.path.join(mask_path, f))
        markers = np.load(os.path.join(markers_path, f))
        weight_map = np.load(os.path.join(wm_path, f))
        print("   Loaded " + os.path.join(mask_path, f) + ", " + os.path.join(path, f.replace('mask', '')))

        # Pack the data for the DataLoader
        target = (cell_mask, markers, weight_map)
        data.append((np.array([clahe]), target))
        
    train_size = int(0.8 * len(data))
    test_size = len(data) - train_size
    train_data, test_data = random_split(data, [train_size, test_size])
    trainLoader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
    testLoader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True)
    return trainLoader, testLoader

def get_score(outputs, ground_truth):
    """
    Calculates Accuracy Score across the batch
    """
    score = 0
    batch_size = outputs.shape[0]
    total = outputs.shape[1] * outputs.shape[2]
    for sample in range(batch_size):
        num_correct = torch.sum(outputs[sample] == ground_truth[sample]).item()
        score += float(num_correct) / total

    return score / batch_size

In [3]:
# Class for creating the CNN
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 32, 3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv4 = nn.Conv2d(64, 64, 3, padding=1)
        self.conv5 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv6 = nn.Conv2d(128, 128, 3, padding=1)
        self.conv7 = nn.Conv2d(128, 256, 3, padding=1)
        self.conv8 = nn.Conv2d(256, 256, 3, padding=1)
        self.conv9 = nn.Conv2d(256, 512, 3, padding=1)
        self.conv10 = nn.Conv2d(512, 512, 3, padding=1)
        self.conv11 = nn.Conv2d(768, 256, 3, padding=1)
        self.conv12 = nn.Conv2d(256, 256, 3, padding=1)
        self.conv13 = nn.Conv2d(384, 128, 3, padding=1)
        self.conv14 = nn.Conv2d(128, 128, 3, padding=1)
        self.conv15 = nn.Conv2d(192, 64, 3, padding=1)
        self.conv16 = nn.Conv2d(64, 64, 3, padding=1)
        self.conv17 = nn.Conv2d(96, 32, 3, padding=1)
        self.conv18 = nn.Conv2d(32, 32, 3, padding=1)
        self.conv_out = nn.Conv2d(32, 2, 1)
        
    def forward(self, x):
        """
        Forward pass through the network
        """
        x = F.relu(self.conv1(x))
        contraction_32 = F.relu(self.conv2(x))
        
        x = F.max_pool2d(contraction_32, kernel_size=2)
        x = F.relu(self.conv3(x))
        contraction_64 = F.relu(self.conv4(x))
        
        x = F.max_pool2d(contraction_64, kernel_size=2)
        x = F.relu(self.conv5(x))
        contraction_128 = F.relu(self.conv6(x))
        
        x = F.max_pool2d(contraction_128, kernel_size=2)
        x = F.relu(self.conv7(x))
        contraction_256 = F.relu(self.conv8(x))
        
        x = F.max_pool2d(contraction_256, kernel_size=2)
        x = F.relu(self.conv9(x))
        x = F.relu(self.conv10(x))
        
        x = F.interpolate(x, scale_factor=2, mode='bilinear')
        x = torch.cat((contraction_256, x), dim=1)
        x = F.relu(self.conv11(x))
        x = F.relu(self.conv12(x))
        
        x = F.interpolate(x, scale_factor=2, mode='bilinear')
        x = torch.cat((contraction_128, x), dim=1)
        x = F.relu(self.conv13(x))
        x = F.relu(self.conv14(x))
        
        x = F.interpolate(x, scale_factor=2, mode='bilinear')
        x = torch.cat((contraction_64, x), dim=1)
        x = F.relu(self.conv15(x))
        x = F.relu(self.conv16(x))
        
        x = F.interpolate(x, scale_factor=2, mode='bilinear')
        x = torch.cat((contraction_32, x), dim=1)
        x = F.relu(self.conv17(x))
        x = F.relu(self.conv18(x))
        
        x = self.conv_out(x)
        output = F.sigmoid(x)
        return output

# '''
def weighted_mean_sq_error(inputs, targets_m, targets_c, weights):

#     Weighted Cross-Entropy Loss takes in a weight map
#     and computes loss
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print('.', end='')
    
#     print(inputs.shape, targets_m.shape)
    
    inputs = inputs.to(device)
    targets = [targets_m.to(device), targets_c.to(device)]
    weights = weights.to(device)
    loss = torch.zeros(inputs.shape[0])

    # Calculate loss for each sample in the batch
    for sample in range(inputs.shape[0]):
#         print("Sample", sample+1)
        sample_loss, total_weight = 0.0, 0.0
        
        pred_markers = inputs[sample][0]
        pred_cmask = inputs[sample][1]
        
#         pimg(pred_markers.cpu().detach().numpy())
#         pimg(pred_cmasks.cpu().detach().numpy())
#         pimg(weights.cpu().detach().numpy())
        
        exp_markers = targets[0][sample]
        exp_cmask = targets[1][sample]
        
        numerator = weights[sample] * ( (pred_markers-exp_markers)**2 + (pred_cmask-exp_cmask)**2 )
        
        numerator = torch.sum(numerator)
        
        denominator = torch.sum(weights[sample])
#         print(np.unique(exp_markers), np.unique(exp_cmasks), numerator, denominator)
        sample_loss = 0.5 * (numerator/denominator)
        
        loss[sample] = sample_loss

    return torch.mean(loss)
# '''

def main():
    """
    Train 2 networks for predicting markers and the cell mask respectively
    Set trains on data from "Sequence 1 Masks" and "Sequence 2 Masks"
    and save the models
    """
    
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Using device: " + str(device))
    
    # Net predicts the markers and cells
    net = Network().to(device)
    
    # criterion = F.nll_loss
#     criterion = weighted_cross_entropy_loss
    criterion = weighted_mean_sq_error
    
    # Optimising using Adam algorithm
    optimiser = optim.Adam(net.parameters(), lr=0.001)
    
    max_score = [0,0]
    
    # Iterate over a number of epochs on the data
    for epoch in range(300):
        for i, batch in enumerate(trainLoader, 0):
            x = batch[0].to(device)
            target = batch[1]
            cell_masks, markers = target[0].to(device), target[1].to(device) # Unpack target data
            weight_map = target[2].to(device)

            # Clear gradients from last step
            optimiser.zero_grad()

            # Predict the markers from the image
            output = net(x)
            # loss_m = criterion(output_m, markers.long())
            loss = criterion(output, markers.float(), cell_masks.float(), weight_map)
            loss.backward()
            optimiser.step()

            if i == 0 or (i + 1) % 960 == 0:
                print(f"Epoch: {epoch+1}, Batch: {i + 1}")
                print(f"Loss: {loss.item():.2f}")
                
                plt.imshow(x[0][0].cpu(), cmap='gray')
                plt.title("Input")
                plt.show()

                # Get the predicted Cell Mask and Markers for one of the images
                pred_m = ( (output[0][0] > 0.5).int() ).cpu()
                pred_c = ( (output[0][1] > 0.5).int() ).cpu()
                
                # Compare predicted to true images
                plot_two_images(pred_c.cpu(), cell_masks[0].cpu(), "Predicted Cell Mask", "True Cell Mask")
                plot_two_images(pred_m.cpu(), markers[0].cpu(), "Predicted Markers", "True Markers")


        # Test on the evaluation set
        print("\n--- Evaluation ---")
        net.eval()
        with torch.no_grad():
            running_score = np.array([0.0, 0.0])
            for i, batch in enumerate(testLoader):
                x = batch[0].to(device)
                target = batch[1]
                cell_masks, markers = target[0].to(device), target[1].to(device) # Unpack target data
                weight_map = target[2].to(device)

                output = net(x)
                
                pred_m = (output[:,0] > 0.5).int()
                pred_c = (output[:,1] > 0.5).int()

                running_score[0] += get_score(pred_c, cell_masks)
                running_score[1] += get_score(pred_m, markers)

                if i == 0:
                    plt.imshow(x[0][0].cpu(), cmap='gray')
                    plt.title("Input")
                    plt.show()

                    # Compare predicted to true images
                    plot_two_images(pred_c[0].cpu(), cell_masks[0].cpu(), "Predicted Cell Mask", "True Cell Mask")
                    plot_two_images(pred_m[0].cpu(), markers[0].cpu(), "Predicted Markers", "True Markers")

            score = running_score / len(testLoader)
            
            file_count_c = 0
            file_count_m = 0
            if score[0] > max_score[0]:
                torch.save(net.state_dict(), "./CNN_c_v1_max_score_{}.pth".format(file_count_c))
                max_score[0] = score[0]
                file_count_c += 1
                file_count_c %= 10
                
            if score[1] > max_score[1]:
                torch.save(net.state_dict(), "./CNN_m_v1_max_score_{}.pth".format(file_count_m))
                max_score[1] = score[1]
                file_count_m += 1
                file_count_m %= 10
            
            print(f"EPOCH {epoch+1} SCORE\nCell Mask: {score[0]:.3f}, Markers: {score[1]:.3f}")
            print(f"Overall: {(score[0]+score[1])/2:.3f}\n\n")
        net.train()

    torch.save(net.state_dict(), "./CNN_v1.pth")
    print("Saved models.")

In [None]:
print("Loading Data...")
trainLoader, testLoader = load_data('DIC-3_cache')
print("Finished.")

Loading Data...
   Loaded DIC-3_cache\masks\t0.npy, DIC-3_cache\originals\t0.npy
   Loaded DIC-3_cache\masks\t1.npy, DIC-3_cache\originals\t1.npy
   Loaded DIC-3_cache\masks\t10.npy, DIC-3_cache\originals\t10.npy
   Loaded DIC-3_cache\masks\t100.npy, DIC-3_cache\originals\t100.npy
   Loaded DIC-3_cache\masks\t1000.npy, DIC-3_cache\originals\t1000.npy
   Loaded DIC-3_cache\masks\t1001.npy, DIC-3_cache\originals\t1001.npy
   Loaded DIC-3_cache\masks\t1002.npy, DIC-3_cache\originals\t1002.npy
   Loaded DIC-3_cache\masks\t1003.npy, DIC-3_cache\originals\t1003.npy
   Loaded DIC-3_cache\masks\t1004.npy, DIC-3_cache\originals\t1004.npy
   Loaded DIC-3_cache\masks\t1005.npy, DIC-3_cache\originals\t1005.npy
   Loaded DIC-3_cache\masks\t1006.npy, DIC-3_cache\originals\t1006.npy
   Loaded DIC-3_cache\masks\t1007.npy, DIC-3_cache\originals\t1007.npy
   Loaded DIC-3_cache\masks\t1008.npy, DIC-3_cache\originals\t1008.npy
   Loaded DIC-3_cache\masks\t1009.npy, DIC-3_cache\originals\t1009.npy
   Loade

   Loaded DIC-3_cache\masks\t1104.npy, DIC-3_cache\originals\t1104.npy
   Loaded DIC-3_cache\masks\t1105.npy, DIC-3_cache\originals\t1105.npy
   Loaded DIC-3_cache\masks\t1106.npy, DIC-3_cache\originals\t1106.npy
   Loaded DIC-3_cache\masks\t1107.npy, DIC-3_cache\originals\t1107.npy
   Loaded DIC-3_cache\masks\t1108.npy, DIC-3_cache\originals\t1108.npy
   Loaded DIC-3_cache\masks\t1109.npy, DIC-3_cache\originals\t1109.npy
   Loaded DIC-3_cache\masks\t111.npy, DIC-3_cache\originals\t111.npy
   Loaded DIC-3_cache\masks\t1110.npy, DIC-3_cache\originals\t1110.npy
   Loaded DIC-3_cache\masks\t1111.npy, DIC-3_cache\originals\t1111.npy
   Loaded DIC-3_cache\masks\t1112.npy, DIC-3_cache\originals\t1112.npy
   Loaded DIC-3_cache\masks\t1113.npy, DIC-3_cache\originals\t1113.npy
   Loaded DIC-3_cache\masks\t1114.npy, DIC-3_cache\originals\t1114.npy
   Loaded DIC-3_cache\masks\t1115.npy, DIC-3_cache\originals\t1115.npy
   Loaded DIC-3_cache\masks\t1116.npy, DIC-3_cache\originals\t1116.npy
   Loade

   Loaded DIC-3_cache\masks\t1210.npy, DIC-3_cache\originals\t1210.npy
   Loaded DIC-3_cache\masks\t1211.npy, DIC-3_cache\originals\t1211.npy
   Loaded DIC-3_cache\masks\t1212.npy, DIC-3_cache\originals\t1212.npy
   Loaded DIC-3_cache\masks\t1213.npy, DIC-3_cache\originals\t1213.npy
   Loaded DIC-3_cache\masks\t1214.npy, DIC-3_cache\originals\t1214.npy
   Loaded DIC-3_cache\masks\t1215.npy, DIC-3_cache\originals\t1215.npy
   Loaded DIC-3_cache\masks\t1216.npy, DIC-3_cache\originals\t1216.npy
   Loaded DIC-3_cache\masks\t1217.npy, DIC-3_cache\originals\t1217.npy
   Loaded DIC-3_cache\masks\t1218.npy, DIC-3_cache\originals\t1218.npy
   Loaded DIC-3_cache\masks\t1219.npy, DIC-3_cache\originals\t1219.npy

In [None]:
if __name__ == '__main__':
    main()