Skip to content

Commit

Permalink
Update Multiplex model from 4 heads to 2. (#130)
Browse files Browse the repository at this point in the history
* Edit the expected multiplex model outputs from 4 predictions to 2.

* Upgrade `deepcell-toolbox` to `0.8.0` which includes the multiplex wrapper functions.

* Use the `multiplex-preprocess` and `multiplex-postprocess` from the new toolbox version.

Co-authored-by: willgraf <7930703+willgraf@users.noreply.github.com>
  • Loading branch information
ngreenwald and willgraf committed Sep 14, 2020
1 parent 8c9b0d3 commit fc537b8
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 14 deletions.
4 changes: 2 additions & 2 deletions redis_consumer/consumers/multiplex_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _consume(self, redis_hash):
image = np.expand_dims(image, axis=0) # add in the batch dim

# Preprocess image
image = self.preprocess(image, ['histogram_normalization'])
image = self.preprocess(image, ['multiplex_preprocess'])

# Send data to the model
self.update_key(redis_hash, {'status': 'predicting'})
Expand All @@ -128,7 +128,7 @@ def _consume(self, redis_hash):
# Post-process model results
self.update_key(redis_hash, {'status': 'post-processing'})
image = processing.format_output_multiplex(image)
image = self.postprocess(image, ['deep_watershed_subcellular'])
image = self.postprocess(image, ['multiplex_postprocess'])

# Save the post-processed results to a file
_ = timeit.default_timer()
Expand Down
6 changes: 1 addition & 5 deletions redis_consumer/consumers/multiplex_consumer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,11 @@ def make_grpc_image(model_shape=(-1, 256, 256, 2)):

def grpc(data, *args, **kwargs):
inner = np.random.random((1,) + shape + (1,))
outer = np.random.random((1,) + shape + (1,))
fgbg = np.random.random((1,) + shape + (2,))
feature = np.random.random((1,) + shape + (3,))

inner2 = np.random.random((1,) + shape + (1,))
outer2 = np.random.random((1,) + shape + (1,))
fgbg2 = np.random.random((1,) + shape + (2,))
feature2 = np.random.random((1,) + shape + (3,))
return [inner, outer, fgbg, feature, inner2, outer2, fgbg2, feature2]
return [inner, feature, inner2, feature2]
return grpc

image_shapes = [
Expand Down
6 changes: 3 additions & 3 deletions redis_consumer/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@
from deepcell_toolbox.deep_watershed import deep_watershed

# import mibi pre- and post-processing functions
from deepcell_toolbox.deep_watershed import deep_watershed_mibi
from deepcell_toolbox.deep_watershed import format_output_multiplex
from deepcell_toolbox.deep_watershed import deep_watershed_subcellular
from deepcell_toolbox.processing import phase_preprocess
from deepcell_toolbox.multiplex_utils import format_output_multiplex
from deepcell_toolbox.multiplex_utils import multiplex_preprocess
from deepcell_toolbox.multiplex_utils import multiplex_postprocess

from deepcell_toolbox import retinanet_semantic_to_label_image
from deepcell_toolbox import retinanet_to_label_image
Expand Down
6 changes: 3 additions & 3 deletions redis_consumer/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def _strip(x):
'pre': {
'normalize': processing.normalize,
'histogram_normalization': processing.phase_preprocess,
'multiplex_preprocess': processing.multiplex_preprocess
},
'post': {
'deepcell': processing.pixelwise, # TODO: this is deprecated.
Expand All @@ -130,8 +131,7 @@ def _strip(x):
'retinanet': processing.retinanet_to_label_image,
'retinanet-semantic': processing.retinanet_semantic_to_label_image,
'deep_watershed': processing.deep_watershed,
'multiplex': processing.deep_watershed_mibi,
'deep_watershed_subcellular': processing.deep_watershed_subcellular,
'multiplex_postprocess': processing.multiplex_postprocess,
},
}

Expand Down Expand Up @@ -162,7 +162,7 @@ def _strip(x):
LABEL_DETECT_ENABLED = config('LABEL_DETECT_ENABLED', default=False, cast=bool)

# Multiplex model Settings
MULTIPLEX_MODEL = config('MULTIPLEX_MODEL', default='MultiplexSegmentation:3', cast=str)
MULTIPLEX_MODEL = config('MULTIPLEX_MODEL', default='MultiplexSegmentation:4', cast=str)

# Set default models based on label type
MODEL_CHOICES = {
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ grpcio==1.27.2
dict-to-protobuf==0.0.3.9
pytz==2019.1
deepcell-tracking==0.2.6
deepcell-toolbox==0.6.2
deepcell-toolbox>=0.8.0

0 comments on commit fc537b8

Please sign in to comment.