In [None]:
def predict_and_segment(test_path, cell_model_path, neurite_model_path, patchsize=128, overlap=0.1):
    # Load the models
    cell_model = load_model(cell_model_path)
    neurite_model = load_model(neurite_model_path)
    
    # Read the image
    file_list = [f for f in os.listdir(test_path) if f[-3:]=="tif"][0]
    img = img_read(test_path, file_list)
    
    # Extract patches from the image
    emp = EMPatches()
    img_patches, indices = emp.extract_patches(img, patchsize=patchsize, overlap=overlap)
    num_test_images = len(img_patches)
    
    # Generate the test data for the cell model
    test_gene = testGenerators(img_patches)
    
    # Predict cell probabilities
    c_results = cell_model.predict_generator(test_gene, num_test_images, verbose=1)
    
    # Merge the cell patches
    cell_merged_img = emp.merge_patches(c_results, indices, mode='avg')
    
    # Generate the test data for the neurite model
    test_gene = testGenerators(img_patches)
    
    # Predict neurite probabilities
    n_results = neurite_model.predict_generator(test_gene, num_test_images, verbose=1)
    
    # Merge the neurite patches
    neurite_merged_img = emp.merge_patches(n_results, indices, mode='min')
    
    # Segment the neurites
    neurite_mask = Neurite_Mask(neurite_merged_img)>0
    
    Cell = Cell_Mask(cell_merged_img)
    # Segment the cells using watershed segmentation
    cell_lbl = wtr_shed(Cell)
    cell_mask = cell_lbl>0
    
    color_labels = color.label2rgb(cell_lbl, img, alpha=0.4, bg_label=0)

    
    fig, axs = plt.subplots(1, 5, figsize=(15, 5))
    axs[0].matshow(img, cmap = 'Greys_r')
    axs[0].axis('off')
    axs[0].set_title('DHM image')
    axs[1].imshow(color_labels, cmap = 'Greys')
    axs[1].axis('off')
    axs[1].set_title(' Number of segmented cells = %i'%((cell_lbl.max())))
    axs[2].matshow(cell_mask, cmap = 'Greys_r')
    axs[2].axis('off')
    axs[2].set_title('Cell body')
    axs[3].matshow(neurite_mask, cmap = 'Greys_r')
    axs[3].axis('off')
    axs[3].set_title('Neurites')
    axs[4].matshow(cell_mask+neurite_mask, cmap = 'Greys_r')
    axs[4].axis('off')
    axs[4].set_title('Neuronal network')
    plt.show()
    
    return cell_mask, neurite_mask