# 2D Nuclear Segmentation with Mask-RCNN

In [1]:
# This notebook is for benchmarking the output of the original trained featurenet model
import os
import errno

import numpy as np

import deepcell

In [2]:
# create folder for this set of experiments
experiment_folder = "featurenet_samir/"
MODEL_DIR = os.path.join("/data/analyses/", experiment_folder)
NPZ_DIR = "/data/npz_data/20201018_freeze/"
LOG_DIR = '/data/logs'

if not os.path.isdir(MODEL_DIR):
    os.makedirs(MODEL_DIR)

In [5]:
from tensorflow.keras.optimizers import SGD, Adam
from deepcell.utils.train_utils import rate_scheduler
from deepcell.utils.retinanet_anchor_utils import get_anchor_parameters
from deepcell.training import train_model_retinanet
from deepcell import model_zoo
from deepcell_toolbox.multiplex_utils import multiplex_preprocess
from timeit import default_timer
from skimage.measure import label
from deepcell_toolbox.deep_watershed import deep_watershed_mibi



def calc_jaccard_index_object(metric_predictions, true_labels, pred_labels):
    jacc_list = []
    for i in range(true_labels.shape[0]):
        y_true = true_labels[i, :, :, 0]
        y_pred = pred_labels[i, :, :, 0]
        true_ids = metric_predictions[i][0]['correct']['y_true']
        pred_ids = metric_predictions[i][0]['correct']['y_pred']

        current_accum = []

        for id in range(len(true_ids)):
            true_mask = y_true == true_ids[id]
            pred_mask = y_pred == pred_ids[id]

            current_jacc = (np.sum(np.logical_and(true_mask, pred_mask)) /
                np.sum(np.logical_or(true_mask, pred_mask)))
            current_accum.append(current_jacc)

        jacc_list.append(current_accum)
    return jacc_list


model_splits = ['1', '2', '3']
metrics = {}
for split in model_splits:
    print('loading data')
    test_name = "20201018_multiplex_seed_{}_test_256x256.npz".format(split)
    test_dict = np.load(NPZ_DIR + test_name)
    X_test = test_dict['X'][..., :1]
    #X_test = multiplex_preprocess(X_test)
    
    y_test = test_dict['y']

    
    model_name = 'featurenet_samir.h5'.format(split)

    # start timing
    time_start = default_timer()
    print('creating model')
    model = model_zoo.bn_feature_net_skip_2D(
        receptive_field=61,
        n_skips=3,
        n_features=3,
        norm_method='whole_image',
        n_conv_filters=32,
        n_dense_filters=128,
        last_only=False,
        input_shape=(256, 256, 1))
    
    model.load_weights(MODEL_DIR + model_name)
    print('predicting')
    pixelwise = model.predict(X_test)[-1]
    print('postprocessing')
    labeled_images = deep_watershed_mibi({'inner-distance': pixelwise[:, :, :, 1:2],
                                     'pixelwise-interior': pixelwise[:, :, :, 1:2]}, 
                                     maxima_threshold=0.3, maxima_model_smooth=0,
                                    interior_threshold=0.3, interior_model_smooth=0,
                                    radius=3,
                                    small_objects_threshold=10,
                                     fill_holes_threshold=10,
                                        pixel_expansion=3)
    
    # end time
    time_end = default_timer()
    print("elapsed time is {}".format(time_end - time_start))
    
    for i in range(labeled_images.shape[0]):
        img = labeled_images[i, :, :, 0]
        img = label(img)
        labeled_images[i, :, :, 0] = img
    
    for i in range(y_test.shape[0]):
        img = y_test[i, :, :, 0]
        img = label(img)
        y_test[i, :, :, 0] = img
    
    # calculating accuracy
    print("calculating accuracy")
    db = DatasetBenchmarker(y_true=y_test, 
                       y_pred=labeled_images,
                       tissue_list=test_dict['tissue_list'],
                       platform_list=test_dict['platform_list'],
                       model_name='default_model')
    tissue_stats, platform_stats = db.benchmark()
    
    jacc = calc_jaccard_index_object(db.metrics.predictions, y_test, labeled_images)
    jacc = np.concatenate(jacc)
    jacc_mean = np.mean(jacc)
    print(jacc_mean)
    metrics[split] = {'tissue_stats':tissue_stats, 'platform_stats': platform_stats, 'jacc':jacc_mean}

    

loading data
creating model
predicting
postprocessing
elapsed time is 121.75926812099988
calculating accuracy

____________Object-based statistics____________

Number of true cells:		 139873
Number of predicted cells:	 100208

Correct detections:  69813	Recall: 49.9117%
Incorrect detections: 30395	Precision: 69.6681%

Gained detections: 14240	Perc Error: 22.6057%
Missed detections: 34510	Perc Error: 54.7839%
Merges: 12506		Perc Error: 19.853%
Splits: 726		Perc Error: 1.1525%
Catastrophes: 1011		Perc Error: 1.6049%

Gained detections from splits: 765
Missed detections from merges: 19174
True detections involved in catastrophes: 1621
Predicted detections involved in catastrophes: 1134 

Average Pixel IOU (Jaccard Index): 0.6942 

