In [1]:
# This notebook benchmarks the model output from the StarDist pretrained nuclear model
import os
import errno

import numpy as np

In [2]:
# create folder for this set of experiments
experiment_folder = "stardist/"
MODEL_DIR = os.path.join("/data/analyses/", experiment_folder)
NPZ_DIR = "/data/npz_data/20201018_freeze/"
LOG_DIR = '/data/logs'

if not os.path.isdir(MODEL_DIR):
    os.makedirs(MODEL_DIR)

In [3]:
from skimage.morphology import binary_dilation, square

In [8]:
print('loading data')
split = '1'
test_name = "20201018_multiplex_seed_{}_test_256x256.npz".format(split)
test_dict = np.load(NPZ_DIR + test_name)

y_test = test_dict['y']


original_labels = np.load(MODEL_DIR + 'seed_{}_labels.npz'.format(split))['y']
expanded_labels = np.zeros_like(original_labels)

for img in range(30):
    current_label = original_labels[img, :, :, 0]
    expanded_label = np.zeros((256, 256))
    for cell in np.unique(current_label):
        if cell != 0:
            mask = current_label == cell
            expanded_mask = binary_dilation(mask, square(7))
            expanded_label[expanded_mask] = cell
    expanded_labels[img, :, :, 0] = expanded_label

loading data


In [18]:
from tensorflow.keras.optimizers import SGD, Adam
from deepcell.utils.train_utils import rate_scheduler
from deepcell.utils.retinanet_anchor_utils import get_anchor_parameters
from deepcell.training import train_model_retinanet
from deepcell import model_zoo
from deepcell_toolbox.multiplex_utils import multiplex_preprocess
from timeit import default_timer
from skimage.measure import label
from deepcell_toolbox.deep_watershed import deep_watershed_mibi



def calc_jaccard_index_object(metric_predictions, true_labels, pred_labels):
    jacc_list = []
    for i in range(true_labels.shape[0]):
        y_true = true_labels[i, :, :, 0]
        y_pred = pred_labels[i, :, :, 0]
        true_ids = metric_predictions[i][0]['correct']['y_true']
        pred_ids = metric_predictions[i][0]['correct']['y_pred']

        current_accum = []

        for id in range(len(true_ids)):
            true_mask = y_true == true_ids[id]
            pred_mask = y_pred == pred_ids[id]

            current_jacc = (np.sum(np.logical_and(true_mask, pred_mask)) /
                np.sum(np.logical_or(true_mask, pred_mask)))
            current_accum.append(current_jacc)

        jacc_list.append(current_accum)
    return jacc_list


model_splits = ['1', '2', '3']
metrics = {}
for split in model_splits:
    print('loading data')
    test_name = "20201018_multiplex_seed_{}_test_256x256.npz".format(split)
    test_dict = np.load(NPZ_DIR + test_name)
    
    y_test = test_dict['y']

    
    original_labels = np.load(MODEL_DIR + 'seed_{}_labels.npz'.format(split))['y']
    expanded_labels = np.zeros_like(original_labels)
    
    for img in range(original_labels.shape[0]):
        current_label = original_labels[img, :, :, 0]
        expanded_label = np.zeros((256, 256))
        for cell in np.unique(current_label):
            if cell != 0:
                mask = current_label == cell
                expanded_mask = binary_dilation(mask, square(7))
                expanded_label[expanded_mask] = cell
                
    
        expanded_labels[img, :, :, 0] = expanded_label
    print("relabeling")
    for i in range(expanded_labels.shape[0]):
        img = expanded_labels[i, :, :, 0]
        img = label(img)
        expanded_labels[i, :, :, 0] = img
    
    for i in range(y_test.shape[0]):
        img = y_test[i, :, :, 0]
        img = label(img)
        y_test[i, :, :, 0] = img
    
    # calculating accuracy
    print("calculating accuracy")
    db = DatasetBenchmarker(y_true=y_test, 
                       y_pred=expanded_labels,
                       tissue_list=test_dict['tissue_list'],
                       platform_list=test_dict['platform_list'],
                       model_name='default_model')
    tissue_stats, platform_stats = db.benchmark()
    
    jacc = calc_jaccard_index_object(db.metrics.predictions, y_test, expanded_labels)
    jacc = np.concatenate(jacc)
    jacc_mean = np.mean(jacc)
    print(jacc_mean)
    metrics[split] = {'tissue_stats':tissue_stats, 'platform_stats': platform_stats, 'jacc':jacc_mean}

    

loading data
relabeling
calculating accuracy

____________Object-based statistics____________

Number of true cells:		 139873
Number of predicted cells:	 110090

Correct detections:  60129	Recall: 42.9883%
Incorrect detections: 49961	Precision: 54.618%

