# Training a random forest classifier to predict whether segmented shapes are spindles or noise

In [None]:
import os
import numpy as np
import napari
from tifffile import imread, imwrite
from skimage import morphology
from tqdm import tqdm
import time
import matplotlib.pyplot as plt
from skimage.morphology import binary_erosion
from skimage.filters import threshold_otsu
from scipy import spatial
from skimage.measure import label, regionprops, regionprops_table
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import pandas as pd
import joblib

In [None]:

def wipe_layers(viewer_name):
    '''
    Delete all layers in the viewer objected
    '''
    layers = viewer_name.layers
    while len(layers) > 0:
        layers.remove(layers[0])

def remove_large_objects(labels_array: np.ndarray, max_size: int) -> np.ndarray:
    ''' 
    Remove all objects in a mask above a specific threshold
    '''
    out = np.copy(labels_array)
    component_sizes = np.bincount(labels_array.ravel()) # count the number of pixels in different labels
    too_big = component_sizes > max_size
    too_big_mask = too_big[labels_array]
    out[too_big_mask] = 0
    return out

def return_points(labels_array: np.ndarray, label_ID: int) -> np.ndarray:
    '''
    Return the points in a mask that belong to a specific label
    ---
    Parameters:
    labels_array: np.ndarray an ndArray of labels
    label_ID: int the label ID of the label whos points you want to calculate
    ---
    Returns:
    points: np.ndarray an ndArray of shape (n,3) where n is the number of points in the label
    and dim1 is the x,y,z coordinates of the points
    '''
    points = np.column_stack(np.where(labels_array == label_ID))
    return points

def find_label_density(label_points: np.ndarray) -> float:
    '''
    Calculate the bounding box for a point cloud and return the density of points in the bounding box
    ---
    Parameters:
    label_points: np.ndarray the array point coordinates for a given label
    ---
    Returns:
    np.nan if the label is 0, or if the label has no length
    density (float) the number of points in the label divided by the volume of the bounding box
    '''

    x = label_points.T[0]
    y = label_points.T[1]
    z = label_points.T[2]
    num_points = len(x)
    x_min = np.min(x)
    x_max = np.max(x)
    y_min = np.min(y)
    y_max = np.max(y)
    z_min = np.min(z)
    z_max = np.max(z)
    # add 1 to prevent division by 0
    x_range = (x_max - x_min) + 1
    y_range = (y_max - y_min) + 1
    z_range = (z_max - z_min) + 1
    vol = x_range * y_range * z_range
    density = num_points / vol
    return density

def print_label_props(source: np.ndarray, label_num: int) -> None:
    '''
    Print the properties of a label in a mask
    ---
    Parameters:
    source: np.ndarray the mask containing the label
    label_num: int the label number of the label you want to print the properties of
    ---
    Returns:
    None
    '''
    label_points = return_points(source, label_num)
    density = find_label_density(label_points)
    size = label_points.shape[0]
    print(f'Label {label_num} has:')
    print(f'{size:,} points.')
    print(f'density of {round(density,4):,}')

def view_saved_files(file_path: str) -> None:
    ''' 
    Fxn for visualizing saved output files.
    '''
    dedicated_file_viewer = napari.Viewer()
    contents = [c for c in os.listdir(file_path) if not c.startswith('.')]
    for content in contents:
        if content.endswith('.tif'):
            if 'tub' in content or 'PI' in content:
                dedicated_file_viewer.add_image(imread(os.path.join(file_path, content)), name=content.split('.')[0], blending='additive', visible=False)
            else:
                dedicated_file_viewer.add_labels(imread(os.path.join(file_path, content)), name=content.split('.')[0], blending='additive')
        elif content.endswith('.txt'):
            nums = np.loadtxt(os.path.join(file_path, content))
            if nums.ndim == 1:
                dedicated_file_viewer.add_points(nums, name=content.split('.')[0], face_color='white', blending='additive')
            elif nums.ndim == 2:
                dedicated_file_viewer.add_shapes(nums, shape_type='line', name=content.split('.')[0], edge_color='white', blending='additive')
        else:
            print(f'file "{content}" not imported to viewer')