uid is breast
uid is gi
uid is immune
uid is lung
uid is pancreas
uid is skin
uid is codex
uid is cycif
uid is imc
uid is mibi
uid is mxif
uid is vectra
uid is all
0.7688963270910291
loading data
creating model
predicting
postprocessing
elapsed time is 124.19915

  label_image = remove_small_objects(label_image, min_size=small_objects_threshold)


elapsed time is 121.51146884599984
calculating accuracy

____________Object-based statistics____________

Number of true cells:		 139149
Number of predicted cells:	 100157

Correct detections:  69308	Recall: 49.8085%
Incorrect detections: 30849	Precision: 69.1994%

Gained detections: 14310	Perc Error: 22.9162%
Missed detections: 33726	Perc Error: 54.0091%
Merges: 12511		Perc Error: 20.0352%
Splits: 861		Perc Error: 1.3788%
Catastrophes: 1037		Perc Error: 1.6607%

Gained detections from splits: 939
Missed detections from merges: 19417
True detections involved in catastrophes: 1661
Predicted detections involved in catastrophes: 1196 

Average Pixel IOU (Jaccard Index): 0.7011 

uid is breast
uid is gi
uid is immune
uid is lung
uid is pancreas
uid is skin
uid is codex
uid is cycif
uid is imc
uid is mibi
uid is mxif
uid is vectra
uid is all
0.7678280266850324


In [8]:
metrics['3']['tissue_stats']['all']['f1']

0.5792416404101861

In [9]:
np.savez_compressed(os.path.join('/data/analyses/', 'featurenet_metrics_samir_jacc.npz'), **metrics)

In [4]:
# Copyright 2016-2020 The Van Valen Lab at the California Institute of
# Technology (Caltech), with support from the Paul Allen Family Foundation,
# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01.
# All rights reserved.
#
# Licensed under a modified Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.github.com/vanvalenlab/caliban-toolbox/LICENSE
#
# The Work provided may be used for non-commercial academic purposes only.
# For any other use of the Work, including commercial use, please contact:
# vanvalenlab@gmail.com
#
# Neither the name of Caltech nor the names of its contributors may be used
# to endorse or promote products derived from this software without specific
# prior written permission.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np

from deepcell_toolbox.metrics import Metrics, stats_pixelbased
from scipy.stats import hmean


class DatasetBenchmarker(object):
    """Class to perform benchmarking across different tissue and platform types

    Args:
        y_true: true labels
        y_pred: predicted labels
        tissue_list: list of tissue names for each image
        platform_list: list of platform names for each image
        model_name: name of the model used to generate the predictions
        metrics_kwargs: arguments to be passed to metrics package

    Raises:
        ValueError: if y_true and y_pred have different shapes
        ValueError: if y_true and y_pred are not 4D
        ValueError: if tissue_ids or platform_ids is not same length as labels
    """
    def __init__(self,
                 y_true,
                 y_pred,
                 tissue_list,
                 platform_list,
                 model_name,
                 metrics_kwargs={}):
        if y_true.shape != y_pred.shape:
            raise ValueError('Shape mismatch: y_true has shape {}, '
                             'y_pred has shape {}. Labels must have the same'
                             'shape.'.format(y_true.shape, y_pred.shape))
        if len(y_true.shape) != 4:
            raise ValueError('Data must be 4D, supplied data is {}'.format(y_true.shape))

        self.y_true = y_true
        self.y_pred = y_pred

        if len({y_true.shape[0], len(tissue_list), len(platform_list)}) != 1:
            raise ValueError('Tissue_list and platform_list must have same length as labels')

        self.tissue_list = tissue_list
        self.platform_list = platform_list
        self.model_name = model_name
        self.metrics = Metrics(model_name, **metrics_kwargs)

    def _benchmark_category(self, category_ids):
        """Compute benchmark stats over the different categories in supplied list

        Args:
            category_ids: list specifying which category each image belongs to

        Returns:
            stats_dict: dictionary of benchmarking results
        """

        unique_ids = np.unique(category_ids)

        # create dict to hold stats across each category
        stats_dict = {}
        for uid in unique_ids:
            print("uid is {}".format(uid))
            stats_dict[uid] = {}
            category_idx = np.isin(category_ids, uid)

            # sum metrics across individual images
            for key in self.metrics.stats:
                stats_dict[uid][key] = self.metrics.stats[key][category_idx].sum()

            # compute additional metrics not produced by Metrics class
            stats_dict[uid]['recall'] = \
                stats_dict[uid]['correct_detections'] / stats_dict[uid]['n_true']

            stats_dict[uid]['precision'] = \
                stats_dict[uid]['correct_detections'] / stats_dict[uid]['n_pred']

            stats_dict[uid]['f1'] = \
                hmean([stats_dict[uid]['recall'], stats_dict[uid]['precision']])

            pixel_stats = stats_pixelbased(self.y_true[category_idx] != 0,
                                           self.y_pred[category_idx] != 0)
            stats_dict[uid]['jaccard'] = pixel_stats['jaccard']

        return stats_dict

    def benchmark(self):
        self.metrics.calc_object_stats(self.y_true, self.y_pred)
        tissue_stats = self._benchmark_category(category_ids=self.tissue_list)
        platform_stats = self._benchmark_category(category_ids=self.platform_list)
        all_stats = self._benchmark_category(category_ids=['all'] * len(self.tissue_list))
        tissue_stats['all'] = all_stats['all']
        platform_stats['all'] = all_stats['all']

        return tissue_stats, platform_stats
