In [None]:
# %cd drive/My\ Drive/Colab\ Notebooks

In [None]:
# train_networks: Training CNNs to be used by the main program

import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import cv2
import os
import imutils
from processing import *

import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from skimage.segmentation import watershed
from skimage.morphology import disk
from skimage.feature import peak_local_max
from skimage.filters import meijering

from scipy import ndimage as ndi


def plot_two_images(imgL, imgR, titleL, titleR):
    f = plt.figure()
    f.add_subplot(1,2, 1)
    plt.imshow(imgL, cmap='gray')
    plt.title(titleL)
    f.add_subplot(1,2, 2)
    plt.imshow(imgR, cmap='gray')
    plt.title(titleR)
    plt.show(block=True)
    
'''
def load_data(dataset):
    data = []
    paths = [os.path.join(dataset, '01'), os.path.join(dataset, '02')]
    for path in paths:
        mask_path = path + '_ST'
        mask_path = os.path.join(mask_path, 'SEG')
        for f in os.listdir(mask_path):
            if not f.endswith(".tif"):
                continue
            image = cv2.imread(os.path.join(path, f.replace('man_seg', 't')), cv2.IMREAD_GRAYSCALE)
            image = equalize_clahe(image).astype(np.float32)
            mask = cv2.imread(os.path.join(mask_path, f), cv2.IMREAD_UNCHANGED)
            print("   Loaded " + os.path.join(mask_path, f) + ", " + os.path.join(path, f.replace('mask', '')))
            
            # Generate the Cell Mask and Markers from the Mask
            cell_mask = (mask > 0).astype(np.uint8)
            markers = (get_markers(mask) > 0).astype(np.uint8)
            weight_map = get_weight_map(markers)
            
            # Pack the data for the DataLoader
            target = (cell_mask, markers, weight_map)
            data.append((np.array([image]), target))
            
    return DataLoader(data, batch_size=6, shuffle=True)
'''

def load_data(dataset):
    data = []
    path = os.path.join(dataset, "originals")
    clahe_path = path.replace("originals", "clahes")
    mask_path = path.replace("originals", "masks")
    markers_path = path.replace("originals", "markers")
    wm_path = path.replace("originals", "weight_maps")
    
    for f in os.listdir(path):
        if not f.endswith(".npy"):
            continue
#         image = cv2.imread(os.path.join(path, f), cv2.IMREAD_GRAYSCALE)
#         clahe = cv2.imread(os.path.join(clahe_path, f), cv2.IMREAD_GRAYSCALE).astype(np.float32)
#         cell_mask = cv2.imread(os.path.join(mask_path, f), cv2.IMREAD_UNCHANGED)
#         markers = cv2.imread(os.path.join(markers_path, f), cv2.IMREAD_UNCHANGED)
#         weight_map = cv2.imread(os.path.join(wm_path, f), cv2.IMREAD_UNCHANGED)
        
        image = np.load(os.path.join(path, f))
        clahe = np.load(os.path.join(clahe_path, f))
        cell_mask = np.load(os.path.join(mask_path, f))
        markers = np.load(os.path.join(markers_path, f))
        weight_map = np.load(os.path.join(wm_path, f))
        print("   Loaded " + os.path.join(mask_path, f) + ", " + os.path.join(path, f.replace('mask', '')))
        
        # Generate the Cell Mask and Markers from the Mask
#         cell_mask = (mask > 0).astype(np.uint8)
#         markers = (preprocessing.get_markers(mask) > 0).astype(np.uint8)
#         weight_map = preprocessing.get_weight_map(markers)

#         weight_map = []
        # Pack the data for the DataLoader
        target = (cell_mask, markers, weight_map)
        data.append((np.array([clahe]), target))
            
    return DataLoader(data, batch_size=4, shuffle=True)