In [None]:
%gui qt
propsviewer = napari.Viewer()

In [None]:
# this folder contains raw data cubes and ground truth masks (or no masks, no good ones were found)
main_folder = os.path.join(os.getcwd(), 'raw_training_data')

# list of subfolders
subfolders = [f for f in os.listdir(main_folder) if not f.startswith('.')]

# list to hold all the properties for each cell cube in the training set
label_properties = []

for curr_cell_num in tqdm(subfolders):

    # get the path to the relevant cell folder and and data files
    base_dir = f'{os.getcwd()}/raw_training_data/{curr_cell_num}'
    valid_mask_name = 'thresh_mask.tif' 
    dog_tub_name = 'curr_tub_cube.tif' 
    cell_mask_name = 'curr_mask_cube.tif'
    eroded_mask_name = 'eroded_mask.tif'

    # load the data
    valid_mask = imread(os.path.join(base_dir, valid_mask_name))
    cubed_tub = imread(os.path.join(base_dir, dog_tub_name))
    cubed_label = imread(os.path.join(base_dir, cell_mask_name)).astype('bool')
    eroded_mask = imread(os.path.join(base_dir, eroded_mask_name)).astype('bool')

    # get the tubulin signal from the eroded mask region and define an Otsu threshold
    remaining_tub = np.zeros(shape=cubed_tub.shape)
    remaining_tub[eroded_mask] = cubed_tub[eroded_mask]
    remaining_vals = cubed_tub[eroded_mask].ravel()
    thresh_val = threshold_otsu(remaining_vals)
    thresh_mask = label(remaining_tub > thresh_val)

    # get the label properties
    props = regionprops_table(thresh_mask, properties=('area',
                                                    'axis_major_length',
                                                    'axis_minor_length',
                                                    'label'))
    props_df = pd.DataFrame(props)

    # attempt to find the ground truth label value. If none exists, assign to None
    try:
        valid_ID = [num for num in np.unique(valid_mask) if num != 0][0]
    except IndexError:
        valid_ID = None

    # make a new column named "spindle" and assign to 1 if the label value is 1 otherwise 0
    props_df['spindle'] = (props_df['label'] == valid_ID).astype(int)

    # get the mask coordinates and centroid
    mask_coords = np.column_stack(np.where(cubed_label == True))
    cell_centroid = mask_coords.mean(axis=0)

    # get a list of remaining label IDs. Iterate through them to find the distance to the 
    # cell controid and the label density
    remaining_labels = [l for l in np.unique(thresh_mask) if label != 0]
    for label_num in remaining_labels:
        label_coords = np.column_stack(np.where(thresh_mask == label_num))
        label_centroid = label_coords.mean(axis=0)
        dist = spatial.distance.euclidean(cell_centroid, label_centroid)
        props_df.loc[props_df['label'] == label_num, 'dist_to_cell'] = dist

        label_density = find_label_density(label_coords)
        props_df.loc[props_df['label'] == label_num, 'density'] = label_density

    # remove label column, convert to list of dicts and append to label_properties list
    props_df.drop(columns=['label'], inplace=True)
    props_list = props_df.to_dict('records')
    for property_dict in props_list:
        label_properties.append(property_dict)

# merge all data into one dataframe
df = pd.DataFrame(label_properties)
df = df[['area', 'axis_major_length', 'axis_minor_length', 'dist_to_cell', 'density', 'spindle']]

# define the properties and classes
X = df.iloc[:,0:5].values
y = df.iloc[:,5].values

# split into training and test sets
X_train, X_test, Y_train, Y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# define and train the classifier
classifier = RandomForestClassifier(n_estimators=100, random_state=42)
classifier.fit(X_train, Y_train)
y_pred = classifier.predict(X_test)
a = accuracy_score(Y_test, y_pred)
print(f'accuracy: {a}')
print(classification_report(Y_test, y_pred))
print(confusion_matrix(Y_test, y_pred))

