Skip to content

Commit

Permalink
Merge d191b9a into 8d75afc
Browse files Browse the repository at this point in the history
  • Loading branch information
ngreenwald committed Aug 9, 2021
2 parents 8d75afc + d191b9a commit e83e48b
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 2 deletions.
43 changes: 42 additions & 1 deletion deepcell/applications/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,47 @@ def _resize_output(self, image, original_shape):

return image

def _batch_predict(self, tiles, batch_size):
"""Batch process tiles to generate model predictions.
The built-in keras.predict function has support for batching, but
loads the entire image stack into GPU memory, which is prohibitive
for large images. This function uses similar code to the underlying
model.predict function without soaking up GPU memory.
Args:
tiles: Tiled data which will be fed to model
batch_size: Number of images to predict on per batch
Returns:
list: Model outputs
"""

# list to hold final output
output_tiles = []

# loop through each batch
for i in range(0, tiles.shape[0], batch_size):
batch_inputs = tiles[i:i + batch_size, ...]

batch_outputs = self.model.predict(batch_inputs, batch_size=batch_size)

# model with only a single output gets temporarily converted to a list
if not isinstance(batch_outputs, list):
batch_outputs = [batch_outputs]

# initialize output list with empty arrays to hold all batches
if not output_tiles:
for batch_out in batch_outputs:
shape = (tiles.shape[0],) + batch_out.shape[1:]
output_tiles.append(np.zeros(shape, dtype=tiles.dtype))

# save each batch to corresponding index in output list
for j, batch_out in enumerate(batch_outputs):
output_tiles[j][i:i + batch_size, ...] = batch_out

return output_tiles

def _run_model(self,
image,
batch_size=4,
Expand All @@ -337,7 +378,7 @@ def _run_model(self,

# Run images through model
t = timeit.default_timer()
output_tiles = self.model.predict(tiles, batch_size=batch_size)
output_tiles = self._batch_predict(tiles=tiles, batch_size=batch_size)
self.logger.debug('Model inference finished in %s s',
timeit.default_timer() - t)

Expand Down
33 changes: 32 additions & 1 deletion deepcell/applications/application_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from tensorflow.keras.layers import Input
from tensorflow.python.platform import test
from unittest.mock import Mock

from deepcell.applications import Application

Expand Down Expand Up @@ -243,13 +244,43 @@ def _format_model_output(Lx):
y = app._format_model_output(x)
self.assertAllEqual(x, y['inner-distance'])

def test_batch_predict(self):

def predict1(x, batch_size=4):
y = np.random.rand(*x.shape)
return [y]

def predict2(x, batch_size=4):
y = np.random.rand(*x.shape)
return [y] * 2

for num_images in [4, 8, 10]:
for num_pred_heads in [1, 2]:
model = DummyModel(n_out=num_pred_heads)
app = Application(model)

x = np.random.rand(num_images, 128, 128, 1)

if num_pred_heads == 1:
app.model.predict = Mock(side_effect=predict1)
else:
app.model.predict = Mock(side_effect=predict2)
y = app._batch_predict(x, batch_size=4)
self.assertEqual(x.shape, y[0].shape)

assert app.model.predict.call_count == np.ceil(num_images / 4)

if num_pred_heads == 2:
self.assertEqual(x.shape, y[1].shape)

def test_run_model(self):
model = DummyModel()
model = DummyModel(n_out=2)
app = Application(model)

x = np.random.rand(1, 128, 128, 1)
y = app._run_model(x)
self.assertEqual(x.shape, y[0].shape)
self.assertEqual(x.shape, y[1].shape)

def test_predict_segmentation(self):
model = DummyModel()
Expand Down

0 comments on commit e83e48b

Please sign in to comment.