Skip to content

Commit

Permalink
Merge f5a8ceb into 8d75afc
Browse files Browse the repository at this point in the history
  • Loading branch information
ngreenwald committed Aug 10, 2021
2 parents 8d75afc + f5a8ceb commit 0579fb9
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 3 deletions.
43 changes: 42 additions & 1 deletion deepcell/applications/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,47 @@ def _resize_output(self, image, original_shape):

return image

def _batch_predict(self, tiles, batch_size):
"""Batch process tiles to generate model predictions.
The built-in keras.predict function has support for batching, but
loads the entire image stack into GPU memory, which is prohibitive
for large images. This function uses similar code to the underlying
model.predict function without soaking up GPU memory.
Args:
tiles (numpy.array): Tiled data which will be fed to model
batch_size (int): Number of images to predict on per batch
Returns:
list: Model outputs
"""

# list to hold final output
output_tiles = []

# loop through each batch
for i in range(0, tiles.shape[0], batch_size):
batch_inputs = tiles[i:i + batch_size, ...]

batch_outputs = self.model.predict(batch_inputs, batch_size=batch_size)

# model with only a single output gets temporarily converted to a list
if not isinstance(batch_outputs, list):
batch_outputs = [batch_outputs]

# initialize output list with empty arrays to hold all batches
if not output_tiles:
for batch_out in batch_outputs:
shape = (tiles.shape[0],) + batch_out.shape[1:]
output_tiles.append(np.zeros(shape, dtype=tiles.dtype))

# save each batch to corresponding index in output list
for j, batch_out in enumerate(batch_outputs):
output_tiles[j][i:i + batch_size, ...] = batch_out

return output_tiles

def _run_model(self,
image,
batch_size=4,
Expand All @@ -337,7 +378,7 @@ def _run_model(self,

# Run images through model
t = timeit.default_timer()
output_tiles = self.model.predict(tiles, batch_size=batch_size)
output_tiles = self._batch_predict(tiles=tiles, batch_size=batch_size)
self.logger.debug('Model inference finished in %s s',
timeit.default_timer() - t)

Expand Down
40 changes: 38 additions & 2 deletions deepcell/applications/application_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@
from __future__ import division
from __future__ import print_function

from itertools import product
from unittest.mock import Mock

import numpy as np

from tensorflow.keras.layers import Input
from tensorflow.python.platform import test

from deepcell.applications import Application
Expand Down Expand Up @@ -243,13 +245,47 @@ def _format_model_output(Lx):
y = app._format_model_output(x)
self.assertAllEqual(x, y['inner-distance'])

def test_batch_predict(self):

def predict1(x, batch_size=4):
y = np.random.rand(*x.shape)
return [y]

def predict2(x, batch_size=4):
y = np.random.rand(*x.shape)
return [y] * 2

num_images = [4, 8, 10]
num_pred_heads = [1, 2]
batch_sizes = [1, 4, 5]
prod = product(num_images, num_pred_heads, batch_sizes)

for num_image, num_pred_head, batch_size in prod:
model = DummyModel(n_out=num_pred_head)
app = Application(model)

x = np.random.rand(num_image, 128, 128, 1)

if num_pred_head == 1:
app.model.predict = Mock(side_effect=predict1)
else:
app.model.predict = Mock(side_effect=predict2)
y = app._batch_predict(x, batch_size=batch_size)

assert app.model.predict.call_count == np.ceil(num_image / batch_size)

self.assertEqual(x.shape, y[0].shape)
if num_pred_head == 2:
self.assertEqual(x.shape, y[1].shape)

def test_run_model(self):
model = DummyModel()
model = DummyModel(n_out=2)
app = Application(model)

x = np.random.rand(1, 128, 128, 1)
y = app._run_model(x)
self.assertEqual(x.shape, y[0].shape)
self.assertEqual(x.shape, y[1].shape)

def test_predict_segmentation(self):
model = DummyModel()
Expand Down

0 comments on commit 0579fb9

Please sign in to comment.