Skip to content

Commit

Permalink
Merge b6c2a77 into a1aad5b
Browse files Browse the repository at this point in the history
  • Loading branch information
msschwartz21 committed Aug 29, 2019
2 parents a1aad5b + b6c2a77 commit 81f1571
Show file tree
Hide file tree
Showing 11 changed files with 1,847 additions and 6 deletions.
1 change: 1 addition & 0 deletions deepcell/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from __future__ import division
from __future__ import print_function

from deepcell import applications
from deepcell import datasets
from deepcell import layers
from deepcell import losses
Expand Down
37 changes: 37 additions & 0 deletions deepcell/applications/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2016-2019 The Van Valen Lab at the California Institute of
# Technology (Caltech), with support from the Paul Allen Family Foundation,
# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01.
# All rights reserved.
#
# Licensed under a modified Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.github.com/vanvalenlab/deepcell-tf/LICENSE
#
# The Work provided may be used for non-commercial academic purposes only.
# For any other use of the Work, including commercial use, please contact:
# vanvalenlab@gmail.com
#
# Neither the name of Caltech nor the names of its contributors may be used
# to endorse or promote products derived from this software without specific
# prior written permission.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Deepcell Applications - Pre-trained models for specific functions"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from deepcell.applications.label_detection import LabelDetectionModel
from deepcell.applications.scale_detection import ScaleDetectionModel

del absolute_import
del division
del print_function
105 changes: 105 additions & 0 deletions deepcell/applications/label_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright 2016-2019 The Van Valen Lab at the California Institute of
# Technology (Caltech), with support from the Paul Allen Family Foundation,
# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01.
# All rights reserved.
#
# Licensed under a modified Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.github.com/vanvalenlab/deepcell-tf/LICENSE
#
# The Work provided may be used for non-commercial academic purposes only.
# For any other use of the Work, including commercial use, please contact:
# vanvalenlab@gmail.com
#
# Neither the name of Caltech nor the names of its contributors may be used
# to endorse or promote products derived from this software without specific
# prior written permission.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Classify the type of an input image to send the data to the correct model"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python import keras

try:
from tensorflow.python.keras.utils.data_utils import get_file
except ImportError: # tf v1.9 moves conv_utils from _impl to keras.utils
from tensorflow.python.keras._impl.keras.utils.data_utils import get_file

from deepcell.layers import ImageNormalization2D, TensorProduct
from deepcell.utils.backbone_utils import get_backbone


WEIGHTS_PATH = ('https://deepcell-data.s3-us-west-1.amazonaws.com/'
'model-weights/LabelDetectionModel_VGG16.h5')


def LabelDetectionModel(input_shape=(None, None, 1),
inputs=None,
backbone='VGG16',
use_pretrained_weights=True):
"""Classify a microscopy image as Nuclear, Cytoplasm, or Phase.
This can be helpful in determining the type of data (nuclear, cytoplasm,
etc.) so that this data can be forwared to the correct segmenation model.
"""
required_channels = 3 # required for most backbones

if inputs is None:
inputs = keras.layers.Input(shape=input_shape)

if keras.backend.image_data_format() == 'channels_first':
channel_axis = 0
else:
channel_axis = -1

norm = ImageNormalization2D(norm_method='whole_image')(inputs)
fixed_inputs = TensorProduct(required_channels)(norm)

# force the input shape
fixed_input_shape = list(input_shape)
fixed_input_shape[channel_axis] = required_channels
fixed_input_shape = tuple(fixed_input_shape)

backbone_model = get_backbone(
backbone,
fixed_inputs,
use_imagenet=False,
return_dict=False,
include_top=False,
weights=None,
input_shape=fixed_input_shape,
pooling=None)

x = keras.layers.AveragePooling2D(4)(backbone_model.outputs[0])
x = TensorProduct(256)(x)
x = TensorProduct(3)(x)
x = keras.layers.Flatten()(x)
outputs = keras.layers.Activation('softmax')(x)

model = keras.Model(inputs=backbone_model.inputs, outputs=outputs)

if use_pretrained_weights:
if backbone.upper() == 'VGG16':
weights_path = get_file(
'LabelDetectionModel_{}.h5'.format(backbone),
WEIGHTS_PATH,
cache_subdir='models',
md5_hash='090a0de7a33dceff7ad690b3c9852938')
else:
raise ValueError('Backbone %s does not have a weights file.' %
backbone)

model.load_weights(weights_path)