cwd = os.getcwd()
joblib.dump(classifier, os.path.join(cwd, 'spindle_classifier.joblib'))

# piecewise:

In [None]:
# this folder contains raw data cubes and ground truth masks (or no masks, no good ones were found)
main_folder = os.path.join(os.getcwd(), 'raw_training_data')

# list of subfolders
subfolders = [f for f in os.listdir(main_folder) if not f.startswith('.')]

# list to hold all the properties for each cell cube in the training set
label_properties = []

curr_cell_num = input()

# get the path to the relevant cell folder and and data files
base_dir = f'{os.getcwd()}/raw_training_data/{curr_cell_num}'
valid_mask_name = 'thresh_mask.tif' 
dog_tub_name = 'curr_tub_cube.tif' 
cell_mask_name = 'curr_mask_cube.tif'
eroded_mask_name = 'eroded_mask.tif'

# load the data
valid_mask = imread(os.path.join(base_dir, valid_mask_name))
cubed_tub = imread(os.path.join(base_dir, dog_tub_name))
cubed_label = imread(os.path.join(base_dir, cell_mask_name)).astype('bool')
eroded_mask = imread(os.path.join(base_dir, eroded_mask_name)).astype('bool')

# get the tubulin signal from the eroded mask region and define an Otsu threshold
remaining_tub = np.zeros(shape=cubed_tub.shape)
remaining_tub[eroded_mask] = cubed_tub[eroded_mask]
remaining_vals = cubed_tub[eroded_mask].ravel()
thresh_val = threshold_otsu(remaining_vals)
thresh_mask = label(remaining_tub > thresh_val)

# get the label properties
props = regionprops_table(thresh_mask, properties=('area',
                                                'axis_major_length',
                                                'axis_minor_length',
                                                'label'))
props_df = pd.DataFrame(props)

# attempt to find the ground truth label value. If none exists, assign to None
try:
    valid_ID = [num for num in np.unique(valid_mask) if num != 0][0]
except IndexError:
    valid_ID = None

# make a new column named "spindle" and assign to 1 if the label value is 1 otherwise 0
props_df['spindle'] = (props_df['label'] == valid_ID).astype(int)

# get the mask coordinates and centroid
mask_coords = np.column_stack(np.where(cubed_label == True))
cell_centroid = mask_coords.mean(axis=0)

# get a list of remaining label IDs. Iterate through them to find the distance to the 
# cell controid and the label density
remaining_labels = [l for l in np.unique(thresh_mask) if label != 0]
for label_num in remaining_labels:
    label_coords = np.column_stack(np.where(thresh_mask == label_num))
    label_centroid = label_coords.mean(axis=0)
    dist = spatial.distance.euclidean(cell_centroid, label_centroid)
    props_df.loc[props_df['label'] == label_num, 'dist_to_cell'] = dist

    label_density = find_label_density(label_coords)
    props_df.loc[props_df['label'] == label_num, 'density'] = label_density

# remove label column, convert to list of dicts and append to label_properties list
props_df.drop(columns=['label'], inplace=True)
props_list = props_df.to_dict('records')
for property_dict in props_list:
    label_properties.append(property_dict)

# merge all data into one dataframe
df = pd.DataFrame(label_properties)
df = df[['area', 'axis_major_length', 'axis_minor_length', 'dist_to_cell', 'density', 'spindle']]

# define the properties and classes
X = df.iloc[:,0:5].values
y = df.iloc[:,5].values

# split into training and test sets
X_train, X_test, Y_train, Y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# define and train the classifier
classifier = RandomForestClassifier(n_estimators=100, random_state=42)
classifier.fit(X_train, Y_train)
y_pred = classifier.predict(X_test)
a = accuracy_score(Y_test, y_pred)
print(f'accuracy: {a}')
print(classification_report(Y_test, y_pred))
print(confusion_matrix(Y_test, y_pred))

cwd = os.getcwd()
joblib.dump(classifier, os.path.join(cwd, 'spindle_classifier.joblib'))