In [None]:
# Class for creating the CNN
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 32, 3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv4 = nn.Conv2d(64, 64, 3, padding=1)
        self.conv5 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv6 = nn.Conv2d(128, 128, 3, padding=1)
        self.conv7 = nn.Conv2d(128, 256, 3, padding=1)
        self.conv8 = nn.Conv2d(256, 256, 3, padding=1)
        self.conv9 = nn.Conv2d(256, 512, 3, padding=1)
        self.conv10 = nn.Conv2d(512, 512, 3, padding=1)
        self.conv11 = nn.Conv2d(768, 256, 3, padding=1)
        self.conv12 = nn.Conv2d(256, 256, 3, padding=1)
        self.conv13 = nn.Conv2d(384, 128, 3, padding=1)
        self.conv14 = nn.Conv2d(128, 128, 3, padding=1)
        self.conv15 = nn.Conv2d(192, 64, 3, padding=1)
        self.conv16 = nn.Conv2d(64, 64, 3, padding=1)
        self.conv17 = nn.Conv2d(96, 32, 3, padding=1)
        self.conv18 = nn.Conv2d(32, 32, 3, padding=1)
        self.conv_out = nn.Conv2d(32, 2, 1)
        
    def forward(self, x):
        """
        Forward pass through the network
        """
        x = F.relu(self.conv1(x))
        contraction_32 = F.relu(self.conv2(x))
        
        x = F.max_pool2d(contraction_32, kernel_size=2)
        x = F.relu(self.conv3(x))
        contraction_64 = F.relu(self.conv4(x))
        
        x = F.max_pool2d(contraction_64, kernel_size=2)
        x = F.relu(self.conv5(x))
        contraction_128 = F.relu(self.conv6(x))
        
        x = F.max_pool2d(contraction_128, kernel_size=2)
        x = F.relu(self.conv7(x))
        contraction_256 = F.relu(self.conv8(x))
        
        x = F.max_pool2d(contraction_256, kernel_size=2)
        x = F.relu(self.conv9(x))
        x = F.relu(self.conv10(x))
        
        x = F.interpolate(x, scale_factor=2, mode='bilinear')
        x = torch.cat((contraction_256, x), dim=1)
        x = F.relu(self.conv11(x))
        x = F.relu(self.conv12(x))
        
        x = F.interpolate(x, scale_factor=2, mode='bilinear')
        x = torch.cat((contraction_128, x), dim=1)
        x = F.relu(self.conv13(x))
        x = F.relu(self.conv14(x))
        
        x = F.interpolate(x, scale_factor=2, mode='bilinear')
        x = torch.cat((contraction_64, x), dim=1)
        x = F.relu(self.conv15(x))
        x = F.relu(self.conv16(x))
        
        x = F.interpolate(x, scale_factor=2, mode='bilinear')
        x = torch.cat((contraction_32, x), dim=1)
        x = F.relu(self.conv17(x))
        x = F.relu(self.conv18(x))
        
        x = self.conv_out(x)
        output = F.log_softmax(x, dim=1)
        return output


# def weighted_cross_entropy_loss(inputs, targets, weights):

# #     Weighted Cross-Entropy Loss takes in a weight map
# #     and computes loss

#     device = torch.device("cpu")
#     inputs = inputs.to(device)
#     targets = targets.to(device)
#     weights = weights.to(device)
#     loss = torch.zeros(inputs.shape[0])

#     # Calculate loss for each sample in the batch
#     for sample in range(inputs.shape[0]):
#         print("Sample", sample+1)
#         sample_loss, total_weight = 0.0, 0.0
#         for row in range(inputs.shape[2]):
#             for col in range(inputs.shape[3]):
#                 # Get pixel q
#                 q = (row, col)
#                 w = weights[sample][q].item()
#                 total_weight += w
#                 if targets[sample][0][q] == 0:
#                     # Get predicted probability for q = 0
#                     p = inputs[sample][0][q]
#                 else:
#                     # Get predicted probability for q = 1
#                     p = inputs[sample][1][q]
#                 sample_loss -= w * torch.log(p).item()
#         sample_loss = sample_loss / total_weight
#         loss[sample] = sample_loss

#     return torch.mean(loss)

# '''
def weighted_cross_entropy_loss(inputs, targets, weights):

#     Weighted Cross-Entropy Loss takes in a weight map
#     and computes loss

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print('.', end='')
    
#     device = torch.device("cpu")
    inputs = inputs.to(device)
    targets = targets.to(device)
    weights = weights.to(device)
    loss = torch.zeros(inputs.shape[0])

    # Calculate loss for each sample in the batch
    for sample in range(inputs.shape[0]):
        # print("Sample", sample+1)
        sample_loss, total_weight = 0.0, 0.0
        
        # P(y(q))q
        # log_p = torch.where( (targets[sample] < 0), torch.zeros(targets.shape), targets[sample])
        log_p = torch.where( (targets[sample] == 0), inputs[sample][0], inputs[sample][1])
        # print(torch.unique(inputs[sample][0]))
        
        log_pw = log_p * weights[sample]
        
        sum_log_pw = torch.sum(log_pw)
        sum_weights = torch.sum(weights[sample])
        # print(log_p,  sum_log_pw, sum_weights)

        div_sums = sum_log_pw / sum_weights
        
        sample_loss = div_sums * -1
        # print(div_sums, sample_loss)

        loss[sample] = sample_loss

    return torch.mean(loss)