Gained detections: 23980	Perc Error: 27.5145%
Missed detections: 48757	Perc Error: 55.9435%
Merges: 6031		Perc Error: 6.9199%
Splits: 3304		Perc Error: 3.791%
Catastrophes: 5082		Perc Error: 5.8311%

Gained detections from splits: 4160
Missed detections from merges: 7551
True detections involved in catastrophes: 2342
Predicted detections involved in catastrophes: 2234 

Average Pixel IOU (Jaccard Index): 0.7015 

uid is breast
uid is gi
uid is immune
uid is lung
uid is pancreas
uid is skin
uid is codex
uid is cycif
uid is imc
uid is mibi
uid is mxif
uid is vectra
uid is all
0.7229020495056567
loading data
relabeling
calculating accuracy

____________Object-based statistics____________

Number of true cells:		 146194
Number of predicted 

In [25]:
np.savez_compressed(os.path.join('/data/analyses/', 'stardist_metrics_jacc.npz'), **metrics)

In [15]:
# Copyright 2016-2020 The Van Valen Lab at the California Institute of
# Technology (Caltech), with support from the Paul Allen Family Foundation,
# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01.
# All rights reserved.
#
# Licensed under a modified Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.github.com/vanvalenlab/caliban-toolbox/LICENSE
#
# The Work provided may be used for non-commercial academic purposes only.
# For any other use of the Work, including commercial use, please contact:
# vanvalenlab@gmail.com
#
# Neither the name of Caltech nor the names of its contributors may be used
# to endorse or promote products derived from this software without specific
# prior written permission.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np

from deepcell_toolbox.metrics import Metrics, stats_pixelbased
from scipy.stats import hmean


class DatasetBenchmarker(object):
    """Class to perform benchmarking across different tissue and platform types

    Args:
        y_true: true labels
        y_pred: predicted labels
        tissue_list: list of tissue names for each image
        platform_list: list of platform names for each image
        model_name: name of the model used to generate the predictions
        metrics_kwargs: arguments to be passed to metrics package

    Raises:
        ValueError: if y_true and y_pred have different shapes
        ValueError: if y_true and y_pred are not 4D
        ValueError: if tissue_ids or platform_ids is not same length as labels
    """
    def __init__(self,
                 y_true,
                 y_pred,
                 tissue_list,
                 platform_list,
                 model_name,
                 metrics_kwargs={}):
        if y_true.shape != y_pred.shape:
            raise ValueError('Shape mismatch: y_true has shape {}, '
                             'y_pred has shape {}. Labels must have the same'
                             'shape.'.format(y_true.shape, y_pred.shape))
        if len(y_true.shape) != 4:
            raise ValueError('Data must be 4D, supplied data is {}'.format(y_true.shape))

        self.y_true = y_true
        self.y_pred = y_pred

        if len({y_true.shape[0], len(tissue_list), len(platform_list)}) != 1:
            raise ValueError('Tissue_list and platform_list must have same length as labels')

        self.tissue_list = tissue_list
        self.platform_list = platform_list
        self.model_name = model_name
        self.metrics = Metrics(model_name, **metrics_kwargs)

    def _benchmark_category(self, category_ids):
        """Compute benchmark stats over the different categories in supplied list

        Args:
            category_ids: list specifying which category each image belongs to

        Returns:
            stats_dict: dictionary of benchmarking results
        """

        unique_ids = np.unique(category_ids)

        # create dict to hold stats across each category
        stats_dict = {}
        for uid in unique_ids:
            print("uid is {}".format(uid))
            stats_dict[uid] = {}
            category_idx = np.isin(category_ids, uid)

            # sum metrics across individual images
            for key in self.metrics.stats:
                stats_dict[uid][key] = self.metrics.stats[key][category_idx].sum()

            # compute additional metrics not produced by Metrics class
            stats_dict[uid]['recall'] = \
                stats_dict[uid]['correct_detections'] / stats_dict[uid]['n_true']

            stats_dict[uid]['precision'] = \
                stats_dict[uid]['correct_detections'] / stats_dict[uid]['n_pred']

            stats_dict[uid]['f1'] = \
                hmean([stats_dict[uid]['recall'], stats_dict[uid]['precision']])

            pixel_stats = stats_pixelbased(self.y_true[category_idx] != 0,
                                           self.y_pred[category_idx] != 0)
            stats_dict[uid]['jaccard'] = pixel_stats['jaccard']

        return stats_dict

    def benchmark(self):
        self.metrics.calc_object_stats(self.y_true, self.y_pred)
        tissue_stats = self._benchmark_category(category_ids=self.tissue_list)
        platform_stats = self._benchmark_category(category_ids=self.platform_list)
        all_stats = self._benchmark_category(category_ids=['all'] * len(self.tissue_list))
        tissue_stats['all'] = all_stats['all']
        platform_stats['all'] = all_stats['all']

        return tissue_stats, platform_stats