return model
73 changes: 73 additions & 0 deletions deepcell/applications/label_detection_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2016-2019 The Van Valen Lab at the California Institute of
# Technology (Caltech), with support from the Paul Allen Family Foundation,
# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01.
# All rights reserved.
#
# Licensed under a modified Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.github.com/vanvalenlab/deepcell-tf/LICENSE
#
# The Work provided may be used for non-commercial academic purposes only.
# For any other use of the Work, including commercial use, please contact:
# vanvalenlab@gmail.com
#
# Neither the name of Caltech nor the names of its contributors may be used
# to endorse or promote products derived from this software without specific
# prior written permission.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for LabelDetectionModel"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from tensorflow.python.keras.layers import Input
from tensorflow.python.platform import test

from deepcell.applications import LabelDetectionModel


class TestLabelDetectionModel(test.TestCase):

def test_label_detection_model(self):

valid_backbones = ['VGG16']
input_shape = (256, 256, 1) # channels will be set to 3

batch_shape = tuple([8] + list(input_shape))

X = np.random.random(batch_shape)

for backbone in valid_backbones:
with self.test_session(use_gpu=True):
inputs = Input(shape=input_shape)
model = LabelDetectionModel(
inputs=inputs,
backbone=backbone,
use_pretrained_weights=False)

y = model.predict(X)

assert y.shape[0] == X.shape[0]
assert len(y.shape) == 2

with self.test_session(use_gpu=True):
model = LabelDetectionModel(
input_shape=input_shape,
backbone=backbone,
use_pretrained_weights=False)

y = model.predict(X)

assert y.shape[0] == X.shape[0]
assert len(y.shape) == 2
105 changes: 105 additions & 0 deletions deepcell/applications/scale_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright 2016-2019 The Van Valen Lab at the California Institute of
# Technology (Caltech), with support from the Paul Allen Family Foundation,
# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01.
# All rights reserved.
#
# Licensed under a modified Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.github.com/vanvalenlab/deepcell-tf/LICENSE
#
# The Work provided may be used for non-commercial academic purposes only.
# For any other use of the Work, including commercial use, please contact:
# vanvalenlab@gmail.com
#
# Neither the name of Caltech nor the names of its contributors may be used
# to endorse or promote products derived from this software without specific
# prior written permission.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Detect the scale of input data for rescaling for other models"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python import keras

try:
from tensorflow.python.keras.utils.data_utils import get_file
except ImportError: # tf v1.9 moves conv_utils from _impl to keras.utils
from tensorflow.python.keras._impl.keras.utils.data_utils import get_file

from deepcell.layers import ImageNormalization2D, TensorProduct
from deepcell.utils.backbone_utils import get_backbone


WEIGHTS_PATH = ('https://deepcell-data.s3-us-west-1.amazonaws.com/'
'model-weights/ScaleDetectionModel_VGG16.h5')


def ScaleDetectionModel(input_shape=(None, None, 1),
inputs=None,
backbone='VGG16',
use_pretrained_weights=True):
"""Create a ScaleDetectionModel for detecting scales of input data.
This enables data to be scaled appropriately for other segmentation models
which may not be resolution tolerant.
"""
required_channels = 3 # required for most backbones

if inputs is None:
inputs = keras.layers.Input(shape=input_shape)

if keras.backend.image_data_format() == 'channels_first':
channel_axis = 0
else:
channel_axis = -1

norm = ImageNormalization2D(norm_method='whole_image')(inputs)
fixed_inputs = TensorProduct(required_channels)(norm)

# force the input shape
fixed_input_shape = list(input_shape)
fixed_input_shape[channel_axis] = required_channels
fixed_input_shape = tuple(fixed_input_shape)

backbone_model = get_backbone(
backbone,
fixed_inputs,
use_imagenet=False,
return_dict=False,
include_top=False,
weights=None,
input_shape=fixed_input_shape,
pooling=None)

x = keras.layers.AveragePooling2D(4)(backbone_model.outputs[0])
x = TensorProduct(256)(x)
x = TensorProduct(1)(x)
x = keras.layers.Flatten()(x)
outputs = keras.layers.Activation('relu')(x)

model = keras.Model(inputs=backbone_model.inputs, outputs=outputs)

if use_pretrained_weights:
if backbone.upper() == 'VGG16':
weights_path = get_file(
'ScaleDetectionModel_{}.h5'.format(backbone),
WEIGHTS_PATH,
cache_subdir='models',
md5_hash='ab23e35676ffcdf1c72d3804cc65ea1d')
else:
raise ValueError('Backbone %s does not have a weights file.' %
backbone)

model.load_weights(weights_path)

return model

0 comments on commit 81f1571

Please sign in to comment.