# '''

def main():
    """
    Train 2 networks for predicting markers and the cell mask respectively
    Set trains on data from "Sequence 1 Masks" and "Sequence 2 Masks"
    and save the models
    """
    
    _c_img = cv2.imread(os.path.join("DIC-C2DH-HeLa", "Sequence 3", "t002.tif"), cv2.IMREAD_GRAYSCALE)
    _c_img = equalize_clahe(_c_img)
    
    torch.cuda.set_device(0)
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Using device: " + str(device))
    
    # Net M predicts the markers. Net C predicts the cell mask
    net_m, net_c = Network().to(device), Network().to(device)
    
    # criterion = F.nll_loss
    criterion = weighted_cross_entropy_loss
    
    # Optimising using Adam algorithm
    optimiser_m = optim.Adam(net_m.parameters(), lr=0.001)
    optimiser_c = optim.Adam(net_c.parameters(), lr=0.001)
    
    # Iterate over a number of epochs on the data
    for epoch in range(100):
        for i, batch in enumerate(trainLoader):
            x = batch[0].to(device)
            target = batch[1]
            cell_masks, markers = target[0].to(device), target[1].to(device) # Unpack target data
            weight_map = target[2].to(device)

            # Clear gradients from last step
            optimiser_m.zero_grad()
            optimiser_c.zero_grad()

            # Predict the markers from the image
            output_m = net_m(x)
            # loss_m = criterion(output_m, markers.long())
            loss_m = criterion(output_m, markers.float(), weight_map)
            loss_m.backward()
            optimiser_m.step()
            
            # Predict the Cell Mask from the image
            output_c = net_c(x)
            loss_c = criterion(output_c, cell_masks.float(), weight_map)
            # loss_c = criterion(output_c, cell_masks.long())
            loss_c.backward()
            optimiser_c.step()

            if i == 0 or (i + 1) % 10 == 0:
                print(f"Epoch: {epoch+1}, Batch: {i + 1}")
                print(f"Cell Mask Loss: {loss_c.item():.2f}, Markers Loss: {loss_m.item():.2f}")
                
                plt.imshow(x[0][0].cpu(), cmap='gray')
                plt.title("Input")
                plt.show()

                # Get the predicted Cell Mask and Markers for one of the images
                pred_c = torch.argmax(output_c[0], dim=0).cpu()
                pred_m = torch.argmax(output_m[0], dim=0).cpu()
                
                # Compare predicted to true images
                plot_two_images(pred_c, cell_masks[0].cpu(), "Predicted Cell Mask", "True Cell Mask")
                plot_two_images(pred_m, markers[0].cpu(), "Predicted Markers", "True Markers")
                
                with torch.no_grad():
                    _c = torch.tensor(np.array([[_c_img]])).float().to(device)
                    
                    net_m.eval()
                    net_c.eval()
                    
                    out_test_m = net_m(_c)
                    out_test_c = net_c(_c)
                    
                    net_m.train()
                    net_c.train()

                    pred_test_c = torch.argmax(out_test_c[0], dim=0).cpu()
                    pred_test_m = torch.argmax(out_test_m[0], dim=0).cpu()

                    ws_labels = get_ws_from_markers(pred_test_m.numpy(), pred_test_c.numpy())
                    bound_box = get_bound_box_from_ws(_c_img, ws_labels)

                    plot_two_images(_c_img, pred_test_c, "test_image", "Predicted Cell Mask Test")
                    plot_two_images(pred_test_m, ws_labels, "Predicted Markers Test", "Watershed_labels")

                    plt.imshow(bound_box, cmap='gray')
                    plt.title("Cells with bounding box.")
                    plt.show()

    torch.save(net_m.state_dict(), "./CNN_m.pth")
    torch.save(net_c.state_dict(), "./CNN_c.pth")
    print("Saved models.")

In [None]:
print("Loading Data...")
trainLoader = load_data('DIC-2_cache_npy')
print("Finished.")

In [None]:
if __name__ == '__main__':
    main()